package provisionerd_test import ( "archive/tar" "bytes" "context" "fmt" "io" "os" "path/filepath" "sync" "testing" "time" "github.com/hashicorp/yamux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "go.uber.org/goleak" "golang.org/x/xerrors" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/coder/v2/provisionerd" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/testutil" ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } func closedWithin(c chan struct{}, d time.Duration) func() bool { return func() bool { select { case <-c: return true case <-time.After(d): return false } } } func TestProvisionerd(t *testing.T) { t.Parallel() noopUpdateJob := func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { return &proto.UpdateJobResponse{}, nil } t.Run("InstantClose", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil }, provisionerd.LocalProvisioners{}) require.NoError(t, closer.Close()) }) t.Run("ConnectErrorClose", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) completeChan := make(chan struct{}) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { defer close(completeChan) return nil, xerrors.New("an error") }, provisionerd.LocalProvisioners{}) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, closer.Close()) }) t.Run("CloseCancelsJob", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) completeChan := make(chan struct{}) var completed sync.Once var closer io.Closer var closerMutex sync.Mutex closerMutex.Lock() closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { err := stream.Send(&proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ Metadata: &sdkproto.Metadata{}, }, }, }) assert.NoError(t, err) return nil }, updateJob: noopUpdateJob, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { completed.Do(func() { close(completeChan) }) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func(_ *provisionersdk.Session, _ *sdkproto.ParseRequest, _ <-chan struct{}) *sdkproto.ParseComplete { closerMutex.Lock() defer closerMutex.Unlock() err := closer.Close() c := &sdkproto.ParseComplete{} if err != nil { c.Error = err.Error() } return c }, }), }) closerMutex.Unlock() require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, closer.Close()) }) t.Run("MaliciousTar", func(t *testing.T) { // Ensures tars with "../../../etc/passwd" as the path // are not allowed to run, and will fail the job. t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( completeChan = make(chan struct{}) completeOnce sync.Once acq = newAcquireOne(t, &proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "../../../etc/passwd": "content", }), Type: &proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ Metadata: &sdkproto.Metadata{}, }, }, }) ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: acq.acquireWithCancel, updateJob: noopUpdateJob, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { completeOnce.Do(func() { close(completeChan) }) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}), }) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, closer.Close()) }) t.Run("RunningPeriodicUpdate", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( completeChan = make(chan struct{}) completeOnce sync.Once ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { err := stream.Send(&proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ Metadata: &sdkproto.Metadata{}, }, }, }) assert.NoError(t, err) return nil }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { completeOnce.Do(func() { close(completeChan) }) return &proto.UpdateJobResponse{}, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func( _ *provisionersdk.Session, _ *sdkproto.ParseRequest, cancelOrComplete <-chan struct{}, ) *sdkproto.ParseComplete { <-cancelOrComplete return &sdkproto.ParseComplete{} }, }), }) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, closer.Close()) }) t.Run("TemplateImport", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( didComplete atomic.Bool didLog atomic.Bool didReadme atomic.Bool acq = newAcquireOne(t, &proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", provisionersdk.ReadmeFile: "# A cool template 😎\n", }), Type: &proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ Metadata: &sdkproto.Metadata{}, }, }, }) ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: acq.acquireWithCancel, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { if len(update.Logs) > 0 { didLog.Store(true) } if len(update.Readme) > 0 { didReadme.Store(true) } return &proto.UpdateJobResponse{}, nil }, completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { didComplete.Store(true) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func( s *provisionersdk.Session, _ *sdkproto.ParseRequest, cancelOrComplete <-chan struct{}, ) *sdkproto.ParseComplete { data, err := os.ReadFile(filepath.Join(s.WorkDirectory, "test.txt")) require.NoError(t, err) require.Equal(t, "content", string(data)) s.ProvisionLog(sdkproto.LogLevel_INFO, "hello") return &sdkproto.ParseComplete{} }, plan: func( s *provisionersdk.Session, _ *sdkproto.PlanRequest, cancelOrComplete <-chan struct{}, ) *sdkproto.PlanComplete { s.ProvisionLog(sdkproto.LogLevel_INFO, "hello") return &sdkproto.PlanComplete{ Resources: []*sdkproto.Resource{}, } }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { t.Error("dry run should not apply") return &sdkproto.ApplyComplete{} }, }), }) require.Condition(t, closedWithin(acq.complete, testutil.WaitShort)) require.NoError(t, closer.Close()) assert.True(t, didLog.Load(), "should log some updates") assert.True(t, didComplete.Load(), "should complete the job") }) t.Run("TemplateDryRun", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( didComplete atomic.Bool didLog atomic.Bool metadata = &sdkproto.Metadata{} acq = newAcquireOne(t, &proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_TemplateDryRun_{ TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{ Metadata: metadata, }, }, }) ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: acq.acquireWithCancel, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { if len(update.Logs) == 0 { t.Log("provisionerDaemonTestServer: no log messages") return &proto.UpdateJobResponse{}, nil } didLog.Store(true) for _, msg := range update.Logs { t.Log("provisionerDaemonTestServer", "msg:", msg) } return &proto.UpdateJobResponse{}, nil }, completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { didComplete.Store(true) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( _ *provisionersdk.Session, _ *sdkproto.PlanRequest, _ <-chan struct{}, ) *sdkproto.PlanComplete { return &sdkproto.PlanComplete{ Resources: []*sdkproto.Resource{}, } }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { t.Error("dry run should not apply") return &sdkproto.ApplyComplete{} }, }), }) require.Condition(t, closedWithin(acq.complete, testutil.WaitShort)) require.NoError(t, closer.Close()) assert.True(t, didLog.Load(), "should log some updates") assert.True(t, didComplete.Load(), "should complete the job") }) t.Run("WorkspaceBuild", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( didComplete atomic.Bool didLog atomic.Bool acq = newAcquireOne(t, &proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ Metadata: &sdkproto.Metadata{}, }, }, }) ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: acq.acquireWithCancel, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { if len(update.Logs) != 0 { didLog.Store(true) } return &proto.UpdateJobResponse{}, nil }, completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { didComplete.Store(true) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, _ *sdkproto.PlanRequest, cancelOrComplete <-chan struct{}, ) *sdkproto.PlanComplete { s.ProvisionLog(sdkproto.LogLevel_DEBUG, "wow") return &sdkproto.PlanComplete{} }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { return &sdkproto.ApplyComplete{} }, }), }) require.Condition(t, closedWithin(acq.complete, testutil.WaitShort)) require.NoError(t, closer.Close()) assert.True(t, didLog.Load(), "should log some updates") assert.True(t, didComplete.Load(), "should complete the job") }) t.Run("WorkspaceBuildQuotaExceeded", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( didComplete atomic.Bool didLog atomic.Bool didFail atomic.Bool acq = newAcquireOne(t, &proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ Metadata: &sdkproto.Metadata{}, }, }, }) ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: acq.acquireWithCancel, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { if len(update.Logs) != 0 { didLog.Store(true) } return &proto.UpdateJobResponse{}, nil }, completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { didComplete.Store(true) return &proto.Empty{}, nil }, commitQuota: func(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) { return &proto.CommitQuotaResponse{ Ok: com.DailyCost < 20, }, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { didFail.Store(true) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, _ *sdkproto.PlanRequest, cancelOrComplete <-chan struct{}, ) *sdkproto.PlanComplete { s.ProvisionLog(sdkproto.LogLevel_DEBUG, "wow") return &sdkproto.PlanComplete{ Resources: []*sdkproto.Resource{ { DailyCost: 10, }, { DailyCost: 15, }, }, } }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { t.Error("should not apply when resources exceed quota") return &sdkproto.ApplyComplete{ Resources: []*sdkproto.Resource{ { DailyCost: 10, }, { DailyCost: 15, }, }, } }, }), }) require.Condition(t, closedWithin(acq.complete, testutil.WaitShort)) require.NoError(t, closer.Close()) assert.True(t, didLog.Load(), "should log some updates") assert.False(t, didComplete.Load(), "should not complete the job") assert.True(t, didFail.Load(), "should fail the job") }) t.Run("WorkspaceBuildFailComplete", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( didFail atomic.Bool acq = newAcquireOne(t, &proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ Metadata: &sdkproto.Metadata{}, }, }, }) ) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: acq.acquireWithCancel, updateJob: noopUpdateJob, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { didFail.Store(true) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, _ *sdkproto.PlanRequest, cancelOrComplete <-chan struct{}, ) *sdkproto.PlanComplete { return &sdkproto.PlanComplete{ Error: "some error", } }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { t.Error("should not apply when plan errors") return &sdkproto.ApplyComplete{ Error: "some error", } }, }), }) require.Condition(t, closedWithin(acq.complete, testutil.WaitShort)) require.NoError(t, closer.Close()) assert.True(t, didFail.Load(), "should fail the job") }) // Simulates when there is no coderd to connect to. So the client connection // will never be established. t.Run("ShutdownNoCoderd", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) connectAttemptedClose := sync.Once{} connectAttempted := make(chan struct{}) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { // This is the dial out to Coderd, which in this unit test will always fail. connectAttemptedClose.Do(func() { close(connectAttempted) }) return nil, xerrors.New("client connection always fails") }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}), }) // Wait for at least 1 attempt to connect to ensure the connect go routine // is running. require.Condition(t, closedWithin(connectAttempted, testutil.WaitShort)) // The test is ensuring this Shutdown call does not block indefinitely. // If it does, the context will return with an error, and the test will // fail. shutdownCtx := testutil.Context(t, testutil.WaitShort) err := server.Shutdown(shutdownCtx, true) require.NoError(t, err, "shutdown did not unblock. Failed to close the server gracefully.") require.NoError(t, server.Close()) }) t.Run("Shutdown", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var updated sync.Once var completed sync.Once updateChan := make(chan struct{}) completeChan := make(chan struct{}) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { err := stream.Send(&proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ Metadata: &sdkproto.Metadata{}, }, }, }) assert.NoError(t, err) return nil }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { if len(update.Logs) > 0 { for _, log := range update.Logs { if log.Source != proto.LogSource_PROVISIONER { continue } // Close on a log so we know when the job is in progress! updated.Do(func() { close(updateChan) }) break } } return &proto.UpdateJobResponse{}, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { completed.Do(func() { close(completeChan) }) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, _ *sdkproto.PlanRequest, canceledOrComplete <-chan struct{}, ) *sdkproto.PlanComplete { s.ProvisionLog(sdkproto.LogLevel_DEBUG, "in progress") <-canceledOrComplete return &sdkproto.PlanComplete{ Error: "some error", } }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { t.Error("should not apply when shut down during plan") return &sdkproto.ApplyComplete{} }, }), }) require.Condition(t, closedWithin(updateChan, testutil.WaitShort)) err := server.Shutdown(context.Background(), true) require.NoError(t, err) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, server.Close()) }) t.Run("ShutdownFromJob", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var completed sync.Once var updated sync.Once updateChan := make(chan struct{}) completeChan := make(chan struct{}) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { err := stream.Send(&proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ Metadata: &sdkproto.Metadata{}, }, }, }) assert.NoError(t, err) return nil }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { resp := &proto.UpdateJobResponse{} if len(update.Logs) > 0 { for _, log := range update.Logs { if log.Source != proto.LogSource_PROVISIONER { continue } // Close on a log so we know when the job is in progress! updated.Do(func() { close(updateChan) }) break } } // start returning Canceled once we've gotten at least one log. select { case <-updateChan: resp.Canceled = true default: // pass } return resp, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { completed.Do(func() { close(completeChan) }) return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, _ *sdkproto.PlanRequest, canceledOrComplete <-chan struct{}, ) *sdkproto.PlanComplete { s.ProvisionLog(sdkproto.LogLevel_DEBUG, "in progress") <-canceledOrComplete return &sdkproto.PlanComplete{ Error: "some error", } }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { t.Error("should not apply when shut down during plan") return &sdkproto.ApplyComplete{} }, }), }) require.Condition(t, closedWithin(updateChan, testutil.WaitShort)) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) t.Run("ReconnectAndFail", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( second atomic.Bool failChan = make(chan struct{}) failOnce sync.Once failedChan = make(chan struct{}) failedOnce sync.Once completeChan = make(chan struct{}) completeOnce sync.Once ) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { job := &proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ Metadata: &sdkproto.Metadata{}, }, }, } if second.Load() { job = &proto.AcquiredJob{} _, err := stream.Recv() assert.NoError(t, err) } err := stream.Send(job) assert.NoError(t, err) return nil }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { return &proto.UpdateJobResponse{}, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { assert.Equal(t, job.JobId, "test") if second.Load() { completeOnce.Do(func() { close(completeChan) }) return &proto.Empty{}, nil } failOnce.Do(func() { close(failChan) }) <-failedChan return &proto.Empty{}, nil }, }) if !second.Load() { go func() { <-failChan _ = client.DRPCConn().Close() second.Store(true) time.Sleep(50 * time.Millisecond) failedOnce.Do(func() { close(failedChan) }) }() } return client, nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( _ *provisionersdk.Session, _ *sdkproto.PlanRequest, _ <-chan struct{}, ) *sdkproto.PlanComplete { return &sdkproto.PlanComplete{ Error: "some error", } }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { t.Error("should not apply when error during plan") return &sdkproto.ApplyComplete{} }, }), }) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) t.Run("ReconnectAndComplete", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) var ( second atomic.Bool failChan = make(chan struct{}) failOnce sync.Once failedChan = make(chan struct{}) failedOnce sync.Once completeChan = make(chan struct{}) completeOnce sync.Once ) server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { if second.Load() { completeOnce.Do(func() { close(completeChan) }) _, err := stream.Recv() assert.NoError(t, err) return nil } job := &proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ Metadata: &sdkproto.Metadata{}, }, }, } err := stream.Send(job) assert.NoError(t, err) return nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { return nil, yamux.ErrSessionShutdown }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { return &proto.UpdateJobResponse{}, nil }, completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { if second.Load() { return &proto.Empty{}, nil } failOnce.Do(func() { close(failChan) }) <-failedChan return &proto.Empty{}, nil }, }) if !second.Load() { go func() { <-failChan _ = client.DRPCConn().Close() second.Store(true) time.Sleep(50 * time.Millisecond) failedOnce.Do(func() { close(failedChan) }) }() } return client, nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( _ *provisionersdk.Session, _ *sdkproto.PlanRequest, _ <-chan struct{}, ) *sdkproto.PlanComplete { return &sdkproto.PlanComplete{} }, apply: func( _ *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { return &sdkproto.ApplyComplete{} }, }), }) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) t.Log("completeChan closed") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) t.Run("UpdatesBeforeComplete", func(t *testing.T) { t.Parallel() done := make(chan struct{}) t.Cleanup(func() { close(done) }) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) m := sync.Mutex{} var ops []string completeChan := make(chan struct{}) completeOnce := sync.Once{} server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{ acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { m.Lock() defer m.Unlock() logger.Info(ctx, "provisioner stage: AcquiredJob") if len(ops) > 0 { _, err := stream.Recv() assert.NoError(t, err) err = stream.Send(&proto.AcquiredJob{}) assert.NoError(t, err) return nil } ops = append(ops, "AcquireJob") err := stream.Send(&proto.AcquiredJob{ JobId: "test", Provisioner: "someprovisioner", TemplateSourceArchive: createTar(t, map[string]string{ "test.txt": "content", }), Type: &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ Metadata: &sdkproto.Metadata{}, }, }, }) assert.NoError(t, err) return nil }, updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { m.Lock() defer m.Unlock() logger.Info(ctx, "provisioner stage: UpdateJob") ops = append(ops, "UpdateJob") for _, log := range update.Logs { ops = append(ops, fmt.Sprintf("Log: %s | %s", log.Stage, log.Output)) } return &proto.UpdateJobResponse{}, nil }, completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { m.Lock() defer m.Unlock() logger.Info(ctx, "provisioner stage: CompleteJob") ops = append(ops, "CompleteJob") completeOnce.Do(func() { close(completeChan) }) return &proto.Empty{}, nil }, failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { m.Lock() defer m.Unlock() logger.Info(ctx, "provisioner stage: FailJob") ops = append(ops, "FailJob") return &proto.Empty{}, nil }, }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, _ *sdkproto.PlanRequest, _ <-chan struct{}, ) *sdkproto.PlanComplete { s.ProvisionLog(sdkproto.LogLevel_DEBUG, "wow") return &sdkproto.PlanComplete{} }, apply: func( s *provisionersdk.Session, _ *sdkproto.ApplyRequest, _ <-chan struct{}, ) *sdkproto.ApplyComplete { s.ProvisionLog(sdkproto.LogLevel_DEBUG, "wow") return &sdkproto.ApplyComplete{} }, }), }) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) assert.Equal(t, ops[len(ops)-1], "CompleteJob") assert.Contains(t, ops[0:len(ops)-1], "Log: Cleaning Up | ") }) } // Creates an in-memory tar of the files provided. func createTar(t *testing.T, files map[string]string) []byte { var buffer bytes.Buffer writer := tar.NewWriter(&buffer) for path, content := range files { err := writer.WriteHeader(&tar.Header{ Name: path, Size: int64(len(content)), }) require.NoError(t, err) _, err = writer.Write([]byte(content)) require.NoError(t, err) } err := writer.Flush() require.NoError(t, err) return buffer.Bytes() } // Creates a provisionerd implementation with the provided dialer and provisioners. func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, connector provisionerd.LocalProvisioners) *provisionerd.Server { server := provisionerd.New(dialer, &provisionerd.Options{ Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("provisionerd").Leveled(slog.LevelDebug), UpdateInterval: 50 * time.Millisecond, Connector: connector, }) t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() _ = server.Shutdown(ctx, true) _ = server.Close() }) return server } // Creates a provisionerd protobuf client that's connected // to the server implementation provided. func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient { t.Helper() if server.failJob == nil { // Default to asserting the error from the failure, otherwise // it can be lost in tests! server.failJob = func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { assert.Fail(t, job.Error) return &proto.Empty{}, nil } } clientPipe, serverPipe := drpc.MemTransportPipe() t.Cleanup(func() { _ = clientPipe.Close() _ = serverPipe.Close() }) mux := drpcmux.New() err := proto.DRPCRegisterProvisionerDaemon(mux, &server) require.NoError(t, err) srv := drpcserver.New(mux) ctx, cancelFunc := context.WithCancel(context.Background()) closed := make(chan struct{}) go func() { defer close(closed) _ = srv.Serve(ctx, serverPipe) }() t.Cleanup(func() { cancelFunc() <-closed select { case <-done: t.Error("createProvisionerDaemonClient cleanup after test was done!") default: } }) select { case <-done: t.Error("called createProvisionerDaemonClient after test was done!") default: } return proto.NewDRPCProvisionerDaemonClient(clientPipe) } // Creates a provisioner protobuf client that's connected // to the server implementation provided. func createProvisionerClient(t *testing.T, done <-chan struct{}, server provisionerTestServer) sdkproto.DRPCProvisionerClient { t.Helper() clientPipe, serverPipe := drpc.MemTransportPipe() t.Cleanup(func() { _ = clientPipe.Close() _ = serverPipe.Close() }) ctx, cancelFunc := context.WithCancel(context.Background()) closed := make(chan struct{}) tempDir := t.TempDir() go func() { defer close(closed) _ = provisionersdk.Serve(ctx, &server, &provisionersdk.ServeOptions{ Listener: serverPipe, Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("test-provisioner"), WorkDirectory: tempDir, }) }() t.Cleanup(func() { cancelFunc() <-closed select { case <-done: t.Error("createProvisionerClient cleanup after test was done!") default: } }) select { case <-done: t.Error("called createProvisionerClient after test was done!") default: } return sdkproto.NewDRPCProvisionerClient(clientPipe) } type provisionerTestServer struct { parse func(s *provisionersdk.Session, r *sdkproto.ParseRequest, canceledOrComplete <-chan struct{}) *sdkproto.ParseComplete plan func(s *provisionersdk.Session, r *sdkproto.PlanRequest, canceledOrComplete <-chan struct{}) *sdkproto.PlanComplete apply func(s *provisionersdk.Session, r *sdkproto.ApplyRequest, canceledOrComplete <-chan struct{}) *sdkproto.ApplyComplete } func (p *provisionerTestServer) Parse(s *provisionersdk.Session, r *sdkproto.ParseRequest, canceledOrComplete <-chan struct{}) *sdkproto.ParseComplete { return p.parse(s, r, canceledOrComplete) } func (p *provisionerTestServer) Plan(s *provisionersdk.Session, r *sdkproto.PlanRequest, canceledOrComplete <-chan struct{}) *sdkproto.PlanComplete { return p.plan(s, r, canceledOrComplete) } func (p *provisionerTestServer) Apply(s *provisionersdk.Session, r *sdkproto.ApplyRequest, canceledOrComplete <-chan struct{}) *sdkproto.ApplyComplete { return p.apply(s, r, canceledOrComplete) } // Fulfills the protobuf interface for a ProvisionerDaemon with // passable functions for dynamic functionality. type provisionerDaemonTestServer struct { acquireJobWithCancel func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error commitQuota func(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) updateJob func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) failJob func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) completeJob func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) } func (*provisionerDaemonTestServer) AcquireJob(context.Context, *proto.Empty) (*proto.AcquiredJob, error) { return nil, xerrors.New("deprecated!") } func (p *provisionerDaemonTestServer) AcquireJobWithCancel(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { if p.acquireJobWithCancel != nil { return p.acquireJobWithCancel(stream) } // default behavior is to wait for cancel _, _ = stream.Recv() _ = stream.Send(&proto.AcquiredJob{}) return nil } func (p *provisionerDaemonTestServer) CommitQuota(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) { if p.commitQuota == nil { return &proto.CommitQuotaResponse{ Ok: true, }, nil } return p.commitQuota(ctx, com) } func (p *provisionerDaemonTestServer) UpdateJob(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { return p.updateJob(ctx, update) } func (p *provisionerDaemonTestServer) FailJob(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) { return p.failJob(ctx, job) } func (p *provisionerDaemonTestServer) CompleteJob(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) { return p.completeJob(ctx, job) } // acquireOne provides a function that returns a single provisioner job, then subsequent calls block until canceled. // The complete channel is closed on the 2nd call. type acquireOne struct { t *testing.T mu sync.Mutex job *proto.AcquiredJob called int complete chan struct{} } func newAcquireOne(t *testing.T, job *proto.AcquiredJob) *acquireOne { return &acquireOne{ t: t, job: job, complete: make(chan struct{}), } } func (a *acquireOne) acquireWithCancel(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error { a.mu.Lock() defer a.mu.Unlock() a.called++ if a.called == 2 { close(a.complete) } if a.called > 1 { _, _ = stream.Recv() _ = stream.Send(&proto.AcquiredJob{}) return nil } err := stream.Send(a.job) assert.NoError(a.t, err) return nil }