coder/coderd/provisionerdserver/acquirer_test.go

514 lines
15 KiB
Go

package provisionerdserver_test
import (
"context"
"database/sql"
"encoding/json"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/exp/slices"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
// TestAcquirer_Store tests that a database.Store is accepted as a provisionerdserver.AcquirerStore
func TestAcquirer_Store(t *testing.T) {
t.Parallel()
db := dbmem.New()
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
_ = provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps)
}
func TestAcquirer_Single(t *testing.T) {
t.Parallel()
fs := newFakeOrderedStore()
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
workerID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"foo": "bar",
}
acquiree := newTestAcquiree(t, workerID, pt, tags)
jobID := uuid.New()
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
require.NoError(t, err)
acquiree.startAcquire(ctx, uut)
job := acquiree.success(ctx)
require.Equal(t, jobID, job.ID)
require.Len(t, fs.params, 1)
require.Equal(t, workerID, fs.params[0].WorkerID.UUID)
}
// TestAcquirer_MultipleSameDomain tests multiple acquirees with the same provisioners and tags
func TestAcquirer_MultipleSameDomain(t *testing.T) {
t.Parallel()
fs := newFakeOrderedStore()
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
acquirees := make([]*testAcquiree, 0, 10)
jobIDs := make(map[uuid.UUID]bool)
workerIDs := make(map[uuid.UUID]bool)
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"foo": "bar",
}
for i := 0; i < 10; i++ {
wID := uuid.New()
workerIDs[wID] = true
a := newTestAcquiree(t, wID, pt, tags)
acquirees = append(acquirees, a)
a.startAcquire(ctx, uut)
}
for i := 0; i < 10; i++ {
jobID := uuid.New()
jobIDs[jobID] = true
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
require.NoError(t, err)
}
gotJobIDs := make(map[uuid.UUID]bool)
for i := 0; i < 10; i++ {
j := acquirees[i].success(ctx)
gotJobIDs[j.ID] = true
}
require.Equal(t, jobIDs, gotJobIDs)
require.Len(t, fs.overlaps, 0)
gotWorkerCalls := make(map[uuid.UUID]bool)
for _, params := range fs.params {
gotWorkerCalls[params.WorkerID.UUID] = true
}
require.Equal(t, workerIDs, gotWorkerCalls)
}
// TestAcquirer_WaitsOnNoJobs tests that after a call that returns no jobs, Acquirer waits for a new
// job posting before retrying
func TestAcquirer_WaitsOnNoJobs(t *testing.T) {
t.Parallel()
fs := newFakeOrderedStore()
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
workerID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"foo": "bar",
}
acquiree := newTestAcquiree(t, workerID, pt, tags)
jobID := uuid.New()
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
require.NoError(t, err)
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
require.NoError(t, err)
acquiree.startAcquire(ctx, uut)
require.Eventually(t, func() bool {
fs.mu.Lock()
defer fs.mu.Unlock()
return len(fs.params) == 1
}, testutil.WaitShort, testutil.IntervalFast)
acquiree.requireBlocked()
// First send in some with incompatible tags & types
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{
"cool": "tapes",
"strong": "bad",
})
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{
"foo": "fighters",
})
postJob(t, ps, database.ProvisionerTypeTerraform, provisionerdserver.Tags{
"foo": "bar",
})
acquiree.requireBlocked()
// compatible tags
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{})
job := acquiree.success(ctx)
require.Equal(t, jobID, job.ID)
}
// TestAcquirer_RetriesPending tests that if we get a job posting while a db call is in progress
// we retry to acquire a job immediately, even if the first call returned no jobs. We want this
// behavior since the query that found no jobs could have resolved before the job was posted, but
// the query result could reach us later than the posting over the pubsub.
func TestAcquirer_RetriesPending(t *testing.T) {
t.Parallel()
fs := newFakeOrderedStore()
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
workerID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"foo": "bar",
}
acquiree := newTestAcquiree(t, workerID, pt, tags)
jobID := uuid.New()
acquiree.startAcquire(ctx, uut)
require.Eventually(t, func() bool {
fs.mu.Lock()
defer fs.mu.Unlock()
return len(fs.params) == 1
}, testutil.WaitShort, testutil.IntervalFast)
// First call to DB is in progress. Send in posting
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{})
// there is a race between the posting being processed and the DB call
// returning. In either case we should retry, but we're trying to hit the
// case where the posting is processed first, so sleep a little bit to give
// it a chance.
time.Sleep(testutil.IntervalMedium)
// Now, when first DB call returns ErrNoRows we retry.
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
require.NoError(t, err)
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
require.NoError(t, err)
job := acquiree.success(ctx)
require.Equal(t, jobID, job.ID)
}
// TestAcquirer_DifferentDomains tests that acquirees with different tags don't block each other
func TestAcquirer_DifferentDomains(t *testing.T) {
t.Parallel()
fs := newFakeTaggedStore(t)
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
worker0 := uuid.New()
tags0 := provisionerdserver.Tags{
"worker": "0",
}
acquiree0 := newTestAcquiree(t, worker0, pt, tags0)
worker1 := uuid.New()
tags1 := provisionerdserver.Tags{
"worker": "1",
}
acquiree1 := newTestAcquiree(t, worker1, pt, tags1)
jobID := uuid.New()
fs.jobs = []database.ProvisionerJob{
{ID: jobID, Provisioner: database.ProvisionerTypeEcho, Tags: database.StringMap{"worker": "1"}},
}
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
ctx0, cancel0 := context.WithCancel(ctx)
defer cancel0()
acquiree0.startAcquire(ctx0, uut)
select {
case params := <-fs.params:
require.Equal(t, worker0, params.WorkerID.UUID)
case <-ctx.Done():
t.Fatal("timed out waiting for call to database from worker0")
}
acquiree0.requireBlocked()
// worker1 should not be blocked by worker0, as they are different tags
acquiree1.startAcquire(ctx, uut)
job := acquiree1.success(ctx)
require.Equal(t, jobID, job.ID)
cancel0()
acquiree0.requireCanceled(ctx)
}
func TestAcquirer_BackupPoll(t *testing.T) {
t.Parallel()
fs := newFakeOrderedStore()
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := provisionerdserver.NewAcquirer(
ctx, logger.Named("acquirer"), fs, ps,
provisionerdserver.TestingBackupPollDuration(testutil.IntervalMedium),
)
workerID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"foo": "bar",
}
acquiree := newTestAcquiree(t, workerID, pt, tags)
jobID := uuid.New()
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
require.NoError(t, err)
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
require.NoError(t, err)
acquiree.startAcquire(ctx, uut)
job := acquiree.success(ctx)
require.Equal(t, jobID, job.ID)
}
// TestAcquirer_UnblockOnCancel tests that a canceled call doesn't block a call
// from the same domain.
func TestAcquirer_UnblockOnCancel(t *testing.T) {
t.Parallel()
fs := newFakeOrderedStore()
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
worker0 := uuid.New()
tags := provisionerdserver.Tags{
"foo": "bar",
}
acquiree0 := newTestAcquiree(t, worker0, pt, tags)
worker1 := uuid.New()
acquiree1 := newTestAcquiree(t, worker1, pt, tags)
jobID := uuid.New()
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
// queue up 2 responses --- we may not need both, since acquiree0 will
// usually cancel before calling, but cancel is async, so it might call.
for i := 0; i < 2; i++ {
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
require.NoError(t, err)
}
ctx0, cancel0 := context.WithCancel(ctx)
cancel0()
acquiree0.startAcquire(ctx0, uut)
acquiree1.startAcquire(ctx, uut)
job := acquiree1.success(ctx)
require.Equal(t, jobID, job.ID)
}
func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) {
t.Helper()
msg, err := json.Marshal(provisionerjobs.JobPosting{
ProvisionerType: pt,
Tags: tags,
})
require.NoError(t, err)
err = ps.Publish(provisionerjobs.EventJobPosted, msg)
require.NoError(t, err)
}
// fakeOrderedStore is a fake store that lets tests send AcquireProvisionerJob
// results in order over a channel, and tests for overlapped calls.
type fakeOrderedStore struct {
jobs chan database.ProvisionerJob
errors chan error
mu sync.Mutex
params []database.AcquireProvisionerJobParams
// inflight and overlaps track whether any calls from workers overlap with
// one another
inflight map[uuid.UUID]bool
overlaps [][]uuid.UUID
}
func newFakeOrderedStore() *fakeOrderedStore {
return &fakeOrderedStore{
// buffer the channels so that we can queue up lots of responses to
// occur nearly simultaneously
jobs: make(chan database.ProvisionerJob, 100),
errors: make(chan error, 100),
inflight: make(map[uuid.UUID]bool),
}
}
func (s *fakeOrderedStore) AcquireProvisionerJob(
_ context.Context, params database.AcquireProvisionerJobParams,
) (
database.ProvisionerJob, error,
) {
s.mu.Lock()
s.params = append(s.params, params)
for workerID := range s.inflight {
s.overlaps = append(s.overlaps, []uuid.UUID{workerID, params.WorkerID.UUID})
}
s.inflight[params.WorkerID.UUID] = true
s.mu.Unlock()
job := <-s.jobs
err := <-s.errors
s.mu.Lock()
delete(s.inflight, params.WorkerID.UUID)
s.mu.Unlock()
return job, err
}
func (s *fakeOrderedStore) sendCtx(ctx context.Context, job database.ProvisionerJob, err error) error {
select {
case <-ctx.Done():
return ctx.Err()
case s.jobs <- job:
// OK
}
select {
case <-ctx.Done():
return ctx.Err()
case s.errors <- err:
// OK
}
return nil
}
// fakeTaggedStore is a test store that allows tests to specify which jobs are
// available, and returns them to callers with the appropriate provisioner type
// and tags. It doesn't care about the order.
type fakeTaggedStore struct {
t *testing.T
mu sync.Mutex
jobs []database.ProvisionerJob
params chan database.AcquireProvisionerJobParams
}
func newFakeTaggedStore(t *testing.T) *fakeTaggedStore {
return &fakeTaggedStore{
t: t,
params: make(chan database.AcquireProvisionerJobParams, 100),
}
}
func (s *fakeTaggedStore) AcquireProvisionerJob(
_ context.Context, params database.AcquireProvisionerJobParams,
) (
database.ProvisionerJob, error,
) {
defer func() { s.params <- params }()
var tags provisionerdserver.Tags
err := json.Unmarshal(params.Tags, &tags)
if !assert.NoError(s.t, err) {
return database.ProvisionerJob{}, err
}
s.mu.Lock()
defer s.mu.Unlock()
jobLoop:
for i, job := range s.jobs {
if !slices.Contains(params.Types, job.Provisioner) {
continue
}
for k, v := range job.Tags {
pv, ok := tags[k]
if !ok {
continue jobLoop
}
if v != pv {
continue jobLoop
}
}
// found a job!
s.jobs = append(s.jobs[:i], s.jobs[i+1:]...)
return job, nil
}
return database.ProvisionerJob{}, sql.ErrNoRows
}
// testAcquiree is a helper type that handles asynchronously calling AcquireJob
// and asserting whether or not it returns, blocks, or is canceled.
type testAcquiree struct {
t *testing.T
workerID uuid.UUID
pt []database.ProvisionerType
tags provisionerdserver.Tags
ec chan error
jc chan database.ProvisionerJob
}
func newTestAcquiree(t *testing.T, workerID uuid.UUID, pt []database.ProvisionerType, tags provisionerdserver.Tags) *testAcquiree {
return &testAcquiree{
t: t,
workerID: workerID,
pt: pt,
tags: tags,
ec: make(chan error, 1),
jc: make(chan database.ProvisionerJob, 1),
}
}
func (a *testAcquiree) startAcquire(ctx context.Context, uut *provisionerdserver.Acquirer) {
go func() {
j, e := uut.AcquireJob(ctx, a.workerID, a.pt, a.tags)
a.ec <- e
a.jc <- j
}()
}
func (a *testAcquiree) success(ctx context.Context) database.ProvisionerJob {
select {
case <-ctx.Done():
a.t.Fatal("timeout waiting for AcquireJob error")
case err := <-a.ec:
require.NoError(a.t, err)
}
select {
case <-ctx.Done():
a.t.Fatal("timeout waiting for AcquireJob job")
case job := <-a.jc:
return job
}
// unhittable
return database.ProvisionerJob{}
}
func (a *testAcquiree) requireBlocked() {
select {
case <-a.ec:
a.t.Fatal("AcquireJob should block")
default:
// OK
}
}
func (a *testAcquiree) requireCanceled(ctx context.Context) {
select {
case err := <-a.ec:
require.ErrorIs(a.t, err, context.Canceled)
case <-ctx.Done():
a.t.Fatal("timed out waiting for AcquireJob")
}
select {
case job := <-a.jc:
require.Equal(a.t, uuid.Nil, job.ID)
case <-ctx.Done():
a.t.Fatal("timed out waiting for AcquireJob")
}
}