From 213b768785d7a07b0c5f227ea7d4ada0cd4ee54e Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Mon, 18 Dec 2023 16:44:52 +0000 Subject: [PATCH] feat(coderd): insert provisioner daemons (#11207) * Adds UpdateProvisionerDaemonLastSeenAt * Adds heartbeat to provisioner daemons * Inserts provisioner daemons to database upon start * Ensures TagOwner is an empty string and not nil * Adds COALESCE() in idx_provisioner_daemons_name_owner_key --- cli/testdata/coder_list_--output_json.golden | 1 + coderd/coderd.go | 28 +- coderd/coderdtest/coderdtest.go | 2 +- coderd/database/dbauthz/dbauthz.go | 8 + coderd/database/dbauthz/dbauthz_test.go | 12 + coderd/database/dbauthz/setup_test.go | 2 +- coderd/database/dbmem/dbmem.go | 22 ++ coderd/database/dbmetrics/dbmetrics.go | 7 + coderd/database/dbmock/dbmock.go | 14 + coderd/database/dump.sql | 4 +- ...esce_provisioner_daemon_idx_owner.down.sql | 8 + ...alesce_provisioner_daemon_idx_owner.up.sql | 8 + coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 22 +- .../database/queries/provisionerdaemons.sql | 11 +- coderd/database/unique_constraint.go | 2 +- .../provisionerdserver/provisionerdserver.go | 92 ++++++- .../provisionerdserver_test.go | 254 +++++++++++------- coderd/rbac/roles.go | 2 + codersdk/organizations.go | 3 - enterprise/cli/provisionerdaemons_test.go | 81 +++++- enterprise/coderd/provisionerdaemons.go | 45 +++- enterprise/coderd/provisionerdaemons_test.go | 32 +++ provisionerd/provisionerd.go | 8 + provisionersdk/provisionertags.go | 10 +- provisionersdk/provisionertags_test.go | 6 + 26 files changed, 548 insertions(+), 137 deletions(-) create mode 100644 coderd/database/migrations/000181_coalesce_provisioner_daemon_idx_owner.down.sql create mode 100644 coderd/database/migrations/000181_coalesce_provisioner_daemon_idx_owner.up.sql diff --git a/cli/testdata/coder_list_--output_json.golden b/cli/testdata/coder_list_--output_json.golden index 55b6606948..d1874e6f7c 100644 --- a/cli/testdata/coder_list_--output_json.golden +++ b/cli/testdata/coder_list_--output_json.golden @@ -36,6 +36,7 @@ "worker_id": "[workspace build worker ID]", "file_id": "[workspace build file ID]", "tags": { + "owner": "", "scope": "organization" }, "queue_position": 0, diff --git a/coderd/coderd.go b/coderd/coderd.go index 67d1593b4f..898dcb36d5 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "database/sql" "flag" "fmt" "io" @@ -49,6 +50,7 @@ import ( "github.com/coder/coder/v2/coderd/batchstats" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/gitsshkey" @@ -1178,22 +1180,32 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string } }() - tags := provisionerdserver.Tags{ - provisionersdk.TagScope: provisionersdk.ScopeOrganization, + //nolint:gocritic // in-memory provisioners are owned by system + daemon, err := api.Database.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{ + Name: name, + CreatedAt: dbtime.Now(), + Provisioners: []database.ProvisionerType{ + database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform, + }, + Tags: provisionersdk.MutateTags(uuid.Nil, nil), + LastSeenAt: sql.NullTime{Time: dbtime.Now(), Valid: true}, + Version: buildinfo.Version(), + APIVersion: "1.0", + }) + if err != nil { + return nil, xerrors.Errorf("failed to create in-memory provisioner daemon: %w", err) } mux := drpcmux.New() api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name)) logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name)) srv, err := provisionerdserver.NewServer( - api.ctx, + api.ctx, // use the same ctx as the API api.AccessURL, - uuid.New(), + daemon.ID, logger, - []database.ProvisionerType{ - database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform, - }, - tags, + daemon.Provisioners, + provisionerdserver.Tags(daemon.Tags), api.Database, api.Pubsub, api.Acquirer, diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 8de475c784..55060a0998 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -533,7 +533,7 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer { }() daemon := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { - return coderAPI.CreateInMemoryProvisionerDaemon(ctx, t.Name()) + return coderAPI.CreateInMemoryProvisionerDaemon(ctx, "test") }, &provisionerd.Options{ Logger: coderAPI.Logger.Named("provisionerd").Leveled(slog.LevelDebug), UpdateInterval: 250 * time.Millisecond, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 5ba613044f..0a1858f559 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -232,6 +232,7 @@ var ( rbac.ResourceOrganization.Type: {rbac.ActionCreate}, rbac.ResourceOrganizationMember.Type: {rbac.ActionCreate}, rbac.ResourceOrgRoleAssignment.Type: {rbac.ActionCreate}, + rbac.ResourceProvisionerDaemon.Type: {rbac.ActionCreate, rbac.ActionUpdate}, rbac.ResourceUser.Type: {rbac.ActionCreate, rbac.ActionUpdate, rbac.ActionDelete}, rbac.ResourceUserData.Type: {rbac.ActionCreate, rbac.ActionUpdate}, rbac.ResourceWorkspace.Type: {rbac.ActionUpdate}, @@ -2499,6 +2500,13 @@ func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemb return q.db.UpdateMemberRoles(ctx, arg) } +func (q *querier) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceProvisionerDaemon); err != nil { + return err + } + return q.db.UpdateProvisionerDaemonLastSeenAt(ctx, arg) +} + // TODO: We need to create a ProvisionerJob resource type func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error { // if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 5f40fe936c..41aecb48fa 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1592,6 +1592,18 @@ func (s *MethodTestSuite) TestExtraMethods() { s.NoError(err, "insert provisioner daemon") check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete) })) + s.Run("UpdateProvisionerDaemonLastSeenAt", s.Subtest(func(db database.Store, check *expects) { + d, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{ + Tags: database.StringMap(map[string]string{ + provisionersdk.TagScope: provisionersdk.ScopeOrganization, + }), + }) + s.NoError(err, "insert provisioner daemon") + check.Args(database.UpdateProvisionerDaemonLastSeenAtParams{ + ID: d.ID, + LastSeenAt: sql.NullTime{Time: dbtime.Now(), Valid: true}, + }).Asserts(rbac.ResourceProvisionerDaemon, rbac.ActionUpdate) + })) } // All functions in this method test suite are not implemented in dbmem, but diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go index 3c54d8be4e..403d23d508 100644 --- a/coderd/database/dbauthz/setup_test.go +++ b/coderd/database/dbauthz/setup_test.go @@ -28,7 +28,7 @@ import ( "github.com/coder/coder/v2/coderd/util/slice" ) -var errMatchAny = errors.New("match any error") +var errMatchAny = xerrors.New("match any error") var skipMethods = map[string]string{ "InTx": "Not relevant", diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 95e5528dbb..216e718939 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -5945,6 +5945,28 @@ func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMe return database.OrganizationMember{}, sql.ErrNoRows } +func (q *FakeQuerier) UpdateProvisionerDaemonLastSeenAt(_ context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for idx := range q.provisionerDaemons { + if q.provisionerDaemons[idx].ID != arg.ID { + continue + } + if q.provisionerDaemons[idx].LastSeenAt.Time.After(arg.LastSeenAt.Time) { + continue + } + q.provisionerDaemons[idx].LastSeenAt = arg.LastSeenAt + return nil + } + return sql.ErrNoRows +} + func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { if err := validateDatabaseType(arg); err != nil { return err diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 6ea0b6d615..1b1aa1e631 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -1593,6 +1593,13 @@ func (m metricsStore) UpdateMemberRoles(ctx context.Context, arg database.Update return member, err } +func (m metricsStore) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error { + start := time.Now() + r0 := m.s.UpdateProvisionerDaemonLastSeenAt(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateProvisionerDaemonLastSeenAt").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error { start := time.Now() err := m.s.UpdateProvisionerJobByID(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 19c8e76365..d8fa998ee6 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3362,6 +3362,20 @@ func (mr *MockStoreMockRecorder) UpdateMemberRoles(arg0, arg1 interface{}) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMemberRoles", reflect.TypeOf((*MockStore)(nil).UpdateMemberRoles), arg0, arg1) } +// UpdateProvisionerDaemonLastSeenAt mocks base method. +func (m *MockStore) UpdateProvisionerDaemonLastSeenAt(arg0 context.Context, arg1 database.UpdateProvisionerDaemonLastSeenAtParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateProvisionerDaemonLastSeenAt", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateProvisionerDaemonLastSeenAt indicates an expected call of UpdateProvisionerDaemonLastSeenAt. +func (mr *MockStoreMockRecorder) UpdateProvisionerDaemonLastSeenAt(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProvisionerDaemonLastSeenAt", reflect.TypeOf((*MockStore)(nil).UpdateProvisionerDaemonLastSeenAt), arg0, arg1) +} + // UpdateProvisionerJobByID mocks base method. func (m *MockStore) UpdateProvisionerJobByID(arg0 context.Context, arg1 database.UpdateProvisionerJobByIDParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index b2f03e0183..0dd504ce2f 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1417,9 +1417,9 @@ CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name); CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)); -CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower((tags ->> 'owner'::text))); +CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower(COALESCE((tags ->> 'owner'::text), ''::text))); -COMMENT ON INDEX idx_provisioner_daemons_name_owner_key IS 'Relax uniqueness constraint for provisioner daemon names'; +COMMENT ON INDEX idx_provisioner_daemons_name_owner_key IS 'Allow unique provisioner daemon names by user'; CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id); diff --git a/coderd/database/migrations/000181_coalesce_provisioner_daemon_idx_owner.down.sql b/coderd/database/migrations/000181_coalesce_provisioner_daemon_idx_owner.down.sql new file mode 100644 index 0000000000..e28371910c --- /dev/null +++ b/coderd/database/migrations/000181_coalesce_provisioner_daemon_idx_owner.down.sql @@ -0,0 +1,8 @@ +DROP INDEX IF EXISTS idx_provisioner_daemons_name_owner_key; + +CREATE UNIQUE INDEX IF NOT EXISTS idx_provisioner_daemons_name_owner_key + ON provisioner_daemons + USING btree (name, lower((tags->>'owner')::text)); + +COMMENT ON INDEX idx_provisioner_daemons_name_owner_key + IS 'Relax uniqueness constraint for provisioner daemon names'; diff --git a/coderd/database/migrations/000181_coalesce_provisioner_daemon_idx_owner.up.sql b/coderd/database/migrations/000181_coalesce_provisioner_daemon_idx_owner.up.sql new file mode 100644 index 0000000000..146f73a23e --- /dev/null +++ b/coderd/database/migrations/000181_coalesce_provisioner_daemon_idx_owner.up.sql @@ -0,0 +1,8 @@ +DROP INDEX IF EXISTS idx_provisioner_daemons_name_owner_key; + +CREATE UNIQUE INDEX IF NOT EXISTS idx_provisioner_daemons_name_owner_key + ON provisioner_daemons + USING btree (name, LOWER(COALESCE(tags->>'owner', '')::text)); + +COMMENT ON INDEX idx_provisioner_daemons_name_owner_key + IS 'Allow unique provisioner daemon names by user'; diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3b0ff868bd..eb7009adec 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -318,6 +318,7 @@ type sqlcQuerier interface { UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error) UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error) + UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e0c4c7ce6d..cd7c115e39 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3057,6 +3057,26 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa return items, nil } +const updateProvisionerDaemonLastSeenAt = `-- name: UpdateProvisionerDaemonLastSeenAt :exec +UPDATE provisioner_daemons +SET + last_seen_at = $1 +WHERE + id = $2 +AND + last_seen_at <= $1 +` + +type UpdateProvisionerDaemonLastSeenAtParams struct { + LastSeenAt sql.NullTime `db:"last_seen_at" json:"last_seen_at"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error { + _, err := q.db.ExecContext(ctx, updateProvisionerDaemonLastSeenAt, arg.LastSeenAt, arg.ID) + return err +} + const upsertProvisionerDaemon = `-- name: UpsertProvisionerDaemon :one INSERT INTO provisioner_daemons ( @@ -3078,7 +3098,7 @@ VALUES ( $5, $6, $7 -) ON CONFLICT("name", lower((tags ->> 'owner'::text))) DO UPDATE SET +) ON CONFLICT("name", LOWER(COALESCE(tags ->> 'owner'::text, ''::text))) DO UPDATE SET provisioners = $3, tags = $4, last_seen_at = $5, diff --git a/coderd/database/queries/provisionerdaemons.sql b/coderd/database/queries/provisionerdaemons.sql index 3041efad94..47c92f7997 100644 --- a/coderd/database/queries/provisionerdaemons.sql +++ b/coderd/database/queries/provisionerdaemons.sql @@ -35,7 +35,7 @@ VALUES ( @last_seen_at, @version, @api_version -) ON CONFLICT("name", lower((tags ->> 'owner'::text))) DO UPDATE SET +) ON CONFLICT("name", LOWER(COALESCE(tags ->> 'owner'::text, ''::text))) DO UPDATE SET provisioners = @provisioners, tags = @tags, last_seen_at = @last_seen_at, @@ -45,3 +45,12 @@ WHERE -- Only ones with the same tags are allowed clobber provisioner_daemons.tags <@ @tags :: jsonb RETURNING *; + +-- name: UpdateProvisionerDaemonLastSeenAt :exec +UPDATE provisioner_daemons +SET + last_seen_at = @last_seen_at +WHERE + id = @id +AND + last_seen_at <= @last_seen_at; diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index e69ed46614..9e54f47652 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -65,7 +65,7 @@ const ( UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type); UniqueIndexOrganizationName UniqueConstraint = "idx_organization_name" // CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name); UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)); - UniqueIndexProvisionerDaemonsNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower((tags ->> 'owner'::text))); + UniqueIndexProvisionerDaemonsNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower(COALESCE((tags ->> 'owner'::text), ''::text))); UniqueIndexUsersEmail UniqueConstraint = "idx_users_email" // CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false); UniqueIndexUsersUsername UniqueConstraint = "idx_users_username" // CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false); UniqueTemplatesOrganizationIDNameIndex UniqueConstraint = "templates_organization_id_name_idx" // CREATE UNIQUE INDEX templates_organization_id_name_idx ON templates USING btree (organization_id, lower((name)::text)) WHERE (deleted = false); diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 6d035bf5ce..f1362b50ac 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -44,9 +44,15 @@ import ( sdkproto "github.com/coder/coder/v2/provisionersdk/proto" ) -// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before -// canceling and returning an empty job. -const DefaultAcquireJobLongPollDur = time.Second * 5 +const ( + // DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before + // canceling and returning an empty job. + DefaultAcquireJobLongPollDur = time.Second * 5 + + // DefaultHeartbeatInterval is the interval at which the provisioner daemon + // will update its last seen at timestamp in the database. + DefaultHeartbeatInterval = time.Minute +) type Options struct { OIDCConfig httpmw.OAuth2Config @@ -56,6 +62,16 @@ type Options struct { // AcquireJobLongPollDur is used in tests AcquireJobLongPollDur time.Duration + + // HeartbeatInterval is the interval at which the provisioner daemon + // will update its last seen at timestamp in the database. + HeartbeatInterval time.Duration + + // HeartbeatFn is the function that will be called at the interval + // specified by HeartbeatInterval. + // The default function just calls UpdateProvisionerDaemonLastSeenAt. + // This is mainly used for testing. + HeartbeatFn func(context.Context) error } type server struct { @@ -85,6 +101,9 @@ type server struct { TimeNowFn func() time.Time acquireJobLongPollDur time.Duration + + heartbeatInterval time.Duration + heartbeatFn func(ctx context.Context) error } // We use the null byte (0x00) in generating a canonical map key for tags, so @@ -161,7 +180,21 @@ func NewServer( if options.AcquireJobLongPollDur == 0 { options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur } - return &server{ + if options.HeartbeatInterval == 0 { + options.HeartbeatInterval = DefaultHeartbeatInterval + } + // Avoid a nil check in s.heartbeat. + if options.HeartbeatFn == nil { + options.HeartbeatFn = func(hbCtx context.Context) error { + //nolint:gocritic // This is specifically for updating the last seen at timestamp. + return db.UpdateProvisionerDaemonLastSeenAt(dbauthz.AsSystemRestricted(hbCtx), database.UpdateProvisionerDaemonLastSeenAtParams{ + ID: id, + LastSeenAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) + } + } + + s := &server{ lifecycleCtx: lifecycleCtx, AccessURL: accessURL, ID: id, @@ -182,7 +215,12 @@ func NewServer( OIDCConfig: options.OIDCConfig, TimeNowFn: options.TimeNowFn, acquireJobLongPollDur: options.AcquireJobLongPollDur, - }, nil + heartbeatInterval: options.HeartbeatInterval, + heartbeatFn: options.HeartbeatFn, + } + + go s.heartbeatLoop() + return s, nil } // timeNow should be used when trying to get the current time for math @@ -194,6 +232,50 @@ func (s *server) timeNow() time.Time { return dbtime.Now() } +// heartbeatLoop runs heartbeatOnce at the interval specified by HeartbeatInterval +// until the lifecycle context is canceled. +func (s *server) heartbeatLoop() { + tick := time.NewTicker(time.Nanosecond) + defer tick.Stop() + for { + select { + case <-s.lifecycleCtx.Done(): + s.Logger.Debug(s.lifecycleCtx, "heartbeat loop canceled") + return + case <-tick.C: + if s.lifecycleCtx.Err() != nil { + return + } + start := s.timeNow() + hbCtx, hbCancel := context.WithTimeout(s.lifecycleCtx, s.heartbeatInterval) + if err := s.heartbeat(hbCtx); err != nil { + if !xerrors.Is(err, context.DeadlineExceeded) && !xerrors.Is(err, context.Canceled) { + s.Logger.Error(hbCtx, "heartbeat failed", slog.Error(err)) + } + } + hbCancel() + elapsed := s.timeNow().Sub(start) + nextBeat := s.heartbeatInterval - elapsed + // avoid negative interval + if nextBeat <= 0 { + nextBeat = time.Nanosecond + } + tick.Reset(nextBeat) + } + } +} + +// heartbeat updates the last seen at timestamp in the database. +// If HeartbeatFn is set, it will be called instead. +func (s *server) heartbeat(ctx context.Context) error { + select { + case <-ctx.Done(): + return nil + default: + return s.heartbeatFn(ctx) + } +} + // AcquireJob queries the database to lock a job. // // Deprecated: This method is only available for back-level provisioner daemons. diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 3aff6d4d80..c2e8c6a836 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -66,7 +66,8 @@ func testUserQuietHoursScheduleStore() *atomic.Pointer[schedule.UserQuietHoursSc func TestAcquireJob_LongPoll(t *testing.T) { t.Parallel() - srv, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond}) + //nolint:dogsled // ૮・ᴥ・ა + srv, _, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond}) job, err := srv.AcquireJob(context.Background(), nil) require.NoError(t, err) require.Equal(t, &proto.AcquiredJob{}, job) @@ -74,7 +75,8 @@ func TestAcquireJob_LongPoll(t *testing.T) { func TestAcquireJobWithCancel_Cancel(t *testing.T) { t.Parallel() - srv, _, _ := setup(t, false, nil) + //nolint:dogsled // ૮ ˶′ﻌ ‵˶ ა + srv, _, _, _ := setup(t, false, nil) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() fs := newFakeStream(ctx) @@ -95,6 +97,46 @@ func TestAcquireJobWithCancel_Cancel(t *testing.T) { require.Equal(t, "", job.JobId) } +func TestHeartbeat(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + heartbeatChan := make(chan struct{}) + heartbeatFn := func(hbCtx context.Context) error { + t.Logf("heartbeat") + select { + case <-hbCtx.Done(): + return hbCtx.Err() + default: + heartbeatChan <- struct{}{} + return nil + } + } + //nolint:dogsled // 。:゚૮ ˶ˆ ﻌ ˆ˶ ა ゚:。 + _, _, _, _ = setup(t, false, &overrides{ + ctx: ctx, + heartbeatFn: heartbeatFn, + heartbeatInterval: testutil.IntervalFast, + }) + + _, ok := <-heartbeatChan + require.True(t, ok, "first heartbeat not received") + _, ok = <-heartbeatChan + require.True(t, ok, "second heartbeat not received") + cancel() + // Close the channel to ensure we don't receive any more heartbeats. + // The test will fail if we do. + defer func() { + if r := recover(); r != nil { + t.Fatalf("heartbeat received after cancel: %v", r) + } + }() + + close(heartbeatChan) + <-time.After(testutil.IntervalMedium) +} + func TestAcquireJob(t *testing.T) { t.Parallel() @@ -120,7 +162,7 @@ func TestAcquireJob(t *testing.T) { tc := tc t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) { t.Parallel() - srv, db, _ := setup(t, false, nil) + srv, db, _, _ := setup(t, false, nil) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() _, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ @@ -141,7 +183,7 @@ func TestAcquireJob(t *testing.T) { // deployment config. dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)} gitAuthProvider := "github" - srv, db, ps := setup(t, false, &overrides{ + srv, db, ps, _ := setup(t, false, &overrides{ deploymentValues: dv, externalAuthConfigs: []*externalauth.Config{{ ID: gitAuthProvider, @@ -359,7 +401,7 @@ func TestAcquireJob(t *testing.T) { t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) { t.Parallel() - srv, db, ps := setup(t, false, nil) + srv, db, ps, _ := setup(t, false, nil) ctx := context.Background() user := dbgen.User(t, db, database.User{}) @@ -396,7 +438,7 @@ func TestAcquireJob(t *testing.T) { }) t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) { t.Parallel() - srv, db, ps := setup(t, false, nil) + srv, db, ps, _ := setup(t, false, nil) ctx := context.Background() user := dbgen.User(t, db, database.User{}) @@ -427,7 +469,7 @@ func TestAcquireJob(t *testing.T) { }) t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) { t.Parallel() - srv, db, ps := setup(t, false, nil) + srv, db, ps, _ := setup(t, false, nil) user := dbgen.User(t, db, database.User{}) version := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) @@ -476,7 +518,7 @@ func TestUpdateJob(t *testing.T) { ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() - srv, _, _ := setup(t, false, nil) + srv, _, _, _ := setup(t, false, nil) _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: "hello", }) @@ -489,7 +531,7 @@ func TestUpdateJob(t *testing.T) { }) t.Run("NotRunning", func(t *testing.T) { t.Parallel() - srv, db, _ := setup(t, false, nil) + srv, db, _, _ := setup(t, false, nil) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, @@ -505,7 +547,7 @@ func TestUpdateJob(t *testing.T) { // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() - srv, db, _ := setup(t, false, nil) + srv, db, _, _ := setup(t, false, nil) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, @@ -548,9 +590,8 @@ func TestUpdateJob(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - srvID := uuid.New() - srv, db, _ := setup(t, false, &overrides{id: &srvID}) - job := setupJob(t, db, srvID) + srv, db, _, pd := setup(t, false, &overrides{}) + job := setupJob(t, db, pd.ID) _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ JobId: job.String(), }) @@ -559,9 +600,8 @@ func TestUpdateJob(t *testing.T) { t.Run("Logs", func(t *testing.T) { t.Parallel() - srvID := uuid.New() - srv, db, ps := setup(t, false, &overrides{id: &srvID}) - job := setupJob(t, db, srvID) + srv, db, ps, pd := setup(t, false, &overrides{}) + job := setupJob(t, db, pd.ID) published := make(chan struct{}) @@ -585,9 +625,8 @@ func TestUpdateJob(t *testing.T) { }) t.Run("Readme", func(t *testing.T) { t.Parallel() - srvID := uuid.New() - srv, db, _ := setup(t, false, &overrides{id: &srvID}) - job := setupJob(t, db, srvID) + srv, db, _, pd := setup(t, false, &overrides{}) + job := setupJob(t, db, pd.ID) versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, @@ -612,9 +651,8 @@ func TestUpdateJob(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - srvID := uuid.New() - srv, db, _ := setup(t, false, &overrides{id: &srvID}) - job := setupJob(t, db, srvID) + srv, db, _, pd := setup(t, false, &overrides{}) + job := setupJob(t, db, pd.ID) versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, @@ -660,9 +698,8 @@ func TestUpdateJob(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - srvID := uuid.New() - srv, db, _ := setup(t, false, &overrides{id: &srvID}) - job := setupJob(t, db, srvID) + srv, db, _, pd := setup(t, false, &overrides{}) + job := setupJob(t, db, pd.ID) versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ ID: versionID, @@ -707,7 +744,7 @@ func TestFailJob(t *testing.T) { ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() - srv, _, _ := setup(t, false, nil) + srv, _, _, _ := setup(t, false, nil) _, err := srv.FailJob(ctx, &proto.FailedJob{ JobId: "hello", }) @@ -721,7 +758,7 @@ func TestFailJob(t *testing.T) { // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() - srv, db, _ := setup(t, false, nil) + srv, db, _, _ := setup(t, false, nil) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, @@ -744,8 +781,7 @@ func TestFailJob(t *testing.T) { }) t.Run("AlreadyCompleted", func(t *testing.T) { t.Parallel() - srvID := uuid.New() - srv, db, _ := setup(t, false, &overrides{id: &srvID}) + srv, db, _, pd := setup(t, false, &overrides{}) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, @@ -755,7 +791,7 @@ func TestFailJob(t *testing.T) { require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srvID, + UUID: pd.ID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -780,8 +816,7 @@ func TestFailJob(t *testing.T) { // // (*Server).FailJob audit log - get build {"error": "sql: no rows in result set"} ignoreLogErrors := true - srvID := uuid.New() - srv, db, ps := setup(t, ignoreLogErrors, &overrides{id: &srvID}) + srv, db, ps, pd := setup(t, ignoreLogErrors, &overrides{}) workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ ID: uuid.New(), AutomaticUpdates: database.AutomaticUpdatesNever, @@ -810,7 +845,7 @@ func TestFailJob(t *testing.T) { require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srvID, + UUID: pd.ID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -852,7 +887,7 @@ func TestCompleteJob(t *testing.T) { ctx := context.Background() t.Run("NotFound", func(t *testing.T) { t.Parallel() - srv, _, _ := setup(t, false, nil) + srv, _, _, _ := setup(t, false, nil) _, err := srv.CompleteJob(ctx, &proto.CompletedJob{ JobId: "hello", }) @@ -866,7 +901,7 @@ func TestCompleteJob(t *testing.T) { // This test prevents runners from updating jobs they don't own! t.Run("NotOwner", func(t *testing.T) { t.Parallel() - srv, db, _ := setup(t, false, nil) + srv, db, _, _ := setup(t, false, nil) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, @@ -890,8 +925,7 @@ func TestCompleteJob(t *testing.T) { t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) { t.Parallel() - srvID := uuid.New() - srv, db, _ := setup(t, false, &overrides{id: &srvID}) + srv, db, _, pd := setup(t, false, &overrides{}) jobID := uuid.New() versionID := uuid.New() err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ @@ -909,7 +943,7 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srvID, + UUID: pd.ID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -939,9 +973,7 @@ func TestCompleteJob(t *testing.T) { t.Run("TemplateImport_WithGitAuth", func(t *testing.T) { t.Parallel() - srvID := uuid.New() - srv, db, _ := setup(t, false, &overrides{ - id: &srvID, + srv, db, _, pd := setup(t, false, &overrides{ externalAuthConfigs: []*externalauth.Config{{ ID: "github", }}, @@ -963,7 +995,7 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srvID, + UUID: pd.ID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -1106,9 +1138,8 @@ func TestCompleteJob(t *testing.T) { t.Run(c.name, func(t *testing.T) { t.Parallel() - srvID := uuid.New() tss := &atomic.Pointer[schedule.TemplateScheduleStore]{} - srv, db, ps := setup(t, false, &overrides{id: &srvID, templateScheduleStore: tss}) + srv, db, ps, pd := setup(t, false, &overrides{templateScheduleStore: tss}) var store schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{ GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) { @@ -1123,10 +1154,19 @@ func TestCompleteJob(t *testing.T) { } tss.Store(&store) + org := dbgen.Organization(t, db, database.Organization{}) user := dbgen.User(t, db, database.User{}) template := dbgen.Template(t, db, database.Template{ - Name: "template", - Provisioner: database.ProvisionerTypeEcho, + Name: "template", + Provisioner: database.ProvisionerTypeEcho, + OrganizationID: org.ID, + }) + version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{ + UUID: template.ID, + Valid: true, + }, + JobID: uuid.New(), }) err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{ ID: template.ID, @@ -1148,13 +1188,6 @@ func TestCompleteJob(t *testing.T) { TemplateID: template.ID, Ttl: workspaceTTL, }) - version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{ - UUID: template.ID, - Valid: true, - }, - JobID: uuid.New(), - }) build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ WorkspaceID: workspace.ID, TemplateVersionID: version.ID, @@ -1170,7 +1203,7 @@ func TestCompleteJob(t *testing.T) { }) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srvID, + UUID: pd.ID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -1315,19 +1348,17 @@ func TestCompleteJob(t *testing.T) { t.Run(c.name, func(t *testing.T) { t.Parallel() - srvID := uuid.New() // Simulate the given time starting from now. require.False(t, c.now.IsZero()) start := time.Now() tss := &atomic.Pointer[schedule.TemplateScheduleStore]{} uqhss := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{} - srv, db, ps := setup(t, false, &overrides{ + srv, db, ps, pd := setup(t, false, &overrides{ timeNowFn: func() time.Time { return c.now.Add(time.Since(start)) }, templateScheduleStore: tss, userQuietHoursScheduleStore: uqhss, - id: &srvID, }) var templateScheduleStore schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{ @@ -1418,7 +1449,7 @@ func TestCompleteJob(t *testing.T) { }) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srvID, + UUID: pd.ID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -1484,8 +1515,7 @@ func TestCompleteJob(t *testing.T) { }) t.Run("TemplateDryRun", func(t *testing.T) { t.Parallel() - srvID := uuid.New() - srv, db, _ := setup(t, false, &overrides{id: &srvID}) + srv, db, _, pd := setup(t, false, &overrides{}) job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: uuid.New(), Provisioner: database.ProvisionerTypeEcho, @@ -1495,7 +1525,7 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ WorkerID: uuid.NullUUID{ - UUID: srvID, + UUID: pd.ID, Valid: true, }, Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, @@ -1686,73 +1716,89 @@ func TestInsertWorkspaceResource(t *testing.T) { } type overrides struct { + ctx context.Context deploymentValues *codersdk.DeploymentValues externalAuthConfigs []*externalauth.Config - id *uuid.UUID templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] timeNowFn func() time.Time acquireJobLongPollDuration time.Duration + heartbeatFn func(ctx context.Context) error + heartbeatInterval time.Duration } -func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) { +func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub, database.ProvisionerDaemon) { t.Helper() - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) db := dbmem.New() ps := pubsub.NewInMemory() deploymentValues := &codersdk.DeploymentValues{} var externalAuthConfigs []*externalauth.Config - srvID := uuid.New() tss := testTemplateScheduleStore() uqhss := testUserQuietHoursScheduleStore() var timeNowFn func() time.Time pollDur := time.Duration(0) - if ov != nil { - if ov.deploymentValues != nil { - deploymentValues = ov.deploymentValues - } - if ov.externalAuthConfigs != nil { - externalAuthConfigs = ov.externalAuthConfigs - } - if ov.id != nil { - srvID = *ov.id - } - if ov.templateScheduleStore != nil { - ttss := tss.Load() - // keep the initial test value if the override hasn't set the atomic pointer. - tss = ov.templateScheduleStore - if tss.Load() == nil { - swapped := tss.CompareAndSwap(nil, ttss) - require.True(t, swapped) - } - } - if ov.userQuietHoursScheduleStore != nil { - tuqhss := uqhss.Load() - // keep the initial test value if the override hasn't set the atomic pointer. - uqhss = ov.userQuietHoursScheduleStore - if uqhss.Load() == nil { - swapped := uqhss.CompareAndSwap(nil, tuqhss) - require.True(t, swapped) - } - } - if ov.timeNowFn != nil { - timeNowFn = ov.timeNowFn - } - pollDur = ov.acquireJobLongPollDuration + if ov == nil { + ov = &overrides{} } + if ov.ctx == nil { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + ov.ctx = ctx + } + if ov.heartbeatInterval == 0 { + ov.heartbeatInterval = testutil.IntervalMedium + } + if ov.deploymentValues != nil { + deploymentValues = ov.deploymentValues + } + if ov.externalAuthConfigs != nil { + externalAuthConfigs = ov.externalAuthConfigs + } + if ov.templateScheduleStore != nil { + ttss := tss.Load() + // keep the initial test value if the override hasn't set the atomic pointer. + tss = ov.templateScheduleStore + if tss.Load() == nil { + swapped := tss.CompareAndSwap(nil, ttss) + require.True(t, swapped) + } + } + if ov.userQuietHoursScheduleStore != nil { + tuqhss := uqhss.Load() + // keep the initial test value if the override hasn't set the atomic pointer. + uqhss = ov.userQuietHoursScheduleStore + if uqhss.Load() == nil { + swapped := uqhss.CompareAndSwap(nil, tuqhss) + require.True(t, swapped) + } + } + if ov.timeNowFn != nil { + timeNowFn = ov.timeNowFn + } + pollDur = ov.acquireJobLongPollDuration + + daemon, err := db.UpsertProvisionerDaemon(ov.ctx, database.UpsertProvisionerDaemonParams{ + Name: "test", + CreatedAt: dbtime.Now(), + Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho}, + Tags: database.StringMap{}, + LastSeenAt: sql.NullTime{}, + Version: "", + APIVersion: "1.0", + }) + require.NoError(t, err) srv, err := provisionerdserver.NewServer( - ctx, + ov.ctx, &url.URL{}, - srvID, + daemon.ID, slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}), []database.ProvisionerType{database.ProvisionerTypeEcho}, - provisionerdserver.Tags{}, + provisionerdserver.Tags(daemon.Tags), db, ps, - provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps), + provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps), telemetry.NewNoop(), trace.NewNoopTracerProvider().Tracer("noop"), &atomic.Pointer[proto.QuotaCommitter]{}, @@ -1765,10 +1811,12 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi TimeNowFn: timeNowFn, OIDCConfig: &oauth2.Config{}, AcquireJobLongPollDur: pollDur, + HeartbeatInterval: ov.heartbeatInterval, + HeartbeatFn: ov.heartbeatFn, }, ) require.NoError(t, err) - return srv, db, ps + return srv, db, ps, daemon } func must[T any](value T, err error) T { diff --git a/coderd/rbac/roles.go b/coderd/rbac/roles.go index 3a2a0d74ea..7f8e0b2759 100644 --- a/coderd/rbac/roles.go +++ b/coderd/rbac/roles.go @@ -155,6 +155,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) { // Users cannot do create/update/delete on themselves, but they // can read their own details. ResourceUser.Type: {ActionRead}, + // Users can create provisioner daemons scoped to themselves. + ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate}, })..., ), }.withCachedRegoValue() diff --git a/codersdk/organizations.go b/codersdk/organizations.go index 5efe728ad0..e26b406cc4 100644 --- a/codersdk/organizations.go +++ b/codersdk/organizations.go @@ -164,9 +164,6 @@ func (c *Client) Organization(ctx context.Context, id uuid.UUID) (Organization, } // ProvisionerDaemons returns provisioner daemons available. -// -// Deprecated: We no longer track provisioner daemons as they connect. This function may return historical data -// but new provisioner daemons will not appear. func (c *Client) ProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) { res, err := c.Request(ctx, http.MethodGet, // TODO: the organization path parameter is currently ignored. diff --git a/enterprise/cli/provisionerdaemons_test.go b/enterprise/cli/provisionerdaemons_test.go index 4f1e09ad43..6424801e53 100644 --- a/enterprise/cli/provisionerdaemons_test.go +++ b/enterprise/cli/provisionerdaemons_test.go @@ -4,13 +4,16 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" ) @@ -35,6 +38,17 @@ func TestProvisionerDaemon_PSK(t *testing.T) { clitest.Start(t, inv) pty.ExpectMatchContext(ctx, "starting provisioner daemon") pty.ExpectMatchContext(ctx, "matt-daemon") + + var daemons []codersdk.ProvisionerDaemon + require.Eventually(t, func() bool { + daemons, err = client.ProvisionerDaemons(ctx) + if err != nil { + return false + } + return len(daemons) == 1 + }, testutil.WaitShort, testutil.IntervalSlow) + require.Equal(t, "matt-daemon", daemons[0].Name) + require.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope]) } func TestProvisionerDaemon_SessionToken(t *testing.T) { @@ -49,14 +63,61 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) { }, }, }) - anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID) - inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user") + anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID) + inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user", "--name", "my-daemon") clitest.SetupConfig(t, anotherClient, conf) pty := ptytest.New(t).Attach(inv) ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong) defer cancel() clitest.Start(t, inv) pty.ExpectMatchContext(ctx, "starting provisioner daemon") + + var daemons []codersdk.ProvisionerDaemon + var err error + require.Eventually(t, func() bool { + daemons, err = client.ProvisionerDaemons(ctx) + if err != nil { + return false + } + return len(daemons) == 1 + }, testutil.WaitShort, testutil.IntervalSlow) + assert.Equal(t, "my-daemon", daemons[0].Name) + assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope]) + assert.Equal(t, anotherUser.ID.String(), daemons[0].Tags[provisionersdk.TagOwner]) + }) + + t.Run("ScopeAnotherUser", func(t *testing.T) { + t.Parallel() + client, admin := coderdenttest.New(t, &coderdenttest.Options{ + ProvisionerDaemonPSK: "provisionersftw", + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID) + inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user", "--tag", "owner="+admin.UserID.String(), "--name", "my-daemon") + clitest.SetupConfig(t, anotherClient, conf) + pty := ptytest.New(t).Attach(inv) + ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong) + defer cancel() + clitest.Start(t, inv) + pty.ExpectMatchContext(ctx, "starting provisioner daemon") + + var daemons []codersdk.ProvisionerDaemon + var err error + require.Eventually(t, func() bool { + daemons, err = client.ProvisionerDaemons(ctx) + if err != nil { + return false + } + return len(daemons) == 1 + }, testutil.WaitShort, testutil.IntervalSlow) + assert.Equal(t, "my-daemon", daemons[0].Name) + assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope]) + // This should get clobbered to the user who started the daemon. + assert.Equal(t, anotherUser.ID.String(), daemons[0].Tags[provisionersdk.TagOwner]) }) t.Run("ScopeOrg", func(t *testing.T) { @@ -69,13 +130,25 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) { }, }, }) - anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID) - inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=organization") + anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID, rbac.RoleTemplateAdmin()) + inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=organization", "--name", "org-daemon") clitest.SetupConfig(t, anotherClient, conf) pty := ptytest.New(t).Attach(inv) ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong) defer cancel() clitest.Start(t, inv) pty.ExpectMatchContext(ctx, "starting provisioner daemon") + + var daemons []codersdk.ProvisionerDaemon + var err error + require.Eventually(t, func() bool { + daemons, err = client.ProvisionerDaemons(ctx) + if err != nil { + return false + } + return len(daemons) == 1 + }, testutil.WaitShort, testutil.IntervalSlow) + assert.Equal(t, "org-daemon", daemons[0].Name) + assert.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope]) }) } diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 037dbe38bc..874c8cb501 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -26,6 +26,8 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/provisionerdserver" @@ -121,7 +123,7 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin psk := r.Header.Get(codersdk.ProvisionerDaemonPSK) if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 { // If using PSK auth, the daemon is, by definition, scoped to the organization. - tags[provisionersdk.TagScope] = provisionersdk.ScopeOrganization + tags = provisionersdk.MutateTags(uuid.Nil, tags) return tags, true } } @@ -191,7 +193,10 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) if !authorized { api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags)) httpapi.Write(ctx, rw, http.StatusForbidden, - codersdk.Response{Message: "You aren't allowed to create provisioner daemons"}) + codersdk.Response{ + Message: fmt.Sprintf("You aren't allowed to create provisioner daemons with scope %q", tags[provisionersdk.TagScope]), + }, + ) return } api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags)) @@ -221,6 +226,35 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) slog.F("tags", tags), ) + authCtx := ctx + if r.Header.Get(codersdk.ProvisionerDaemonPSK) != "" { + //nolint:gocritic // PSK auth means no actor in request, + // so use system restricted. + authCtx = dbauthz.AsSystemRestricted(ctx) + } + + // Create the daemon in the database. + now := dbtime.Now() + daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{ + Name: name, + Provisioners: provisioners, + Tags: tags, + CreatedAt: now, + LastSeenAt: sql.NullTime{Time: now, Valid: true}, + Version: "", // TODO: provisionerd needs to send version + APIVersion: "1.0", + }) + if err != nil { + if !xerrors.Is(err, context.Canceled) { + log.Error(ctx, "create provisioner daemon", slog.Error(err)) + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error creating provisioner daemon.", + Detail: err.Error(), + }) + } + return + } + api.AGPL.WebsocketWaitMutex.Lock() api.AGPL.WebsocketWaitGroup.Add(1) api.AGPL.WebsocketWaitMutex.Unlock() @@ -264,11 +298,13 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) } mux := drpcmux.New() logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name)) + srvCtx, srvCancel := context.WithCancel(ctx) + defer srvCancel() logger.Info(ctx, "starting external provisioner daemon") srv, err := provisionerdserver.NewServer( - api.ctx, + srvCtx, api.AccessURL, - id, + daemon.ID, logger, provisioners, tags, @@ -308,6 +344,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) }, }) err = server.Serve(ctx, session) + srvCancel() logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err)) if err != nil && !xerrors.Is(err, io.EOF) { _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index 10442f71f5..2e19aa3168 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "cdr.dev/slog" @@ -50,6 +51,10 @@ func TestProvisionerDaemonServe(t *testing.T) { }) require.NoError(t, err) srv.DRPCConn().Close() + + daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. + require.NoError(t, err) + require.Len(t, daemons, 1) }) t.Run("NoLicense", func(t *testing.T) { @@ -163,6 +168,15 @@ func TestProvisionerDaemonServe(t *testing.T) { file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, bytes.NewReader(data)) require.NoError(t, err) + require.Eventually(t, func() bool { + daemons, err := client.ProvisionerDaemons(context.Background()) + assert.NoError(t, err, "failed to get provisioner daemons") + return len(daemons) > 0 && + assert.Equal(t, t.Name(), daemons[0].Name) && + assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope]) && + assert.Equal(t, user.UserID.String(), daemons[0].Tags[provisionersdk.TagOwner]) + }, testutil.WaitShort, testutil.IntervalMedium) + version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{ Name: "example", StorageMethod: codersdk.ProvisionerStorageMethodFile, @@ -211,6 +225,12 @@ func TestProvisionerDaemonServe(t *testing.T) { require.NoError(t, err) err = srv.DRPCConn().Close() require.NoError(t, err) + + daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. + require.NoError(t, err) + if assert.Len(t, daemons, 1) { + assert.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope]) + } }) t.Run("PSK_daily_cost", func(t *testing.T) { @@ -346,6 +366,10 @@ func TestProvisionerDaemonServe(t *testing.T) { var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + + daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. + require.NoError(t, err) + require.Len(t, daemons, 0) }) t.Run("NoAuth", func(t *testing.T) { @@ -376,6 +400,10 @@ func TestProvisionerDaemonServe(t *testing.T) { var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + + daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. + require.NoError(t, err) + require.Len(t, daemons, 0) }) t.Run("NoPSK", func(t *testing.T) { @@ -406,5 +434,9 @@ func TestProvisionerDaemonServe(t *testing.T) { var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + + daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. + require.NoError(t, err) + require.Len(t, daemons, 0) }) } diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 9072085ff5..52414db4af 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "net/http" "reflect" "sync" "time" @@ -20,6 +21,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" @@ -199,6 +201,12 @@ connectLoop: if errors.Is(err, context.Canceled) { return } + var sdkErr *codersdk.Error + // If something is wrong with our auth, stop trying to connect. + if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusForbidden { + p.opts.Logger.Error(p.closeContext, "not authorized to dial coderd", slog.Error(err)) + return + } if p.isClosed() { return } diff --git a/provisionersdk/provisionertags.go b/provisionersdk/provisionertags.go index 970cf2094d..9dc9bd7392 100644 --- a/provisionersdk/provisionertags.go +++ b/provisionersdk/provisionertags.go @@ -14,7 +14,10 @@ const ( // If the scope is "user", the "owner" is changed to the user ID. // This is for user-scoped provisioner daemons, where users should // own their own operations. -// Otherwise, the "owner" tag is always empty. +// Otherwise, the "owner" tag is always an empty string. +// NOTE: "owner" must NEVER be nil. Otherwise it will end up being +// duplicated in the database, as idx_provisioner_daemons_name_owner_key +// is a partial unique index that includes a JSON field. func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string { if tags == nil { tags = map[string]string{} @@ -22,15 +25,16 @@ func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string { _, ok := tags[TagScope] if !ok { tags[TagScope] = ScopeOrganization - delete(tags, TagOwner) + tags[TagOwner] = "" } switch tags[TagScope] { case ScopeUser: tags[TagOwner] = userID.String() case ScopeOrganization: - delete(tags, TagOwner) + tags[TagOwner] = "" default: tags[TagScope] = ScopeOrganization + tags[TagOwner] = "" } return tags } diff --git a/provisionersdk/provisionertags_test.go b/provisionersdk/provisionertags_test.go index 26ecc1d12b..911de161c1 100644 --- a/provisionersdk/provisionertags_test.go +++ b/provisionersdk/provisionertags_test.go @@ -27,6 +27,7 @@ func TestMutateTags(t *testing.T) { tags: nil, want: map[string]string{ provisionersdk.TagScope: provisionersdk.ScopeOrganization, + provisionersdk.TagOwner: "", }, }, { @@ -35,6 +36,7 @@ func TestMutateTags(t *testing.T) { tags: map[string]string{}, want: map[string]string{ provisionersdk.TagScope: provisionersdk.ScopeOrganization, + provisionersdk.TagOwner: "", }, }, { @@ -52,6 +54,7 @@ func TestMutateTags(t *testing.T) { userID: testUserID, want: map[string]string{ provisionersdk.TagScope: provisionersdk.ScopeOrganization, + provisionersdk.TagOwner: "", }, }, { @@ -63,6 +66,7 @@ func TestMutateTags(t *testing.T) { userID: uuid.Nil, want: map[string]string{ provisionersdk.TagScope: provisionersdk.ScopeOrganization, + provisionersdk.TagOwner: "", }, }, { @@ -73,6 +77,7 @@ func TestMutateTags(t *testing.T) { userID: uuid.Nil, want: map[string]string{ provisionersdk.TagScope: provisionersdk.ScopeOrganization, + provisionersdk.TagOwner: "", }, }, { @@ -81,6 +86,7 @@ func TestMutateTags(t *testing.T) { userID: testUserID, want: map[string]string{ provisionersdk.TagScope: provisionersdk.ScopeOrganization, + provisionersdk.TagOwner: "", }, }, } {