mirror of https://github.com/coder/coder.git
514 lines
15 KiB
Go
514 lines
15 KiB
Go
package provisionerdserver_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"golang.org/x/exp/slices"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbmem"
|
|
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
|
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
|
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
// TestAcquirer_Store tests that a database.Store is accepted as a provisionerdserver.AcquirerStore
|
|
func TestAcquirer_Store(t *testing.T) {
|
|
t.Parallel()
|
|
db := dbmem.New()
|
|
ps := pubsub.NewInMemory()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
_ = provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps)
|
|
}
|
|
|
|
func TestAcquirer_Single(t *testing.T) {
|
|
t.Parallel()
|
|
fs := newFakeOrderedStore()
|
|
ps := pubsub.NewInMemory()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
|
|
|
workerID := uuid.New()
|
|
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
|
tags := provisionerdserver.Tags{
|
|
"foo": "bar",
|
|
}
|
|
acquiree := newTestAcquiree(t, workerID, pt, tags)
|
|
jobID := uuid.New()
|
|
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
|
require.NoError(t, err)
|
|
acquiree.startAcquire(ctx, uut)
|
|
job := acquiree.success(ctx)
|
|
require.Equal(t, jobID, job.ID)
|
|
require.Len(t, fs.params, 1)
|
|
require.Equal(t, workerID, fs.params[0].WorkerID.UUID)
|
|
}
|
|
|
|
// TestAcquirer_MultipleSameDomain tests multiple acquirees with the same provisioners and tags
|
|
func TestAcquirer_MultipleSameDomain(t *testing.T) {
|
|
t.Parallel()
|
|
fs := newFakeOrderedStore()
|
|
ps := pubsub.NewInMemory()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
|
|
|
acquirees := make([]*testAcquiree, 0, 10)
|
|
jobIDs := make(map[uuid.UUID]bool)
|
|
workerIDs := make(map[uuid.UUID]bool)
|
|
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
|
tags := provisionerdserver.Tags{
|
|
"foo": "bar",
|
|
}
|
|
for i := 0; i < 10; i++ {
|
|
wID := uuid.New()
|
|
workerIDs[wID] = true
|
|
a := newTestAcquiree(t, wID, pt, tags)
|
|
acquirees = append(acquirees, a)
|
|
a.startAcquire(ctx, uut)
|
|
}
|
|
for i := 0; i < 10; i++ {
|
|
jobID := uuid.New()
|
|
jobIDs[jobID] = true
|
|
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
|
require.NoError(t, err)
|
|
}
|
|
gotJobIDs := make(map[uuid.UUID]bool)
|
|
for i := 0; i < 10; i++ {
|
|
j := acquirees[i].success(ctx)
|
|
gotJobIDs[j.ID] = true
|
|
}
|
|
require.Equal(t, jobIDs, gotJobIDs)
|
|
require.Len(t, fs.overlaps, 0)
|
|
gotWorkerCalls := make(map[uuid.UUID]bool)
|
|
for _, params := range fs.params {
|
|
gotWorkerCalls[params.WorkerID.UUID] = true
|
|
}
|
|
require.Equal(t, workerIDs, gotWorkerCalls)
|
|
}
|
|
|
|
// TestAcquirer_WaitsOnNoJobs tests that after a call that returns no jobs, Acquirer waits for a new
|
|
// job posting before retrying
|
|
func TestAcquirer_WaitsOnNoJobs(t *testing.T) {
|
|
t.Parallel()
|
|
fs := newFakeOrderedStore()
|
|
ps := pubsub.NewInMemory()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
|
|
|
workerID := uuid.New()
|
|
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
|
tags := provisionerdserver.Tags{
|
|
"foo": "bar",
|
|
}
|
|
acquiree := newTestAcquiree(t, workerID, pt, tags)
|
|
jobID := uuid.New()
|
|
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
|
|
require.NoError(t, err)
|
|
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
|
require.NoError(t, err)
|
|
acquiree.startAcquire(ctx, uut)
|
|
require.Eventually(t, func() bool {
|
|
fs.mu.Lock()
|
|
defer fs.mu.Unlock()
|
|
return len(fs.params) == 1
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
acquiree.requireBlocked()
|
|
|
|
// First send in some with incompatible tags & types
|
|
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{
|
|
"cool": "tapes",
|
|
"strong": "bad",
|
|
})
|
|
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{
|
|
"foo": "fighters",
|
|
})
|
|
postJob(t, ps, database.ProvisionerTypeTerraform, provisionerdserver.Tags{
|
|
"foo": "bar",
|
|
})
|
|
acquiree.requireBlocked()
|
|
|
|
// compatible tags
|
|
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{})
|
|
job := acquiree.success(ctx)
|
|
require.Equal(t, jobID, job.ID)
|
|
}
|
|
|
|
// TestAcquirer_RetriesPending tests that if we get a job posting while a db call is in progress
|
|
// we retry to acquire a job immediately, even if the first call returned no jobs. We want this
|
|
// behavior since the query that found no jobs could have resolved before the job was posted, but
|
|
// the query result could reach us later than the posting over the pubsub.
|
|
func TestAcquirer_RetriesPending(t *testing.T) {
|
|
t.Parallel()
|
|
fs := newFakeOrderedStore()
|
|
ps := pubsub.NewInMemory()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
|
|
|
workerID := uuid.New()
|
|
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
|
tags := provisionerdserver.Tags{
|
|
"foo": "bar",
|
|
}
|
|
acquiree := newTestAcquiree(t, workerID, pt, tags)
|
|
jobID := uuid.New()
|
|
|
|
acquiree.startAcquire(ctx, uut)
|
|
require.Eventually(t, func() bool {
|
|
fs.mu.Lock()
|
|
defer fs.mu.Unlock()
|
|
return len(fs.params) == 1
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
|
|
// First call to DB is in progress. Send in posting
|
|
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{})
|
|
// there is a race between the posting being processed and the DB call
|
|
// returning. In either case we should retry, but we're trying to hit the
|
|
// case where the posting is processed first, so sleep a little bit to give
|
|
// it a chance.
|
|
time.Sleep(testutil.IntervalMedium)
|
|
|
|
// Now, when first DB call returns ErrNoRows we retry.
|
|
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
|
|
require.NoError(t, err)
|
|
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
|
require.NoError(t, err)
|
|
|
|
job := acquiree.success(ctx)
|
|
require.Equal(t, jobID, job.ID)
|
|
}
|
|
|
|
// TestAcquirer_DifferentDomains tests that acquirees with different tags don't block each other
|
|
func TestAcquirer_DifferentDomains(t *testing.T) {
|
|
t.Parallel()
|
|
fs := newFakeTaggedStore(t)
|
|
ps := pubsub.NewInMemory()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
|
|
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
|
worker0 := uuid.New()
|
|
tags0 := provisionerdserver.Tags{
|
|
"worker": "0",
|
|
}
|
|
acquiree0 := newTestAcquiree(t, worker0, pt, tags0)
|
|
worker1 := uuid.New()
|
|
tags1 := provisionerdserver.Tags{
|
|
"worker": "1",
|
|
}
|
|
acquiree1 := newTestAcquiree(t, worker1, pt, tags1)
|
|
jobID := uuid.New()
|
|
fs.jobs = []database.ProvisionerJob{
|
|
{ID: jobID, Provisioner: database.ProvisionerTypeEcho, Tags: database.StringMap{"worker": "1"}},
|
|
}
|
|
|
|
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
|
|
|
ctx0, cancel0 := context.WithCancel(ctx)
|
|
defer cancel0()
|
|
acquiree0.startAcquire(ctx0, uut)
|
|
select {
|
|
case params := <-fs.params:
|
|
require.Equal(t, worker0, params.WorkerID.UUID)
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for call to database from worker0")
|
|
}
|
|
acquiree0.requireBlocked()
|
|
|
|
// worker1 should not be blocked by worker0, as they are different tags
|
|
acquiree1.startAcquire(ctx, uut)
|
|
job := acquiree1.success(ctx)
|
|
require.Equal(t, jobID, job.ID)
|
|
|
|
cancel0()
|
|
acquiree0.requireCanceled(ctx)
|
|
}
|
|
|
|
func TestAcquirer_BackupPoll(t *testing.T) {
|
|
t.Parallel()
|
|
fs := newFakeOrderedStore()
|
|
ps := pubsub.NewInMemory()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
uut := provisionerdserver.NewAcquirer(
|
|
ctx, logger.Named("acquirer"), fs, ps,
|
|
provisionerdserver.TestingBackupPollDuration(testutil.IntervalMedium),
|
|
)
|
|
|
|
workerID := uuid.New()
|
|
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
|
tags := provisionerdserver.Tags{
|
|
"foo": "bar",
|
|
}
|
|
acquiree := newTestAcquiree(t, workerID, pt, tags)
|
|
jobID := uuid.New()
|
|
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
|
|
require.NoError(t, err)
|
|
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
|
require.NoError(t, err)
|
|
acquiree.startAcquire(ctx, uut)
|
|
job := acquiree.success(ctx)
|
|
require.Equal(t, jobID, job.ID)
|
|
}
|
|
|
|
// TestAcquirer_UnblockOnCancel tests that a canceled call doesn't block a call
|
|
// from the same domain.
|
|
func TestAcquirer_UnblockOnCancel(t *testing.T) {
|
|
t.Parallel()
|
|
fs := newFakeOrderedStore()
|
|
ps := pubsub.NewInMemory()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
|
|
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
|
worker0 := uuid.New()
|
|
tags := provisionerdserver.Tags{
|
|
"foo": "bar",
|
|
}
|
|
acquiree0 := newTestAcquiree(t, worker0, pt, tags)
|
|
worker1 := uuid.New()
|
|
acquiree1 := newTestAcquiree(t, worker1, pt, tags)
|
|
jobID := uuid.New()
|
|
|
|
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
|
|
|
// queue up 2 responses --- we may not need both, since acquiree0 will
|
|
// usually cancel before calling, but cancel is async, so it might call.
|
|
for i := 0; i < 2; i++ {
|
|
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
ctx0, cancel0 := context.WithCancel(ctx)
|
|
cancel0()
|
|
acquiree0.startAcquire(ctx0, uut)
|
|
acquiree1.startAcquire(ctx, uut)
|
|
job := acquiree1.success(ctx)
|
|
require.Equal(t, jobID, job.ID)
|
|
}
|
|
|
|
func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) {
|
|
t.Helper()
|
|
msg, err := json.Marshal(provisionerjobs.JobPosting{
|
|
ProvisionerType: pt,
|
|
Tags: tags,
|
|
})
|
|
require.NoError(t, err)
|
|
err = ps.Publish(provisionerjobs.EventJobPosted, msg)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// fakeOrderedStore is a fake store that lets tests send AcquireProvisionerJob
|
|
// results in order over a channel, and tests for overlapped calls.
|
|
type fakeOrderedStore struct {
|
|
jobs chan database.ProvisionerJob
|
|
errors chan error
|
|
|
|
mu sync.Mutex
|
|
params []database.AcquireProvisionerJobParams
|
|
|
|
// inflight and overlaps track whether any calls from workers overlap with
|
|
// one another
|
|
inflight map[uuid.UUID]bool
|
|
overlaps [][]uuid.UUID
|
|
}
|
|
|
|
func newFakeOrderedStore() *fakeOrderedStore {
|
|
return &fakeOrderedStore{
|
|
// buffer the channels so that we can queue up lots of responses to
|
|
// occur nearly simultaneously
|
|
jobs: make(chan database.ProvisionerJob, 100),
|
|
errors: make(chan error, 100),
|
|
inflight: make(map[uuid.UUID]bool),
|
|
}
|
|
}
|
|
|
|
func (s *fakeOrderedStore) AcquireProvisionerJob(
|
|
_ context.Context, params database.AcquireProvisionerJobParams,
|
|
) (
|
|
database.ProvisionerJob, error,
|
|
) {
|
|
s.mu.Lock()
|
|
s.params = append(s.params, params)
|
|
for workerID := range s.inflight {
|
|
s.overlaps = append(s.overlaps, []uuid.UUID{workerID, params.WorkerID.UUID})
|
|
}
|
|
s.inflight[params.WorkerID.UUID] = true
|
|
s.mu.Unlock()
|
|
|
|
job := <-s.jobs
|
|
err := <-s.errors
|
|
|
|
s.mu.Lock()
|
|
delete(s.inflight, params.WorkerID.UUID)
|
|
s.mu.Unlock()
|
|
|
|
return job, err
|
|
}
|
|
|
|
func (s *fakeOrderedStore) sendCtx(ctx context.Context, job database.ProvisionerJob, err error) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case s.jobs <- job:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case s.errors <- err:
|
|
// OK
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// fakeTaggedStore is a test store that allows tests to specify which jobs are
|
|
// available, and returns them to callers with the appropriate provisioner type
|
|
// and tags. It doesn't care about the order.
|
|
type fakeTaggedStore struct {
|
|
t *testing.T
|
|
mu sync.Mutex
|
|
jobs []database.ProvisionerJob
|
|
params chan database.AcquireProvisionerJobParams
|
|
}
|
|
|
|
func newFakeTaggedStore(t *testing.T) *fakeTaggedStore {
|
|
return &fakeTaggedStore{
|
|
t: t,
|
|
params: make(chan database.AcquireProvisionerJobParams, 100),
|
|
}
|
|
}
|
|
|
|
func (s *fakeTaggedStore) AcquireProvisionerJob(
|
|
_ context.Context, params database.AcquireProvisionerJobParams,
|
|
) (
|
|
database.ProvisionerJob, error,
|
|
) {
|
|
defer func() { s.params <- params }()
|
|
var tags provisionerdserver.Tags
|
|
err := json.Unmarshal(params.Tags, &tags)
|
|
if !assert.NoError(s.t, err) {
|
|
return database.ProvisionerJob{}, err
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
jobLoop:
|
|
for i, job := range s.jobs {
|
|
if !slices.Contains(params.Types, job.Provisioner) {
|
|
continue
|
|
}
|
|
for k, v := range job.Tags {
|
|
pv, ok := tags[k]
|
|
if !ok {
|
|
continue jobLoop
|
|
}
|
|
if v != pv {
|
|
continue jobLoop
|
|
}
|
|
}
|
|
// found a job!
|
|
s.jobs = append(s.jobs[:i], s.jobs[i+1:]...)
|
|
return job, nil
|
|
}
|
|
return database.ProvisionerJob{}, sql.ErrNoRows
|
|
}
|
|
|
|
// testAcquiree is a helper type that handles asynchronously calling AcquireJob
|
|
// and asserting whether or not it returns, blocks, or is canceled.
|
|
type testAcquiree struct {
|
|
t *testing.T
|
|
workerID uuid.UUID
|
|
pt []database.ProvisionerType
|
|
tags provisionerdserver.Tags
|
|
ec chan error
|
|
jc chan database.ProvisionerJob
|
|
}
|
|
|
|
func newTestAcquiree(t *testing.T, workerID uuid.UUID, pt []database.ProvisionerType, tags provisionerdserver.Tags) *testAcquiree {
|
|
return &testAcquiree{
|
|
t: t,
|
|
workerID: workerID,
|
|
pt: pt,
|
|
tags: tags,
|
|
ec: make(chan error, 1),
|
|
jc: make(chan database.ProvisionerJob, 1),
|
|
}
|
|
}
|
|
|
|
func (a *testAcquiree) startAcquire(ctx context.Context, uut *provisionerdserver.Acquirer) {
|
|
go func() {
|
|
j, e := uut.AcquireJob(ctx, a.workerID, a.pt, a.tags)
|
|
a.ec <- e
|
|
a.jc <- j
|
|
}()
|
|
}
|
|
|
|
func (a *testAcquiree) success(ctx context.Context) database.ProvisionerJob {
|
|
select {
|
|
case <-ctx.Done():
|
|
a.t.Fatal("timeout waiting for AcquireJob error")
|
|
case err := <-a.ec:
|
|
require.NoError(a.t, err)
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
a.t.Fatal("timeout waiting for AcquireJob job")
|
|
case job := <-a.jc:
|
|
return job
|
|
}
|
|
// unhittable
|
|
return database.ProvisionerJob{}
|
|
}
|
|
|
|
func (a *testAcquiree) requireBlocked() {
|
|
select {
|
|
case <-a.ec:
|
|
a.t.Fatal("AcquireJob should block")
|
|
default:
|
|
// OK
|
|
}
|
|
}
|
|
|
|
func (a *testAcquiree) requireCanceled(ctx context.Context) {
|
|
select {
|
|
case err := <-a.ec:
|
|
require.ErrorIs(a.t, err, context.Canceled)
|
|
case <-ctx.Done():
|
|
a.t.Fatal("timed out waiting for AcquireJob")
|
|
}
|
|
select {
|
|
case job := <-a.jc:
|
|
require.Equal(a.t, uuid.Nil, job.ID)
|
|
case <-ctx.Done():
|
|
a.t.Fatal("timed out waiting for AcquireJob")
|
|
}
|
|
}
|