package provisionerdserver_test import ( "context" "database/sql" "encoding/json" "io" "net/url" "strings" "sync" "sync/atomic" "testing" "time" "golang.org/x/xerrors" "storj.io/drpc" "cdr.dev/slog" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/cli/clibase" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/schedule/cron" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/testutil" ) func mockAuditor() *atomic.Pointer[audit.Auditor] { ptr := &atomic.Pointer[audit.Auditor]{} mock := audit.Auditor(audit.NewMock()) ptr.Store(&mock) return ptr } func testTemplateScheduleStore() *atomic.Pointer[schedule.TemplateScheduleStore] { ptr := &atomic.Pointer[schedule.TemplateScheduleStore]{} store := schedule.NewAGPLTemplateScheduleStore() ptr.Store(&store) return ptr } func testUserQuietHoursScheduleStore() *atomic.Pointer[schedule.UserQuietHoursScheduleStore] { ptr := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{} store := schedule.NewAGPLUserQuietHoursScheduleStore() ptr.Store(&store) return ptr } func TestAcquireJob_LongPoll(t *testing.T) { t.Parallel() srv, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond}) job, err := srv.AcquireJob(context.Background(), nil) require.NoError(t, err) require.Equal(t, &proto.AcquiredJob{}, job) } func TestAcquireJobWithCancel_Cancel(t *testing.T) { t.Parallel() srv, _, _ := setup(t, false, nil) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() fs := newFakeStream(ctx) errCh := make(chan error) go func() { errCh <- srv.AcquireJobWithCancel(fs) }() fs.cancel() select { case <-ctx.Done(): t.Fatal("timed out waiting for AcquireJobWithCancel") case err := <-errCh: require.NoError(t, err) } job, err := fs.waitForJob() require.NoError(t, err) require.NotNil(t, job) require.Equal(t, "", job.JobId) } func TestAcquireJob(t *testing.T) { t.Parallel() // These test acquiring a single job without canceling, and tests both AcquireJob (deprecated) and // AcquireJobWithCancel as the way to get the job. cases := []struct { name string acquire func(context.Context, proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error) }{ {name: "Deprecated", acquire: func(ctx context.Context, srv proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error) { return srv.AcquireJob(ctx, nil) }}, {name: "WithCancel", acquire: func(ctx context.Context, srv proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error) { fs := newFakeStream(ctx) err := srv.AcquireJobWithCancel(fs) if err != nil { return nil, err } return fs.waitForJob() }}, } for _, tc := range cases { tc := tc t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) { t.Parallel() srv, db, _ := setup(t, false, nil) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() _, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ ID: uuid.New(), InitiatorID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionDryRun, }) require.NoError(t, err) _, err = tc.acquire(ctx, srv) require.ErrorContains(t, err, "sql: no rows in result set") }) t.Run(tc.name+"_WorkspaceBuildJob", func(t *testing.T) { t.Parallel() // Set the max session token lifetime so we can assert we // create an API key with an expiration within the bounds of the // deployment config. dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)} gitAuthProvider := "github" srv, db, ps := setup(t, false, &overrides{ deploymentValues: dv, externalAuthConfigs: []*externalauth.Config{{ ID: gitAuthProvider, OAuth2Config: &testutil.OAuth2Config{}, }}, }) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() user := dbgen.User(t, db, database.User{}) link := dbgen.UserLink(t, db, database.UserLink{ LoginType: database.LoginTypeOIDC, UserID: user.ID, OAuthExpiry: dbtime.Now().Add(time.Hour), OAuthAccessToken: "access-token", }) dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{ ProviderID: gitAuthProvider, UserID: user.ID, }) template := dbgen.Template(t, db, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) versionFile := dbgen.File(t, db, database.File{CreatedBy: user.ID}) version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, }, JobID: uuid.New(), }) err := db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{ JobID: version.JobID, ExternalAuthProviders: []string{gitAuthProvider}, UpdatedAt: dbtime.Now(), }) require.NoError(t, err) // Import version job _ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ ID: version.JobID, InitiatorID: user.ID, FileID: versionFile.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionImport, Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ TemplateVersionID: version.ID, UserVariableValues: []codersdk.VariableValue{ {Name: "second", Value: "bah"}, }, })), }) _ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{ TemplateVersionID: version.ID, Name: "first", Value: "first_value", DefaultValue: "default_value", Sensitive: true, }) _ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{ TemplateVersionID: version.ID, Name: "second", Value: "second_value", DefaultValue: "default_value", Required: true, Sensitive: false, }) workspace := dbgen.Workspace(t, db, database.Workspace{ TemplateID: template.ID, OwnerID: user.ID, }) build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, BuildNumber: 1, JobID: uuid.New(), TemplateVersionID: version.ID, Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) _ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ ID: build.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, FileID: file.ID, Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: build.ID, })), }) startPublished := make(chan struct{}) var closed bool closeStartSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { if !closed { close(startPublished) closed = true } }) require.NoError(t, err) defer closeStartSubscribe() var job *proto.AcquiredJob for { // Grab jobs until we find the workspace build job. There is also // an import version job that we need to ignore. job, err = tc.acquire(ctx, srv) require.NoError(t, err) if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok { break } } <-startPublished got, err := json.Marshal(job.Type) require.NoError(t, err) // Validate that a session token is generated during the job. sessionToken := job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken require.NotEmpty(t, sessionToken) toks := strings.Split(sessionToken, "-") require.Len(t, toks, 2, "invalid api key") key, err := db.GetAPIKeyByID(ctx, toks[0]) require.NoError(t, err) require.Equal(t, int64(dv.MaxTokenLifetime.Value().Seconds()), key.LifetimeSeconds) require.WithinDuration(t, time.Now().Add(dv.MaxTokenLifetime.Value()), key.ExpiresAt, time.Minute) want, err := json.Marshal(&proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ WorkspaceBuildId: build.ID.String(), WorkspaceName: workspace.Name, VariableValues: []*sdkproto.VariableValue{ { Name: "first", Value: "first_value", Sensitive: true, }, { Name: "second", Value: "second_value", }, }, ExternalAuthProviders: []*sdkproto.ExternalAuthProvider{{ Id: gitAuthProvider, AccessToken: "access_token", }}, Metadata: &sdkproto.Metadata{ CoderUrl: (&url.URL{}).String(), WorkspaceTransition: sdkproto.WorkspaceTransition_START, WorkspaceName: workspace.Name, WorkspaceOwner: user.Username, WorkspaceOwnerEmail: user.Email, WorkspaceOwnerOidcAccessToken: link.OAuthAccessToken, WorkspaceId: workspace.ID.String(), WorkspaceOwnerId: user.ID.String(), TemplateId: template.ID.String(), TemplateName: template.Name, TemplateVersion: version.Name, WorkspaceOwnerSessionToken: sessionToken, }, }, }) require.NoError(t, err) require.JSONEq(t, string(want), string(got)) // Assert that we delete the session token whenever // a stop is issued. stopbuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, BuildNumber: 2, JobID: uuid.New(), TemplateVersionID: version.ID, Transition: database.WorkspaceTransitionStop, Reason: database.BuildReasonInitiator, }) _ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ ID: stopbuild.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, FileID: file.ID, Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: stopbuild.ID, })), }) stopPublished := make(chan struct{}) closeStopSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { close(stopPublished) }) require.NoError(t, err) defer closeStopSubscribe() // Grab jobs until we find the workspace build job. There is also // an import version job that we need to ignore. job, err = tc.acquire(ctx, srv) require.NoError(t, err) _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_) require.True(t, ok, "acquired job not a workspace build?") <-stopPublished // Validate that a session token is deleted during a stop job. sessionToken = job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken require.Empty(t, sessionToken) _, err = db.GetAPIKeyByID(ctx, key.ID) require.ErrorIs(t, err, sql.ErrNoRows) }) t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) { t.Parallel() srv, db, ps := setup(t, false, nil) ctx := context.Background() user := dbgen.User(t, db, database.User{}) version := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) _ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, FileID: file.ID, Type: database.ProvisionerJobTypeTemplateVersionDryRun, Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{ TemplateVersionID: version.ID, WorkspaceName: "testing", })), }) job, err := tc.acquire(ctx, srv) require.NoError(t, err) got, err := json.Marshal(job.Type) require.NoError(t, err) want, err := json.Marshal(&proto.AcquiredJob_TemplateDryRun_{ TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{ Metadata: &sdkproto.Metadata{ CoderUrl: (&url.URL{}).String(), WorkspaceName: "testing", }, }, }) require.NoError(t, err) require.JSONEq(t, string(want), string(got)) }) t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) { t.Parallel() srv, db, ps := setup(t, false, nil) ctx := context.Background() user := dbgen.User(t, db, database.User{}) file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) _ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ FileID: file.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionImport, }) job, err := tc.acquire(ctx, srv) require.NoError(t, err) got, err := json.Marshal(job.Type) require.NoError(t, err) want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ Metadata: &sdkproto.Metadata{ CoderUrl: (&url.URL{}).String(), }, }, }) require.NoError(t, err) require.JSONEq(t, string(want), string(got)) }) t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) { t.Parallel() srv, db, ps := setup(t, false, nil) user := dbgen.User(t, db, database.User{}) version := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) _ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ FileID: file.ID, InitiatorID: user.ID, Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionImport, Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{ TemplateVersionID: version.ID, UserVariableValues: []codersdk.VariableValue{ {Name: "first", Value: "first_value"}, }, })), }) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() job, err := tc.acquire(ctx, srv) require.NoError(t, err) got, err := json.Marshal(job.Type) require.NoError(t, err) want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ UserVariableValues: []*sdkproto.VariableValue{ {Name: "first", Sensitive: true, Value: "first_value"}, }, Metadata: &sdkproto.Metadata{ CoderUrl: (&url.URL{}).String(), }, }, }) require.NoError(t, err) require.JSONEq(t, string(want), string(got)) }) } } func TestUpdateJob(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() srv, _, _ := setup(t, false, nil) _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: "hello", }) require.ErrorContains(t, err, "invalid UUID") _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: uuid.NewString(), }) require.ErrorContains(t, err, "no rows in result set") }) t.Run("NotRunning", func(t *testing.T) { t.Parallel() srv, db, _ := setup(t, false, nil) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionDryRun, }) require.NoError(t, err) _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.ID.String(), }) require.ErrorContains(t, err, "job isn't running yet") }) // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() srv, db, _ := setup(t, false, nil) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionDryRun, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: uuid.New(), Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.ID.String(), }) require.ErrorContains(t, err, "you don't own this job") }) setupJob := func(t *testing.T, db database.Store, srvID uuid.UUID) uuid.UUID { job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeTemplateVersionImport, StorageMethod: database.ProvisionerStorageMethodFile, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) return job.ID } t.Run("Success", func(t *testing.T) { t.Parallel() srvID := uuid.New() srv, db, _ := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, srvID) _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), }) require.NoError(t, err) }) t.Run("Logs", func(t *testing.T) { t.Parallel() srvID := uuid.New() srv, db, ps := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, srvID) published := make(chan struct{}) closeListener, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) { close(published) }) require.NoError(t, err) defer closeListener() _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), Logs: []*proto.Log{{ Source: proto.LogSource_PROVISIONER, Level: sdkproto.LogLevel_INFO, Output: "hi", }}, }) require.NoError(t, err) <-published }) t.Run("Readme", func(t *testing.T) { t.Parallel() srvID := uuid.New() srv, db, _ := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, srvID) versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: job, }) require.NoError(t, err) _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), Readme: []byte("# hello world"), }) require.NoError(t, err) version, err := db.GetTemplateVersionByID(ctx, versionID) require.NoError(t, err) require.Equal(t, "# hello world", version.Readme) }) t.Run("TemplateVariables", func(t *testing.T) { t.Parallel() t.Run("Valid", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() srvID := uuid.New() srv, db, _ := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, srvID) versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: job, }) require.NoError(t, err) firstTemplateVariable := &sdkproto.TemplateVariable{ Name: "first", Type: "string", DefaultValue: "default_value", Sensitive: true, } secondTemplateVariable := &sdkproto.TemplateVariable{ Name: "second", Type: "string", Required: true, Sensitive: true, } response, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), TemplateVariables: []*sdkproto.TemplateVariable{ firstTemplateVariable, secondTemplateVariable, }, UserVariableValues: []*sdkproto.VariableValue{ { Name: "second", Value: "foobar", }, }, }) require.NoError(t, err) require.Len(t, response.VariableValues, 2) templateVariables, err := db.GetTemplateVersionVariables(ctx, versionID) require.NoError(t, err) require.Len(t, templateVariables, 2) require.Equal(t, templateVariables[0].Value, firstTemplateVariable.DefaultValue) require.Equal(t, templateVariables[1].Value, "foobar") }) t.Run("Missing required value", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() srvID := uuid.New() srv, db, _ := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, srvID) versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: job, }) require.NoError(t, err) firstTemplateVariable := &sdkproto.TemplateVariable{ Name: "first", Type: "string", DefaultValue: "default_value", Sensitive: true, } secondTemplateVariable := &sdkproto.TemplateVariable{ Name: "second", Type: "string", Required: true, Sensitive: true, } response, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), TemplateVariables: []*sdkproto.TemplateVariable{ firstTemplateVariable, secondTemplateVariable, }, }) require.Error(t, err) // required template variables need values require.Nil(t, response) // Even though there is an error returned, variables are stored in the database // to show the schema in the site UI. templateVariables, err := db.GetTemplateVersionVariables(ctx, versionID) require.NoError(t, err) require.Len(t, templateVariables, 2) require.Equal(t, templateVariables[0].Value, firstTemplateVariable.DefaultValue) require.Equal(t, templateVariables[1].Value, "") }) }) } func TestFailJob(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() srv, _, _ := setup(t, false, nil) _, err := srv.FailJob(ctx, &proto.FailedJob{ JobId: "hello", }) require.ErrorContains(t, err, "invalid UUID") _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: uuid.NewString(), }) require.ErrorContains(t, err, "no rows in result set") }) // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() srv, db, _ := setup(t, false, nil) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeTemplateVersionImport, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: uuid.New(), Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) _, err = srv.FailJob(ctx, &proto.FailedJob{ JobId: job.ID.String(), }) require.ErrorContains(t, err, "you don't own this job") }) t.Run("AlreadyCompleted", func(t *testing.T) { t.Parallel() srvID := uuid.New() srv, db, _ := setup(t, false, &overrides{id: &srvID}) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeTemplateVersionImport, StorageMethod: database.ProvisionerStorageMethodFile, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: job.ID, CompletedAt: sql.NullTime{ Time: dbtime.Now(), Valid: true, }, }) require.NoError(t, err) _, err = srv.FailJob(ctx, &proto.FailedJob{ JobId: job.ID.String(), }) require.ErrorContains(t, err, "job already completed") }) t.Run("WorkspaceBuild", func(t *testing.T) { t.Parallel() // Ignore log errors because we get: // // (*Server).FailJob audit log - get build {"error": "sql: no rows in result set"} ignoreLogErrors := true srvID := uuid.New() srv, db, ps := setup(t, ignoreLogErrors, &overrides{id: &srvID}) workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ ID: uuid.New(), AutomaticUpdates: database.AutomaticUpdatesNever, }) require.NoError(t, err) buildID := uuid.New() err = db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ ID: buildID, WorkspaceID: workspace.ID, Transition: database.WorkspaceTransitionStart, Reason: database.BuildReasonInitiator, }) require.NoError(t, err) input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: buildID, }) require.NoError(t, err) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Input: input, Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeWorkspaceBuild, StorageMethod: database.ProvisionerStorageMethodFile, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) publishedWorkspace := make(chan struct{}) closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) { close(publishedWorkspace) }) require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) closeLogsSubscribe, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { close(publishedLogs) }) require.NoError(t, err) defer closeLogsSubscribe() _, err = srv.FailJob(ctx, &proto.FailedJob{ JobId: job.ID.String(), Type: &proto.FailedJob_WorkspaceBuild_{ WorkspaceBuild: &proto.FailedJob_WorkspaceBuild{ State: []byte("some state"), }, }, }) require.NoError(t, err) <-publishedWorkspace <-publishedLogs build, err := db.GetWorkspaceBuildByID(ctx, buildID) require.NoError(t, err) require.Equal(t, "some state", string(build.ProvisionerState)) }) } func TestCompleteJob(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() srv, _, _ := setup(t, false, nil) _, err := srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: "hello", }) require.ErrorContains(t, err, "invalid UUID") _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: uuid.NewString(), }) require.ErrorContains(t, err, "no rows in result set") }) // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() srv, db, _ := setup(t, false, nil) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeWorkspaceBuild, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: uuid.New(), Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: job.ID.String(), }) require.ErrorContains(t, err, "you don't own this job") }) t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) { t.Parallel() srvID := uuid.New() srv, db, _ := setup(t, false, &overrides{id: &srvID}) jobID := uuid.New() versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: jobID, }) require.NoError(t, err) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: jobID, Provisioner: database.ProvisionerTypeEcho, Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeWorkspaceBuild, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) completeJob := func() { _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: job.ID.String(), Type: &proto.CompletedJob_TemplateImport_{ TemplateImport: &proto.CompletedJob_TemplateImport{ StartResources: []*sdkproto.Resource{{ Name: "hello", Type: "aws_instance", }}, StopResources: []*sdkproto.Resource{}, ExternalAuthProviders: []string{"github"}, }, }, }) require.NoError(t, err) } completeJob() job, err = db.GetProvisionerJobByID(ctx, job.ID) require.NoError(t, err) require.Contains(t, job.Error.String, `external auth provider "github" is not configured`) }) t.Run("TemplateImport_WithGitAuth", func(t *testing.T) { t.Parallel() srvID := uuid.New() srv, db, _ := setup(t, false, &overrides{ id: &srvID, externalAuthConfigs: []*externalauth.Config{{ ID: "github", }}, }) jobID := uuid.New() versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, JobID: jobID, }) require.NoError(t, err) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: jobID, Provisioner: database.ProvisionerTypeEcho, Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), StorageMethod: database.ProvisionerStorageMethodFile, Type: database.ProvisionerJobTypeWorkspaceBuild, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) completeJob := func() { _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: job.ID.String(), Type: &proto.CompletedJob_TemplateImport_{ TemplateImport: &proto.CompletedJob_TemplateImport{ StartResources: []*sdkproto.Resource{{ Name: "hello", Type: "aws_instance", }}, StopResources: []*sdkproto.Resource{}, ExternalAuthProviders: []string{"github"}, }, }, }) require.NoError(t, err) } completeJob() job, err = db.GetProvisionerJobByID(ctx, job.ID) require.NoError(t, err) require.False(t, job.Error.Valid) }) // TODO(@dean): remove this legacy test for MaxTTL t.Run("WorkspaceBuildLegacy", func(t *testing.T) { t.Parallel() cases := []struct { name string templateAllowAutostop bool templateDefaultTTL time.Duration templateMaxTTL time.Duration workspaceTTL time.Duration transition database.WorkspaceTransition // The TTL is actually a deadline time on the workspace_build row, // so during the test this will be compared to be within 15 seconds // of the expected value. expectedTTL time.Duration expectedMaxTTL time.Duration }{ { name: "OK", templateAllowAutostop: true, templateDefaultTTL: 0, templateMaxTTL: 0, workspaceTTL: 0, transition: database.WorkspaceTransitionStart, expectedTTL: 0, expectedMaxTTL: 0, }, { name: "Delete", templateAllowAutostop: true, templateDefaultTTL: 0, templateMaxTTL: 0, workspaceTTL: 0, transition: database.WorkspaceTransitionDelete, expectedTTL: 0, expectedMaxTTL: 0, }, { name: "WorkspaceTTL", templateAllowAutostop: true, templateDefaultTTL: 0, templateMaxTTL: 0, workspaceTTL: time.Hour, transition: database.WorkspaceTransitionStart, expectedTTL: time.Hour, expectedMaxTTL: 0, }, { name: "TemplateDefaultTTLIgnored", templateAllowAutostop: true, templateDefaultTTL: time.Hour, templateMaxTTL: 0, workspaceTTL: 0, transition: database.WorkspaceTransitionStart, expectedTTL: 0, expectedMaxTTL: 0, }, { name: "WorkspaceTTLOverridesTemplateDefaultTTL", templateAllowAutostop: true, templateDefaultTTL: 2 * time.Hour, templateMaxTTL: 0, workspaceTTL: time.Hour, transition: database.WorkspaceTransitionStart, expectedTTL: time.Hour, expectedMaxTTL: 0, }, { name: "TemplateMaxTTL", templateAllowAutostop: true, templateDefaultTTL: 0, templateMaxTTL: time.Hour, workspaceTTL: 0, transition: database.WorkspaceTransitionStart, expectedTTL: time.Hour, expectedMaxTTL: time.Hour, }, { name: "TemplateMaxTTLOverridesWorkspaceTTL", templateAllowAutostop: true, templateDefaultTTL: 0, templateMaxTTL: 2 * time.Hour, workspaceTTL: 3 * time.Hour, transition: database.WorkspaceTransitionStart, expectedTTL: 2 * time.Hour, expectedMaxTTL: 2 * time.Hour, }, { name: "TemplateMaxTTLOverridesTemplateDefaultTTL", templateAllowAutostop: true, templateDefaultTTL: 3 * time.Hour, templateMaxTTL: 2 * time.Hour, workspaceTTL: 0, transition: database.WorkspaceTransitionStart, expectedTTL: 2 * time.Hour, expectedMaxTTL: 2 * time.Hour, }, { name: "TemplateBlockWorkspaceTTL", templateAllowAutostop: false, templateDefaultTTL: 3 * time.Hour, templateMaxTTL: 6 * time.Hour, workspaceTTL: 4 * time.Hour, transition: database.WorkspaceTransitionStart, expectedTTL: 3 * time.Hour, expectedMaxTTL: 6 * time.Hour, }, } for _, c := range cases { c := c t.Run(c.name, func(t *testing.T) { t.Parallel() srvID := uuid.New() tss := &atomic.Pointer[schedule.TemplateScheduleStore]{} srv, db, ps := setup(t, false, &overrides{id: &srvID, templateScheduleStore: tss}) var store schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{ GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) { return schedule.TemplateScheduleOptions{ UserAutostartEnabled: false, UserAutostopEnabled: c.templateAllowAutostop, DefaultTTL: c.templateDefaultTTL, MaxTTL: c.templateMaxTTL, UseAutostopRequirement: false, }, nil }, } tss.Store(&store) user := dbgen.User(t, db, database.User{}) template := dbgen.Template(t, db, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{ ID: template.ID, UpdatedAt: dbtime.Now(), AllowUserAutostart: c.templateAllowAutostop, DefaultTTL: int64(c.templateDefaultTTL), MaxTTL: int64(c.templateMaxTTL), }) require.NoError(t, err) file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) workspaceTTL := sql.NullInt64{} if c.workspaceTTL != 0 { workspaceTTL = sql.NullInt64{ Int64: int64(c.workspaceTTL), Valid: true, } } workspace := dbgen.Workspace(t, db, database.Workspace{ TemplateID: template.ID, Ttl: workspaceTTL, }) version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, }, JobID: uuid.New(), }) build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, TemplateVersionID: version.ID, Transition: c.transition, Reason: database.BuildReasonInitiator, }) job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ FileID: file.ID, Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: build.ID, })), }) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) publishedWorkspace := make(chan struct{}) closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { close(publishedWorkspace) }) require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) closeLogsSubscribe, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { close(publishedLogs) }) require.NoError(t, err) defer closeLogsSubscribe() _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: job.ID.String(), Type: &proto.CompletedJob_WorkspaceBuild_{ WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ State: []byte{}, Resources: []*sdkproto.Resource{{ Name: "example", Type: "aws_instance", }}, }, }, }) require.NoError(t, err) <-publishedWorkspace <-publishedLogs workspace, err = db.GetWorkspaceByID(ctx, workspace.ID) require.NoError(t, err) require.Equal(t, c.transition == database.WorkspaceTransitionDelete, workspace.Deleted) workspaceBuild, err := db.GetWorkspaceBuildByID(ctx, build.ID) require.NoError(t, err) if c.expectedTTL == 0 { require.True(t, workspaceBuild.Deadline.IsZero()) } else { require.WithinDuration(t, time.Now().Add(c.expectedTTL), workspaceBuild.Deadline, 15*time.Second, "deadline does not match expected") } if c.expectedMaxTTL == 0 { require.True(t, workspaceBuild.MaxDeadline.IsZero()) } else { require.WithinDuration(t, time.Now().Add(c.expectedMaxTTL), workspaceBuild.MaxDeadline, 15*time.Second, "max deadline does not match expected") require.GreaterOrEqual(t, workspaceBuild.MaxDeadline.Unix(), workspaceBuild.Deadline.Unix(), "max deadline is smaller than deadline") } }) } }) t.Run("WorkspaceBuild", func(t *testing.T) { t.Parallel() now := time.Now() // NOTE: if you're looking for more in-depth deadline/max_deadline // calculation testing, see the schedule package. The provsiionerdserver // package calls `schedule.CalculateAutostop()` to generate the deadline // and max_deadline. // Wednesday the 8th of February 2023 at midnight. This date was // specifically chosen as it doesn't fall on a applicable week for both // fortnightly and triweekly autostop requirements. wednesdayMidnightUTC := time.Date(2023, 2, 8, 0, 0, 0, 0, time.UTC) sydneyQuietHours := "CRON_TZ=Australia/Sydney 0 0 * * *" sydneyLoc, err := time.LoadLocation("Australia/Sydney") require.NoError(t, err) // 12am on Saturday the 11th of February 2023 in Sydney. saturdayMidnightSydney := time.Date(2023, 2, 11, 0, 0, 0, 0, sydneyLoc) t.Log("now", now) t.Log("wednesdayMidnightUTC", wednesdayMidnightUTC) t.Log("saturdayMidnightSydney", saturdayMidnightSydney) cases := []struct { name string now time.Time workspaceTTL time.Duration transition database.WorkspaceTransition // These fields are only used when testing max deadline. userQuietHoursSchedule string templateAutostopRequirement schedule.TemplateAutostopRequirement expectedDeadline time.Time expectedMaxDeadline time.Time }{ { name: "OK", now: now, templateAutostopRequirement: schedule.TemplateAutostopRequirement{}, workspaceTTL: 0, transition: database.WorkspaceTransitionStart, expectedDeadline: time.Time{}, expectedMaxDeadline: time.Time{}, }, { name: "Delete", now: now, templateAutostopRequirement: schedule.TemplateAutostopRequirement{}, workspaceTTL: 0, transition: database.WorkspaceTransitionDelete, expectedDeadline: time.Time{}, expectedMaxDeadline: time.Time{}, }, { name: "WorkspaceTTL", now: now, templateAutostopRequirement: schedule.TemplateAutostopRequirement{}, workspaceTTL: time.Hour, transition: database.WorkspaceTransitionStart, expectedDeadline: now.Add(time.Hour), expectedMaxDeadline: time.Time{}, }, { name: "TemplateAutostopRequirement", now: wednesdayMidnightUTC, userQuietHoursSchedule: sydneyQuietHours, templateAutostopRequirement: schedule.TemplateAutostopRequirement{ DaysOfWeek: 0b00100000, // Saturday Weeks: 0, // weekly }, workspaceTTL: 0, transition: database.WorkspaceTransitionStart, // expectedDeadline is copied from expectedMaxDeadline. expectedMaxDeadline: saturdayMidnightSydney.In(time.UTC), }, } for _, c := range cases { c := c t.Run(c.name, func(t *testing.T) { t.Parallel() srvID := uuid.New() // Simulate the given time starting from now. require.False(t, c.now.IsZero()) start := time.Now() tss := &atomic.Pointer[schedule.TemplateScheduleStore]{} uqhss := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{} srv, db, ps := setup(t, false, &overrides{ timeNowFn: func() time.Time { return c.now.Add(time.Since(start)) }, templateScheduleStore: tss, userQuietHoursScheduleStore: uqhss, id: &srvID, }) var templateScheduleStore schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{ GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) { return schedule.TemplateScheduleOptions{ UserAutostartEnabled: false, UserAutostopEnabled: true, DefaultTTL: 0, UseAutostopRequirement: true, AutostopRequirement: c.templateAutostopRequirement, }, nil }, } tss.Store(&templateScheduleStore) var userQuietHoursScheduleStore schedule.UserQuietHoursScheduleStore = schedule.MockUserQuietHoursScheduleStore{ GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.UserQuietHoursScheduleOptions, error) { if c.userQuietHoursSchedule == "" { return schedule.UserQuietHoursScheduleOptions{ Schedule: nil, }, nil } sched, err := cron.Daily(c.userQuietHoursSchedule) if !assert.NoError(t, err) { return schedule.UserQuietHoursScheduleOptions{}, err } return schedule.UserQuietHoursScheduleOptions{ Schedule: sched, UserSet: false, }, nil }, } uqhss.Store(&userQuietHoursScheduleStore) user := dbgen.User(t, db, database.User{ QuietHoursSchedule: c.userQuietHoursSchedule, }) template := dbgen.Template(t, db, database.Template{ Name: "template", Provisioner: database.ProvisionerTypeEcho, }) err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{ ID: template.ID, UpdatedAt: dbtime.Now(), AllowUserAutostart: false, AllowUserAutostop: true, DefaultTTL: 0, AutostopRequirementDaysOfWeek: int16(c.templateAutostopRequirement.DaysOfWeek), AutostopRequirementWeeks: c.templateAutostopRequirement.Weeks, }) require.NoError(t, err) template, err = db.GetTemplateByID(ctx, template.ID) require.NoError(t, err) file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) workspaceTTL := sql.NullInt64{} if c.workspaceTTL != 0 { workspaceTTL = sql.NullInt64{ Int64: int64(c.workspaceTTL), Valid: true, } } workspace := dbgen.Workspace(t, db, database.Workspace{ TemplateID: template.ID, Ttl: workspaceTTL, OwnerID: user.ID, }) version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ TemplateID: uuid.NullUUID{ UUID: template.ID, Valid: true, }, JobID: uuid.New(), }) build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, TemplateVersionID: version.ID, Transition: c.transition, Reason: database.BuildReasonInitiator, }) job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{ FileID: file.ID, Type: database.ProvisionerJobTypeWorkspaceBuild, Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: build.ID, })), }) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) publishedWorkspace := make(chan struct{}) closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) { close(publishedWorkspace) }) require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) closeLogsSubscribe, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { close(publishedLogs) }) require.NoError(t, err) defer closeLogsSubscribe() _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: job.ID.String(), Type: &proto.CompletedJob_WorkspaceBuild_{ WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{ State: []byte{}, Resources: []*sdkproto.Resource{{ Name: "example", Type: "aws_instance", }}, }, }, }) require.NoError(t, err) <-publishedWorkspace <-publishedLogs workspace, err = db.GetWorkspaceByID(ctx, workspace.ID) require.NoError(t, err) require.Equal(t, c.transition == database.WorkspaceTransitionDelete, workspace.Deleted) workspaceBuild, err := db.GetWorkspaceBuildByID(ctx, build.ID) require.NoError(t, err) // If the max deadline is set, the deadline should also be set. // Default to the max deadline if the deadline is not set. if c.expectedDeadline.IsZero() { c.expectedDeadline = c.expectedMaxDeadline } if c.expectedDeadline.IsZero() { require.True(t, workspaceBuild.Deadline.IsZero()) } else { require.WithinDuration(t, c.expectedDeadline, workspaceBuild.Deadline, 15*time.Second, "deadline does not match expected") } if c.expectedMaxDeadline.IsZero() { require.True(t, workspaceBuild.MaxDeadline.IsZero()) } else { require.WithinDuration(t, c.expectedMaxDeadline, workspaceBuild.MaxDeadline, 15*time.Second, "max deadline does not match expected") require.GreaterOrEqual(t, workspaceBuild.MaxDeadline.Unix(), workspaceBuild.Deadline.Unix(), "max deadline is smaller than deadline") } }) } }) t.Run("TemplateDryRun", func(t *testing.T) { t.Parallel() srvID := uuid.New() srv, db, _ := setup(t, false, &overrides{id: &srvID}) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, Type: database.ProvisionerJobTypeTemplateVersionDryRun, StorageMethod: database.ProvisionerStorageMethodFile, }) require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ UUID: srvID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, }) require.NoError(t, err) _, err = srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: job.ID.String(), Type: &proto.CompletedJob_TemplateDryRun_{ TemplateDryRun: &proto.CompletedJob_TemplateDryRun{ Resources: []*sdkproto.Resource{{ Name: "something", Type: "aws_instance", }}, }, }, }) require.NoError(t, err) }) } func TestInsertWorkspaceResource(t *testing.T) { t.Parallel() ctx := context.Background() insert := func(db database.Store, jobID uuid.UUID, resource *sdkproto.Resource) error { return provisionerdserver.InsertWorkspaceResource(ctx, db, jobID, database.WorkspaceTransitionStart, resource, &telemetry.Snapshot{}) } t.Run("NoAgents", func(t *testing.T) { t.Parallel() db := dbmem.New() job := uuid.New() err := insert(db, job, &sdkproto.Resource{ Name: "something", Type: "aws_instance", }) require.NoError(t, err) resources, err := db.GetWorkspaceResourcesByJobID(ctx, job) require.NoError(t, err) require.Len(t, resources, 1) }) t.Run("InvalidAgentToken", func(t *testing.T) { t.Parallel() err := insert(dbmem.New(), uuid.New(), &sdkproto.Resource{ Name: "something", Type: "aws_instance", Agents: []*sdkproto.Agent{{ Auth: &sdkproto.Agent_Token{ Token: "bananas", }, }}, }) require.ErrorContains(t, err, "invalid UUID length") }) t.Run("DuplicateApps", func(t *testing.T) { t.Parallel() err := insert(dbmem.New(), uuid.New(), &sdkproto.Resource{ Name: "something", Type: "aws_instance", Agents: []*sdkproto.Agent{{ Apps: []*sdkproto.App{{ Slug: "a", }, { Slug: "a", }}, }}, }) require.ErrorContains(t, err, "duplicate app slug") }) t.Run("Success", func(t *testing.T) { t.Parallel() db := dbmem.New() job := uuid.New() err := insert(db, job, &sdkproto.Resource{ Name: "something", Type: "aws_instance", DailyCost: 10, Agents: []*sdkproto.Agent{{ Name: "dev", Env: map[string]string{ "something": "test", }, OperatingSystem: "linux", Architecture: "amd64", Auth: &sdkproto.Agent_Token{ Token: uuid.NewString(), }, Apps: []*sdkproto.App{{ Slug: "a", }}, Scripts: []*sdkproto.Script{{ DisplayName: "Startup", Icon: "/test.png", }}, DisplayApps: &sdkproto.DisplayApps{ Vscode: true, PortForwardingHelper: true, SshHelper: true, }, }}, }) require.NoError(t, err) resources, err := db.GetWorkspaceResourcesByJobID(ctx, job) require.NoError(t, err) require.Len(t, resources, 1) require.EqualValues(t, 10, resources[0].DailyCost) agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID}) require.NoError(t, err) require.Len(t, agents, 1) agent := agents[0] require.Equal(t, "amd64", agent.Architecture) require.Equal(t, "linux", agent.OperatingSystem) want, err := json.Marshal(map[string]string{ "something": "test", }) require.NoError(t, err) got, err := agent.EnvironmentVariables.RawMessage.MarshalJSON() require.NoError(t, err) require.Equal(t, want, got) require.ElementsMatch(t, []database.DisplayApp{ database.DisplayAppPortForwardingHelper, database.DisplayAppSSHHelper, database.DisplayAppVscode, }, agent.DisplayApps) }) t.Run("AllDisplayApps", func(t *testing.T) { t.Parallel() db := dbmem.New() job := uuid.New() err := insert(db, job, &sdkproto.Resource{ Name: "something", Type: "aws_instance", Agents: []*sdkproto.Agent{{ DisplayApps: &sdkproto.DisplayApps{ Vscode: true, VscodeInsiders: true, SshHelper: true, PortForwardingHelper: true, WebTerminal: true, }, }}, }) require.NoError(t, err) resources, err := db.GetWorkspaceResourcesByJobID(ctx, job) require.NoError(t, err) require.Len(t, resources, 1) agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID}) require.NoError(t, err) require.Len(t, agents, 1) agent := agents[0] require.ElementsMatch(t, database.AllDisplayAppValues(), agent.DisplayApps) }) t.Run("DisableDefaultApps", func(t *testing.T) { t.Parallel() db := dbmem.New() job := uuid.New() err := insert(db, job, &sdkproto.Resource{ Name: "something", Type: "aws_instance", Agents: []*sdkproto.Agent{{ DisplayApps: &sdkproto.DisplayApps{}, }}, }) require.NoError(t, err) resources, err := db.GetWorkspaceResourcesByJobID(ctx, job) require.NoError(t, err) require.Len(t, resources, 1) agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID}) require.NoError(t, err) require.Len(t, agents, 1) agent := agents[0] // An empty array (as opposed to nil) should be returned to indicate // that all apps are disabled. require.Equal(t, []database.DisplayApp{}, agent.DisplayApps) }) } type overrides struct { deploymentValues *codersdk.DeploymentValues externalAuthConfigs []*externalauth.Config id *uuid.UUID templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] timeNowFn func() time.Time acquireJobLongPollDuration time.Duration } func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) { t.Helper() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) db := dbmem.New() ps := pubsub.NewInMemory() deploymentValues := &codersdk.DeploymentValues{} var externalAuthConfigs []*externalauth.Config srvID := uuid.New() tss := testTemplateScheduleStore() uqhss := testUserQuietHoursScheduleStore() var timeNowFn func() time.Time pollDur := time.Duration(0) if ov != nil { if ov.deploymentValues != nil { deploymentValues = ov.deploymentValues } if ov.externalAuthConfigs != nil { externalAuthConfigs = ov.externalAuthConfigs } if ov.id != nil { srvID = *ov.id } if ov.templateScheduleStore != nil { ttss := tss.Load() // keep the initial test value if the override hasn't set the atomic pointer. tss = ov.templateScheduleStore if tss.Load() == nil { swapped := tss.CompareAndSwap(nil, ttss) require.True(t, swapped) } } if ov.userQuietHoursScheduleStore != nil { tuqhss := uqhss.Load() // keep the initial test value if the override hasn't set the atomic pointer. uqhss = ov.userQuietHoursScheduleStore if uqhss.Load() == nil { swapped := uqhss.CompareAndSwap(nil, tuqhss) require.True(t, swapped) } } if ov.timeNowFn != nil { timeNowFn = ov.timeNowFn } pollDur = ov.acquireJobLongPollDuration } srv, err := provisionerdserver.NewServer( ctx, &url.URL{}, srvID, slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}), []database.ProvisionerType{database.ProvisionerTypeEcho}, provisionerdserver.Tags{}, db, ps, provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps), telemetry.NewNoop(), trace.NewNoopTracerProvider().Tracer("noop"), &atomic.Pointer[proto.QuotaCommitter]{}, mockAuditor(), tss, uqhss, deploymentValues, provisionerdserver.Options{ ExternalAuthConfigs: externalAuthConfigs, TimeNowFn: timeNowFn, OIDCConfig: &oauth2.Config{}, AcquireJobLongPollDur: pollDur, }, ) require.NoError(t, err) return srv, db, ps } func must[T any](value T, err error) T { if err != nil { panic(err) } return value } var ( errUnimplemented = xerrors.New("unimplemented") errClosed = xerrors.New("closed") ) type fakeStream struct { ctx context.Context c *sync.Cond closed bool canceled bool sendCalled bool job *proto.AcquiredJob } func newFakeStream(ctx context.Context) *fakeStream { return &fakeStream{ ctx: ctx, c: sync.NewCond(&sync.Mutex{}), } } func (s *fakeStream) Send(j *proto.AcquiredJob) error { s.c.L.Lock() defer s.c.L.Unlock() s.sendCalled = true s.job = j s.c.Broadcast() return nil } func (s *fakeStream) Recv() (*proto.CancelAcquire, error) { s.c.L.Lock() defer s.c.L.Unlock() for !(s.canceled || s.closed) { s.c.Wait() } if s.canceled { return &proto.CancelAcquire{}, nil } return nil, io.EOF } // Context returns the context associated with the stream. It is canceled // when the Stream is closed and no more messages will ever be sent or // received on it. func (s *fakeStream) Context() context.Context { return s.ctx } // MsgSend sends the Message to the remote. func (*fakeStream) MsgSend(drpc.Message, drpc.Encoding) error { return errUnimplemented } // MsgRecv receives a Message from the remote. func (*fakeStream) MsgRecv(drpc.Message, drpc.Encoding) error { return errUnimplemented } // CloseSend signals to the remote that we will no longer send any messages. func (*fakeStream) CloseSend() error { return errUnimplemented } // Close closes the stream. func (s *fakeStream) Close() error { s.c.L.Lock() defer s.c.L.Unlock() s.closed = true s.c.Broadcast() return nil } func (s *fakeStream) waitForJob() (*proto.AcquiredJob, error) { s.c.L.Lock() defer s.c.L.Unlock() for !(s.sendCalled || s.closed) { s.c.Wait() } if s.sendCalled { return s.job, nil } return nil, errClosed } func (s *fakeStream) cancel() { s.c.L.Lock() defer s.c.L.Unlock() s.canceled = true s.c.Broadcast() }