feat(coderd): insert provisioner daemons (#11207)

* Adds UpdateProvisionerDaemonLastSeenAt
* Adds heartbeat to provisioner daemons
* Inserts provisioner daemons to database upon start
* Ensures TagOwner is an empty string and not nil
* Adds COALESCE() in idx_provisioner_daemons_name_owner_key
This commit is contained in:
Cian Johnston 2023-12-18 16:44:52 +00:00 committed by GitHub
parent a6901ae2c5
commit 213b768785
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 548 additions and 137 deletions

View File

@ -36,6 +36,7 @@
"worker_id": "[workspace build worker ID]", "worker_id": "[workspace build worker ID]",
"file_id": "[workspace build file ID]", "file_id": "[workspace build file ID]",
"tags": { "tags": {
"owner": "",
"scope": "organization" "scope": "organization"
}, },
"queue_position": 0, "queue_position": 0,

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"database/sql"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -49,6 +50,7 @@ import (
"github.com/coder/coder/v2/coderd/batchstats" "github.com/coder/coder/v2/coderd/batchstats"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/gitsshkey" "github.com/coder/coder/v2/coderd/gitsshkey"
@ -1178,22 +1180,32 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string
} }
}() }()
tags := provisionerdserver.Tags{ //nolint:gocritic // in-memory provisioners are owned by system
provisionersdk.TagScope: provisionersdk.ScopeOrganization, daemon, err := api.Database.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{
Name: name,
CreatedAt: dbtime.Now(),
Provisioners: []database.ProvisionerType{
database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform,
},
Tags: provisionersdk.MutateTags(uuid.Nil, nil),
LastSeenAt: sql.NullTime{Time: dbtime.Now(), Valid: true},
Version: buildinfo.Version(),
APIVersion: "1.0",
})
if err != nil {
return nil, xerrors.Errorf("failed to create in-memory provisioner daemon: %w", err)
} }
mux := drpcmux.New() mux := drpcmux.New()
api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name)) api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name))
logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name)) logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name))
srv, err := provisionerdserver.NewServer( srv, err := provisionerdserver.NewServer(
api.ctx, api.ctx, // use the same ctx as the API
api.AccessURL, api.AccessURL,
uuid.New(), daemon.ID,
logger, logger,
[]database.ProvisionerType{ daemon.Provisioners,
database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform, provisionerdserver.Tags(daemon.Tags),
},
tags,
api.Database, api.Database,
api.Pubsub, api.Pubsub,
api.Acquirer, api.Acquirer,

View File

@ -533,7 +533,7 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer {
}() }()
daemon := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { daemon := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, t.Name()) return coderAPI.CreateInMemoryProvisionerDaemon(ctx, "test")
}, &provisionerd.Options{ }, &provisionerd.Options{
Logger: coderAPI.Logger.Named("provisionerd").Leveled(slog.LevelDebug), Logger: coderAPI.Logger.Named("provisionerd").Leveled(slog.LevelDebug),
UpdateInterval: 250 * time.Millisecond, UpdateInterval: 250 * time.Millisecond,

View File

@ -232,6 +232,7 @@ var (
rbac.ResourceOrganization.Type: {rbac.ActionCreate}, rbac.ResourceOrganization.Type: {rbac.ActionCreate},
rbac.ResourceOrganizationMember.Type: {rbac.ActionCreate}, rbac.ResourceOrganizationMember.Type: {rbac.ActionCreate},
rbac.ResourceOrgRoleAssignment.Type: {rbac.ActionCreate}, rbac.ResourceOrgRoleAssignment.Type: {rbac.ActionCreate},
rbac.ResourceProvisionerDaemon.Type: {rbac.ActionCreate, rbac.ActionUpdate},
rbac.ResourceUser.Type: {rbac.ActionCreate, rbac.ActionUpdate, rbac.ActionDelete}, rbac.ResourceUser.Type: {rbac.ActionCreate, rbac.ActionUpdate, rbac.ActionDelete},
rbac.ResourceUserData.Type: {rbac.ActionCreate, rbac.ActionUpdate}, rbac.ResourceUserData.Type: {rbac.ActionCreate, rbac.ActionUpdate},
rbac.ResourceWorkspace.Type: {rbac.ActionUpdate}, rbac.ResourceWorkspace.Type: {rbac.ActionUpdate},
@ -2499,6 +2500,13 @@ func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemb
return q.db.UpdateMemberRoles(ctx, arg) return q.db.UpdateMemberRoles(ctx, arg)
} }
func (q *querier) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceProvisionerDaemon); err != nil {
return err
}
return q.db.UpdateProvisionerDaemonLastSeenAt(ctx, arg)
}
// TODO: We need to create a ProvisionerJob resource type // TODO: We need to create a ProvisionerJob resource type
func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error { func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
// if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { // if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {

View File

@ -1592,6 +1592,18 @@ func (s *MethodTestSuite) TestExtraMethods() {
s.NoError(err, "insert provisioner daemon") s.NoError(err, "insert provisioner daemon")
check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete) check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete)
})) }))
s.Run("UpdateProvisionerDaemonLastSeenAt", s.Subtest(func(db database.Store, check *expects) {
d, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{
Tags: database.StringMap(map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
}),
})
s.NoError(err, "insert provisioner daemon")
check.Args(database.UpdateProvisionerDaemonLastSeenAtParams{
ID: d.ID,
LastSeenAt: sql.NullTime{Time: dbtime.Now(), Valid: true},
}).Asserts(rbac.ResourceProvisionerDaemon, rbac.ActionUpdate)
}))
} }
// All functions in this method test suite are not implemented in dbmem, but // All functions in this method test suite are not implemented in dbmem, but

View File

@ -28,7 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/util/slice"
) )
var errMatchAny = errors.New("match any error") var errMatchAny = xerrors.New("match any error")
var skipMethods = map[string]string{ var skipMethods = map[string]string{
"InTx": "Not relevant", "InTx": "Not relevant",

View File

@ -5945,6 +5945,28 @@ func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMe
return database.OrganizationMember{}, sql.ErrNoRows return database.OrganizationMember{}, sql.ErrNoRows
} }
func (q *FakeQuerier) UpdateProvisionerDaemonLastSeenAt(_ context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
err := validateDatabaseType(arg)
if err != nil {
return err
}
q.mutex.Lock()
defer q.mutex.Unlock()
for idx := range q.provisionerDaemons {
if q.provisionerDaemons[idx].ID != arg.ID {
continue
}
if q.provisionerDaemons[idx].LastSeenAt.Time.After(arg.LastSeenAt.Time) {
continue
}
q.provisionerDaemons[idx].LastSeenAt = arg.LastSeenAt
return nil
}
return sql.ErrNoRows
}
func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error {
if err := validateDatabaseType(arg); err != nil { if err := validateDatabaseType(arg); err != nil {
return err return err

View File

@ -1593,6 +1593,13 @@ func (m metricsStore) UpdateMemberRoles(ctx context.Context, arg database.Update
return member, err return member, err
} }
func (m metricsStore) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
start := time.Now()
r0 := m.s.UpdateProvisionerDaemonLastSeenAt(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateProvisionerDaemonLastSeenAt").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error { func (m metricsStore) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
start := time.Now() start := time.Now()
err := m.s.UpdateProvisionerJobByID(ctx, arg) err := m.s.UpdateProvisionerJobByID(ctx, arg)

View File

@ -3362,6 +3362,20 @@ func (mr *MockStoreMockRecorder) UpdateMemberRoles(arg0, arg1 interface{}) *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMemberRoles", reflect.TypeOf((*MockStore)(nil).UpdateMemberRoles), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMemberRoles", reflect.TypeOf((*MockStore)(nil).UpdateMemberRoles), arg0, arg1)
} }
// UpdateProvisionerDaemonLastSeenAt mocks base method.
func (m *MockStore) UpdateProvisionerDaemonLastSeenAt(arg0 context.Context, arg1 database.UpdateProvisionerDaemonLastSeenAtParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProvisionerDaemonLastSeenAt", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProvisionerDaemonLastSeenAt indicates an expected call of UpdateProvisionerDaemonLastSeenAt.
func (mr *MockStoreMockRecorder) UpdateProvisionerDaemonLastSeenAt(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProvisionerDaemonLastSeenAt", reflect.TypeOf((*MockStore)(nil).UpdateProvisionerDaemonLastSeenAt), arg0, arg1)
}
// UpdateProvisionerJobByID mocks base method. // UpdateProvisionerJobByID mocks base method.
func (m *MockStore) UpdateProvisionerJobByID(arg0 context.Context, arg1 database.UpdateProvisionerJobByIDParams) error { func (m *MockStore) UpdateProvisionerJobByID(arg0 context.Context, arg1 database.UpdateProvisionerJobByIDParams) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -1417,9 +1417,9 @@ CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name);
CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)); CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name));
CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower((tags ->> 'owner'::text))); CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower(COALESCE((tags ->> 'owner'::text), ''::text)));
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key IS 'Relax uniqueness constraint for provisioner daemon names'; COMMENT ON INDEX idx_provisioner_daemons_name_owner_key IS 'Allow unique provisioner daemon names by user';
CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id); CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id);

View File

@ -0,0 +1,8 @@
DROP INDEX IF EXISTS idx_provisioner_daemons_name_owner_key;
CREATE UNIQUE INDEX IF NOT EXISTS idx_provisioner_daemons_name_owner_key
ON provisioner_daemons
USING btree (name, lower((tags->>'owner')::text));
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key
IS 'Relax uniqueness constraint for provisioner daemon names';

View File

@ -0,0 +1,8 @@
DROP INDEX IF EXISTS idx_provisioner_daemons_name_owner_key;
CREATE UNIQUE INDEX IF NOT EXISTS idx_provisioner_daemons_name_owner_key
ON provisioner_daemons
USING btree (name, LOWER(COALESCE(tags->>'owner', '')::text));
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key
IS 'Allow unique provisioner daemon names by user';

View File

@ -318,6 +318,7 @@ type sqlcQuerier interface {
UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error)
UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error) UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error)
UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error) UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error)
UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error
UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error

View File

@ -3057,6 +3057,26 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
return items, nil return items, nil
} }
const updateProvisionerDaemonLastSeenAt = `-- name: UpdateProvisionerDaemonLastSeenAt :exec
UPDATE provisioner_daemons
SET
last_seen_at = $1
WHERE
id = $2
AND
last_seen_at <= $1
`
type UpdateProvisionerDaemonLastSeenAtParams struct {
LastSeenAt sql.NullTime `db:"last_seen_at" json:"last_seen_at"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error {
_, err := q.db.ExecContext(ctx, updateProvisionerDaemonLastSeenAt, arg.LastSeenAt, arg.ID)
return err
}
const upsertProvisionerDaemon = `-- name: UpsertProvisionerDaemon :one const upsertProvisionerDaemon = `-- name: UpsertProvisionerDaemon :one
INSERT INTO INSERT INTO
provisioner_daemons ( provisioner_daemons (
@ -3078,7 +3098,7 @@ VALUES (
$5, $5,
$6, $6,
$7 $7
) ON CONFLICT("name", lower((tags ->> 'owner'::text))) DO UPDATE SET ) ON CONFLICT("name", LOWER(COALESCE(tags ->> 'owner'::text, ''::text))) DO UPDATE SET
provisioners = $3, provisioners = $3,
tags = $4, tags = $4,
last_seen_at = $5, last_seen_at = $5,

View File

@ -35,7 +35,7 @@ VALUES (
@last_seen_at, @last_seen_at,
@version, @version,
@api_version @api_version
) ON CONFLICT("name", lower((tags ->> 'owner'::text))) DO UPDATE SET ) ON CONFLICT("name", LOWER(COALESCE(tags ->> 'owner'::text, ''::text))) DO UPDATE SET
provisioners = @provisioners, provisioners = @provisioners,
tags = @tags, tags = @tags,
last_seen_at = @last_seen_at, last_seen_at = @last_seen_at,
@ -45,3 +45,12 @@ WHERE
-- Only ones with the same tags are allowed clobber -- Only ones with the same tags are allowed clobber
provisioner_daemons.tags <@ @tags :: jsonb provisioner_daemons.tags <@ @tags :: jsonb
RETURNING *; RETURNING *;
-- name: UpdateProvisionerDaemonLastSeenAt :exec
UPDATE provisioner_daemons
SET
last_seen_at = @last_seen_at
WHERE
id = @id
AND
last_seen_at <= @last_seen_at;

View File

@ -65,7 +65,7 @@ const (
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type); UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
UniqueIndexOrganizationName UniqueConstraint = "idx_organization_name" // CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name); UniqueIndexOrganizationName UniqueConstraint = "idx_organization_name" // CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name);
UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)); UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name));
UniqueIndexProvisionerDaemonsNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower((tags ->> 'owner'::text))); UniqueIndexProvisionerDaemonsNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower(COALESCE((tags ->> 'owner'::text), ''::text)));
UniqueIndexUsersEmail UniqueConstraint = "idx_users_email" // CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false); UniqueIndexUsersEmail UniqueConstraint = "idx_users_email" // CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false);
UniqueIndexUsersUsername UniqueConstraint = "idx_users_username" // CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false); UniqueIndexUsersUsername UniqueConstraint = "idx_users_username" // CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false);
UniqueTemplatesOrganizationIDNameIndex UniqueConstraint = "templates_organization_id_name_idx" // CREATE UNIQUE INDEX templates_organization_id_name_idx ON templates USING btree (organization_id, lower((name)::text)) WHERE (deleted = false); UniqueTemplatesOrganizationIDNameIndex UniqueConstraint = "templates_organization_id_name_idx" // CREATE UNIQUE INDEX templates_organization_id_name_idx ON templates USING btree (organization_id, lower((name)::text)) WHERE (deleted = false);

View File

@ -44,9 +44,15 @@ import (
sdkproto "github.com/coder/coder/v2/provisionersdk/proto" sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
) )
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before const (
// canceling and returning an empty job. // DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
const DefaultAcquireJobLongPollDur = time.Second * 5 // canceling and returning an empty job.
DefaultAcquireJobLongPollDur = time.Second * 5
// DefaultHeartbeatInterval is the interval at which the provisioner daemon
// will update its last seen at timestamp in the database.
DefaultHeartbeatInterval = time.Minute
)
type Options struct { type Options struct {
OIDCConfig httpmw.OAuth2Config OIDCConfig httpmw.OAuth2Config
@ -56,6 +62,16 @@ type Options struct {
// AcquireJobLongPollDur is used in tests // AcquireJobLongPollDur is used in tests
AcquireJobLongPollDur time.Duration AcquireJobLongPollDur time.Duration
// HeartbeatInterval is the interval at which the provisioner daemon
// will update its last seen at timestamp in the database.
HeartbeatInterval time.Duration
// HeartbeatFn is the function that will be called at the interval
// specified by HeartbeatInterval.
// The default function just calls UpdateProvisionerDaemonLastSeenAt.
// This is mainly used for testing.
HeartbeatFn func(context.Context) error
} }
type server struct { type server struct {
@ -85,6 +101,9 @@ type server struct {
TimeNowFn func() time.Time TimeNowFn func() time.Time
acquireJobLongPollDur time.Duration acquireJobLongPollDur time.Duration
heartbeatInterval time.Duration
heartbeatFn func(ctx context.Context) error
} }
// We use the null byte (0x00) in generating a canonical map key for tags, so // We use the null byte (0x00) in generating a canonical map key for tags, so
@ -161,7 +180,21 @@ func NewServer(
if options.AcquireJobLongPollDur == 0 { if options.AcquireJobLongPollDur == 0 {
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
} }
return &server{ if options.HeartbeatInterval == 0 {
options.HeartbeatInterval = DefaultHeartbeatInterval
}
// Avoid a nil check in s.heartbeat.
if options.HeartbeatFn == nil {
options.HeartbeatFn = func(hbCtx context.Context) error {
//nolint:gocritic // This is specifically for updating the last seen at timestamp.
return db.UpdateProvisionerDaemonLastSeenAt(dbauthz.AsSystemRestricted(hbCtx), database.UpdateProvisionerDaemonLastSeenAtParams{
ID: id,
LastSeenAt: sql.NullTime{Time: time.Now(), Valid: true},
})
}
}
s := &server{
lifecycleCtx: lifecycleCtx, lifecycleCtx: lifecycleCtx,
AccessURL: accessURL, AccessURL: accessURL,
ID: id, ID: id,
@ -182,7 +215,12 @@ func NewServer(
OIDCConfig: options.OIDCConfig, OIDCConfig: options.OIDCConfig,
TimeNowFn: options.TimeNowFn, TimeNowFn: options.TimeNowFn,
acquireJobLongPollDur: options.AcquireJobLongPollDur, acquireJobLongPollDur: options.AcquireJobLongPollDur,
}, nil heartbeatInterval: options.HeartbeatInterval,
heartbeatFn: options.HeartbeatFn,
}
go s.heartbeatLoop()
return s, nil
} }
// timeNow should be used when trying to get the current time for math // timeNow should be used when trying to get the current time for math
@ -194,6 +232,50 @@ func (s *server) timeNow() time.Time {
return dbtime.Now() return dbtime.Now()
} }
// heartbeatLoop runs heartbeatOnce at the interval specified by HeartbeatInterval
// until the lifecycle context is canceled.
func (s *server) heartbeatLoop() {
tick := time.NewTicker(time.Nanosecond)
defer tick.Stop()
for {
select {
case <-s.lifecycleCtx.Done():
s.Logger.Debug(s.lifecycleCtx, "heartbeat loop canceled")
return
case <-tick.C:
if s.lifecycleCtx.Err() != nil {
return
}
start := s.timeNow()
hbCtx, hbCancel := context.WithTimeout(s.lifecycleCtx, s.heartbeatInterval)
if err := s.heartbeat(hbCtx); err != nil {
if !xerrors.Is(err, context.DeadlineExceeded) && !xerrors.Is(err, context.Canceled) {
s.Logger.Error(hbCtx, "heartbeat failed", slog.Error(err))
}
}
hbCancel()
elapsed := s.timeNow().Sub(start)
nextBeat := s.heartbeatInterval - elapsed
// avoid negative interval
if nextBeat <= 0 {
nextBeat = time.Nanosecond
}
tick.Reset(nextBeat)
}
}
}
// heartbeat updates the last seen at timestamp in the database.
// If HeartbeatFn is set, it will be called instead.
func (s *server) heartbeat(ctx context.Context) error {
select {
case <-ctx.Done():
return nil
default:
return s.heartbeatFn(ctx)
}
}
// AcquireJob queries the database to lock a job. // AcquireJob queries the database to lock a job.
// //
// Deprecated: This method is only available for back-level provisioner daemons. // Deprecated: This method is only available for back-level provisioner daemons.

View File

@ -66,7 +66,8 @@ func testUserQuietHoursScheduleStore() *atomic.Pointer[schedule.UserQuietHoursSc
func TestAcquireJob_LongPoll(t *testing.T) { func TestAcquireJob_LongPoll(t *testing.T) {
t.Parallel() t.Parallel()
srv, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond}) //nolint:dogsled // ૮・ᴥ・ა
srv, _, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond})
job, err := srv.AcquireJob(context.Background(), nil) job, err := srv.AcquireJob(context.Background(), nil)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &proto.AcquiredJob{}, job) require.Equal(t, &proto.AcquiredJob{}, job)
@ -74,7 +75,8 @@ func TestAcquireJob_LongPoll(t *testing.T) {
func TestAcquireJobWithCancel_Cancel(t *testing.T) { func TestAcquireJobWithCancel_Cancel(t *testing.T) {
t.Parallel() t.Parallel()
srv, _, _ := setup(t, false, nil) //nolint:dogsled // ૮ ˶′ﻌ ‵˶ ა
srv, _, _, _ := setup(t, false, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel() defer cancel()
fs := newFakeStream(ctx) fs := newFakeStream(ctx)
@ -95,6 +97,46 @@ func TestAcquireJobWithCancel_Cancel(t *testing.T) {
require.Equal(t, "", job.JobId) require.Equal(t, "", job.JobId)
} }
func TestHeartbeat(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
heartbeatChan := make(chan struct{})
heartbeatFn := func(hbCtx context.Context) error {
t.Logf("heartbeat")
select {
case <-hbCtx.Done():
return hbCtx.Err()
default:
heartbeatChan <- struct{}{}
return nil
}
}
//nolint:dogsled // 。:゚૮ ˶ˆ ﻌ ˆ˶ ა ゚:。
_, _, _, _ = setup(t, false, &overrides{
ctx: ctx,
heartbeatFn: heartbeatFn,
heartbeatInterval: testutil.IntervalFast,
})
_, ok := <-heartbeatChan
require.True(t, ok, "first heartbeat not received")
_, ok = <-heartbeatChan
require.True(t, ok, "second heartbeat not received")
cancel()
// Close the channel to ensure we don't receive any more heartbeats.
// The test will fail if we do.
defer func() {
if r := recover(); r != nil {
t.Fatalf("heartbeat received after cancel: %v", r)
}
}()
close(heartbeatChan)
<-time.After(testutil.IntervalMedium)
}
func TestAcquireJob(t *testing.T) { func TestAcquireJob(t *testing.T) {
t.Parallel() t.Parallel()
@ -120,7 +162,7 @@ func TestAcquireJob(t *testing.T) {
tc := tc tc := tc
t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) { t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _ := setup(t, false, nil) srv, db, _, _ := setup(t, false, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel() defer cancel()
_, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{ _, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
@ -141,7 +183,7 @@ func TestAcquireJob(t *testing.T) {
// deployment config. // deployment config.
dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)} dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)}
gitAuthProvider := "github" gitAuthProvider := "github"
srv, db, ps := setup(t, false, &overrides{ srv, db, ps, _ := setup(t, false, &overrides{
deploymentValues: dv, deploymentValues: dv,
externalAuthConfigs: []*externalauth.Config{{ externalAuthConfigs: []*externalauth.Config{{
ID: gitAuthProvider, ID: gitAuthProvider,
@ -359,7 +401,7 @@ func TestAcquireJob(t *testing.T) {
t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) { t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, ps := setup(t, false, nil) srv, db, ps, _ := setup(t, false, nil)
ctx := context.Background() ctx := context.Background()
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
@ -396,7 +438,7 @@ func TestAcquireJob(t *testing.T) {
}) })
t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) { t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, ps := setup(t, false, nil) srv, db, ps, _ := setup(t, false, nil)
ctx := context.Background() ctx := context.Background()
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
@ -427,7 +469,7 @@ func TestAcquireJob(t *testing.T) {
}) })
t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) { t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, ps := setup(t, false, nil) srv, db, ps, _ := setup(t, false, nil)
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
@ -476,7 +518,7 @@ func TestUpdateJob(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, _, _ := setup(t, false, nil) srv, _, _, _ := setup(t, false, nil)
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: "hello", JobId: "hello",
}) })
@ -489,7 +531,7 @@ func TestUpdateJob(t *testing.T) {
}) })
t.Run("NotRunning", func(t *testing.T) { t.Run("NotRunning", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _ := setup(t, false, nil) srv, db, _, _ := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
@ -505,7 +547,7 @@ func TestUpdateJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own! // This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) { t.Run("NotOwner", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _ := setup(t, false, nil) srv, db, _, _ := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
@ -548,9 +590,8 @@ func TestUpdateJob(t *testing.T) {
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New() srv, db, _, pd := setup(t, false, &overrides{})
srv, db, _ := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, pd.ID)
job := setupJob(t, db, srvID)
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.String(), JobId: job.String(),
}) })
@ -559,9 +600,8 @@ func TestUpdateJob(t *testing.T) {
t.Run("Logs", func(t *testing.T) { t.Run("Logs", func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New() srv, db, ps, pd := setup(t, false, &overrides{})
srv, db, ps := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, pd.ID)
job := setupJob(t, db, srvID)
published := make(chan struct{}) published := make(chan struct{})
@ -585,9 +625,8 @@ func TestUpdateJob(t *testing.T) {
}) })
t.Run("Readme", func(t *testing.T) { t.Run("Readme", func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New() srv, db, _, pd := setup(t, false, &overrides{})
srv, db, _ := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, pd.ID)
job := setupJob(t, db, srvID)
versionID := uuid.New() versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID, ID: versionID,
@ -612,9 +651,8 @@ func TestUpdateJob(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
srvID := uuid.New() srv, db, _, pd := setup(t, false, &overrides{})
srv, db, _ := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, pd.ID)
job := setupJob(t, db, srvID)
versionID := uuid.New() versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID, ID: versionID,
@ -660,9 +698,8 @@ func TestUpdateJob(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
srvID := uuid.New() srv, db, _, pd := setup(t, false, &overrides{})
srv, db, _ := setup(t, false, &overrides{id: &srvID}) job := setupJob(t, db, pd.ID)
job := setupJob(t, db, srvID)
versionID := uuid.New() versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID, ID: versionID,
@ -707,7 +744,7 @@ func TestFailJob(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, _, _ := setup(t, false, nil) srv, _, _, _ := setup(t, false, nil)
_, err := srv.FailJob(ctx, &proto.FailedJob{ _, err := srv.FailJob(ctx, &proto.FailedJob{
JobId: "hello", JobId: "hello",
}) })
@ -721,7 +758,7 @@ func TestFailJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own! // This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) { t.Run("NotOwner", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _ := setup(t, false, nil) srv, db, _, _ := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
@ -744,8 +781,7 @@ func TestFailJob(t *testing.T) {
}) })
t.Run("AlreadyCompleted", func(t *testing.T) { t.Run("AlreadyCompleted", func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New() srv, db, _, pd := setup(t, false, &overrides{})
srv, db, _ := setup(t, false, &overrides{id: &srvID})
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
@ -755,7 +791,7 @@ func TestFailJob(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{ WorkerID: uuid.NullUUID{
UUID: srvID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -780,8 +816,7 @@ func TestFailJob(t *testing.T) {
// //
// (*Server).FailJob audit log - get build {"error": "sql: no rows in result set"} // (*Server).FailJob audit log - get build {"error": "sql: no rows in result set"}
ignoreLogErrors := true ignoreLogErrors := true
srvID := uuid.New() srv, db, ps, pd := setup(t, ignoreLogErrors, &overrides{})
srv, db, ps := setup(t, ignoreLogErrors, &overrides{id: &srvID})
workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{ workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{
ID: uuid.New(), ID: uuid.New(),
AutomaticUpdates: database.AutomaticUpdatesNever, AutomaticUpdates: database.AutomaticUpdatesNever,
@ -810,7 +845,7 @@ func TestFailJob(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{ WorkerID: uuid.NullUUID{
UUID: srvID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -852,7 +887,7 @@ func TestCompleteJob(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, _, _ := setup(t, false, nil) srv, _, _, _ := setup(t, false, nil)
_, err := srv.CompleteJob(ctx, &proto.CompletedJob{ _, err := srv.CompleteJob(ctx, &proto.CompletedJob{
JobId: "hello", JobId: "hello",
}) })
@ -866,7 +901,7 @@ func TestCompleteJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own! // This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) { t.Run("NotOwner", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _ := setup(t, false, nil) srv, db, _, _ := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
@ -890,8 +925,7 @@ func TestCompleteJob(t *testing.T) {
t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) { t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New() srv, db, _, pd := setup(t, false, &overrides{})
srv, db, _ := setup(t, false, &overrides{id: &srvID})
jobID := uuid.New() jobID := uuid.New()
versionID := uuid.New() versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
@ -909,7 +943,7 @@ func TestCompleteJob(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{ WorkerID: uuid.NullUUID{
UUID: srvID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -939,9 +973,7 @@ func TestCompleteJob(t *testing.T) {
t.Run("TemplateImport_WithGitAuth", func(t *testing.T) { t.Run("TemplateImport_WithGitAuth", func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New() srv, db, _, pd := setup(t, false, &overrides{
srv, db, _ := setup(t, false, &overrides{
id: &srvID,
externalAuthConfigs: []*externalauth.Config{{ externalAuthConfigs: []*externalauth.Config{{
ID: "github", ID: "github",
}}, }},
@ -963,7 +995,7 @@ func TestCompleteJob(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{ WorkerID: uuid.NullUUID{
UUID: srvID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -1106,9 +1138,8 @@ func TestCompleteJob(t *testing.T) {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New()
tss := &atomic.Pointer[schedule.TemplateScheduleStore]{} tss := &atomic.Pointer[schedule.TemplateScheduleStore]{}
srv, db, ps := setup(t, false, &overrides{id: &srvID, templateScheduleStore: tss}) srv, db, ps, pd := setup(t, false, &overrides{templateScheduleStore: tss})
var store schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{ var store schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{
GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) { GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) {
@ -1123,10 +1154,19 @@ func TestCompleteJob(t *testing.T) {
} }
tss.Store(&store) tss.Store(&store)
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
template := dbgen.Template(t, db, database.Template{ template := dbgen.Template(t, db, database.Template{
Name: "template", Name: "template",
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
OrganizationID: org.ID,
})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
JobID: uuid.New(),
}) })
err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{ err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{
ID: template.ID, ID: template.ID,
@ -1148,13 +1188,6 @@ func TestCompleteJob(t *testing.T) {
TemplateID: template.ID, TemplateID: template.ID,
Ttl: workspaceTTL, Ttl: workspaceTTL,
}) })
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
JobID: uuid.New(),
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID, WorkspaceID: workspace.ID,
TemplateVersionID: version.ID, TemplateVersionID: version.ID,
@ -1170,7 +1203,7 @@ func TestCompleteJob(t *testing.T) {
}) })
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{ WorkerID: uuid.NullUUID{
UUID: srvID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -1315,19 +1348,17 @@ func TestCompleteJob(t *testing.T) {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New()
// Simulate the given time starting from now. // Simulate the given time starting from now.
require.False(t, c.now.IsZero()) require.False(t, c.now.IsZero())
start := time.Now() start := time.Now()
tss := &atomic.Pointer[schedule.TemplateScheduleStore]{} tss := &atomic.Pointer[schedule.TemplateScheduleStore]{}
uqhss := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{} uqhss := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{}
srv, db, ps := setup(t, false, &overrides{ srv, db, ps, pd := setup(t, false, &overrides{
timeNowFn: func() time.Time { timeNowFn: func() time.Time {
return c.now.Add(time.Since(start)) return c.now.Add(time.Since(start))
}, },
templateScheduleStore: tss, templateScheduleStore: tss,
userQuietHoursScheduleStore: uqhss, userQuietHoursScheduleStore: uqhss,
id: &srvID,
}) })
var templateScheduleStore schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{ var templateScheduleStore schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{
@ -1418,7 +1449,7 @@ func TestCompleteJob(t *testing.T) {
}) })
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{ WorkerID: uuid.NullUUID{
UUID: srvID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -1484,8 +1515,7 @@ func TestCompleteJob(t *testing.T) {
}) })
t.Run("TemplateDryRun", func(t *testing.T) { t.Run("TemplateDryRun", func(t *testing.T) {
t.Parallel() t.Parallel()
srvID := uuid.New() srv, db, _, pd := setup(t, false, &overrides{})
srv, db, _ := setup(t, false, &overrides{id: &srvID})
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
@ -1495,7 +1525,7 @@ func TestCompleteJob(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{ WorkerID: uuid.NullUUID{
UUID: srvID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -1686,73 +1716,89 @@ func TestInsertWorkspaceResource(t *testing.T) {
} }
type overrides struct { type overrides struct {
ctx context.Context
deploymentValues *codersdk.DeploymentValues deploymentValues *codersdk.DeploymentValues
externalAuthConfigs []*externalauth.Config externalAuthConfigs []*externalauth.Config
id *uuid.UUID
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
timeNowFn func() time.Time timeNowFn func() time.Time
acquireJobLongPollDuration time.Duration acquireJobLongPollDuration time.Duration
heartbeatFn func(ctx context.Context) error
heartbeatInterval time.Duration
} }
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) { func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub, database.ProvisionerDaemon) {
t.Helper() t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
db := dbmem.New() db := dbmem.New()
ps := pubsub.NewInMemory() ps := pubsub.NewInMemory()
deploymentValues := &codersdk.DeploymentValues{} deploymentValues := &codersdk.DeploymentValues{}
var externalAuthConfigs []*externalauth.Config var externalAuthConfigs []*externalauth.Config
srvID := uuid.New()
tss := testTemplateScheduleStore() tss := testTemplateScheduleStore()
uqhss := testUserQuietHoursScheduleStore() uqhss := testUserQuietHoursScheduleStore()
var timeNowFn func() time.Time var timeNowFn func() time.Time
pollDur := time.Duration(0) pollDur := time.Duration(0)
if ov != nil { if ov == nil {
if ov.deploymentValues != nil { ov = &overrides{}
deploymentValues = ov.deploymentValues
}
if ov.externalAuthConfigs != nil {
externalAuthConfigs = ov.externalAuthConfigs
}
if ov.id != nil {
srvID = *ov.id
}
if ov.templateScheduleStore != nil {
ttss := tss.Load()
// keep the initial test value if the override hasn't set the atomic pointer.
tss = ov.templateScheduleStore
if tss.Load() == nil {
swapped := tss.CompareAndSwap(nil, ttss)
require.True(t, swapped)
}
}
if ov.userQuietHoursScheduleStore != nil {
tuqhss := uqhss.Load()
// keep the initial test value if the override hasn't set the atomic pointer.
uqhss = ov.userQuietHoursScheduleStore
if uqhss.Load() == nil {
swapped := uqhss.CompareAndSwap(nil, tuqhss)
require.True(t, swapped)
}
}
if ov.timeNowFn != nil {
timeNowFn = ov.timeNowFn
}
pollDur = ov.acquireJobLongPollDuration
} }
if ov.ctx == nil {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
ov.ctx = ctx
}
if ov.heartbeatInterval == 0 {
ov.heartbeatInterval = testutil.IntervalMedium
}
if ov.deploymentValues != nil {
deploymentValues = ov.deploymentValues
}
if ov.externalAuthConfigs != nil {
externalAuthConfigs = ov.externalAuthConfigs
}
if ov.templateScheduleStore != nil {
ttss := tss.Load()
// keep the initial test value if the override hasn't set the atomic pointer.
tss = ov.templateScheduleStore
if tss.Load() == nil {
swapped := tss.CompareAndSwap(nil, ttss)
require.True(t, swapped)
}
}
if ov.userQuietHoursScheduleStore != nil {
tuqhss := uqhss.Load()
// keep the initial test value if the override hasn't set the atomic pointer.
uqhss = ov.userQuietHoursScheduleStore
if uqhss.Load() == nil {
swapped := uqhss.CompareAndSwap(nil, tuqhss)
require.True(t, swapped)
}
}
if ov.timeNowFn != nil {
timeNowFn = ov.timeNowFn
}
pollDur = ov.acquireJobLongPollDuration
daemon, err := db.UpsertProvisionerDaemon(ov.ctx, database.UpsertProvisionerDaemonParams{
Name: "test",
CreatedAt: dbtime.Now(),
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: database.StringMap{},
LastSeenAt: sql.NullTime{},
Version: "",
APIVersion: "1.0",
})
require.NoError(t, err)
srv, err := provisionerdserver.NewServer( srv, err := provisionerdserver.NewServer(
ctx, ov.ctx,
&url.URL{}, &url.URL{},
srvID, daemon.ID,
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}), slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
[]database.ProvisionerType{database.ProvisionerTypeEcho}, []database.ProvisionerType{database.ProvisionerTypeEcho},
provisionerdserver.Tags{}, provisionerdserver.Tags(daemon.Tags),
db, db,
ps, ps,
provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps), provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps),
telemetry.NewNoop(), telemetry.NewNoop(),
trace.NewNoopTracerProvider().Tracer("noop"), trace.NewNoopTracerProvider().Tracer("noop"),
&atomic.Pointer[proto.QuotaCommitter]{}, &atomic.Pointer[proto.QuotaCommitter]{},
@ -1765,10 +1811,12 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
TimeNowFn: timeNowFn, TimeNowFn: timeNowFn,
OIDCConfig: &oauth2.Config{}, OIDCConfig: &oauth2.Config{},
AcquireJobLongPollDur: pollDur, AcquireJobLongPollDur: pollDur,
HeartbeatInterval: ov.heartbeatInterval,
HeartbeatFn: ov.heartbeatFn,
}, },
) )
require.NoError(t, err) require.NoError(t, err)
return srv, db, ps return srv, db, ps, daemon
} }
func must[T any](value T, err error) T { func must[T any](value T, err error) T {

View File

@ -155,6 +155,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
// Users cannot do create/update/delete on themselves, but they // Users cannot do create/update/delete on themselves, but they
// can read their own details. // can read their own details.
ResourceUser.Type: {ActionRead}, ResourceUser.Type: {ActionRead},
// Users can create provisioner daemons scoped to themselves.
ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate},
})..., })...,
), ),
}.withCachedRegoValue() }.withCachedRegoValue()

View File

@ -164,9 +164,6 @@ func (c *Client) Organization(ctx context.Context, id uuid.UUID) (Organization,
} }
// ProvisionerDaemons returns provisioner daemons available. // ProvisionerDaemons returns provisioner daemons available.
//
// Deprecated: We no longer track provisioner daemons as they connect. This function may return historical data
// but new provisioner daemons will not appear.
func (c *Client) ProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) { func (c *Client) ProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) {
res, err := c.Request(ctx, http.MethodGet, res, err := c.Request(ctx, http.MethodGet,
// TODO: the organization path parameter is currently ignored. // TODO: the organization path parameter is currently ignored.

View File

@ -4,13 +4,16 @@ import (
"context" "context"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -35,6 +38,17 @@ func TestProvisionerDaemon_PSK(t *testing.T) {
clitest.Start(t, inv) clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "starting provisioner daemon") pty.ExpectMatchContext(ctx, "starting provisioner daemon")
pty.ExpectMatchContext(ctx, "matt-daemon") pty.ExpectMatchContext(ctx, "matt-daemon")
var daemons []codersdk.ProvisionerDaemon
require.Eventually(t, func() bool {
daemons, err = client.ProvisionerDaemons(ctx)
if err != nil {
return false
}
return len(daemons) == 1
}, testutil.WaitShort, testutil.IntervalSlow)
require.Equal(t, "matt-daemon", daemons[0].Name)
require.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
} }
func TestProvisionerDaemon_SessionToken(t *testing.T) { func TestProvisionerDaemon_SessionToken(t *testing.T) {
@ -49,14 +63,61 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) {
}, },
}, },
}) })
anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID) anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user") inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user", "--name", "my-daemon")
clitest.SetupConfig(t, anotherClient, conf) clitest.SetupConfig(t, anotherClient, conf)
pty := ptytest.New(t).Attach(inv) pty := ptytest.New(t).Attach(inv)
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong) ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
defer cancel() defer cancel()
clitest.Start(t, inv) clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "starting provisioner daemon") pty.ExpectMatchContext(ctx, "starting provisioner daemon")
var daemons []codersdk.ProvisionerDaemon
var err error
require.Eventually(t, func() bool {
daemons, err = client.ProvisionerDaemons(ctx)
if err != nil {
return false
}
return len(daemons) == 1
}, testutil.WaitShort, testutil.IntervalSlow)
assert.Equal(t, "my-daemon", daemons[0].Name)
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope])
assert.Equal(t, anotherUser.ID.String(), daemons[0].Tags[provisionersdk.TagOwner])
})
t.Run("ScopeAnotherUser", func(t *testing.T) {
t.Parallel()
client, admin := coderdenttest.New(t, &coderdenttest.Options{
ProvisionerDaemonPSK: "provisionersftw",
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
})
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user", "--tag", "owner="+admin.UserID.String(), "--name", "my-daemon")
clitest.SetupConfig(t, anotherClient, conf)
pty := ptytest.New(t).Attach(inv)
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
defer cancel()
clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
var daemons []codersdk.ProvisionerDaemon
var err error
require.Eventually(t, func() bool {
daemons, err = client.ProvisionerDaemons(ctx)
if err != nil {
return false
}
return len(daemons) == 1
}, testutil.WaitShort, testutil.IntervalSlow)
assert.Equal(t, "my-daemon", daemons[0].Name)
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope])
// This should get clobbered to the user who started the daemon.
assert.Equal(t, anotherUser.ID.String(), daemons[0].Tags[provisionersdk.TagOwner])
}) })
t.Run("ScopeOrg", func(t *testing.T) { t.Run("ScopeOrg", func(t *testing.T) {
@ -69,13 +130,25 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) {
}, },
}, },
}) })
anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID) anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID, rbac.RoleTemplateAdmin())
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=organization") inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=organization", "--name", "org-daemon")
clitest.SetupConfig(t, anotherClient, conf) clitest.SetupConfig(t, anotherClient, conf)
pty := ptytest.New(t).Attach(inv) pty := ptytest.New(t).Attach(inv)
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong) ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
defer cancel() defer cancel()
clitest.Start(t, inv) clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "starting provisioner daemon") pty.ExpectMatchContext(ctx, "starting provisioner daemon")
var daemons []codersdk.ProvisionerDaemon
var err error
require.Eventually(t, func() bool {
daemons, err = client.ProvisionerDaemons(ctx)
if err != nil {
return false
}
return len(daemons) == 1
}, testutil.WaitShort, testutil.IntervalSlow)
assert.Equal(t, "org-daemon", daemons[0].Name)
assert.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
}) })
} }

View File

@ -26,6 +26,8 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/provisionerdserver"
@ -121,7 +123,7 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK) psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 { if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
// If using PSK auth, the daemon is, by definition, scoped to the organization. // If using PSK auth, the daemon is, by definition, scoped to the organization.
tags[provisionersdk.TagScope] = provisionersdk.ScopeOrganization tags = provisionersdk.MutateTags(uuid.Nil, tags)
return tags, true return tags, true
} }
} }
@ -191,7 +193,10 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
if !authorized { if !authorized {
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags)) api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags))
httpapi.Write(ctx, rw, http.StatusForbidden, httpapi.Write(ctx, rw, http.StatusForbidden,
codersdk.Response{Message: "You aren't allowed to create provisioner daemons"}) codersdk.Response{
Message: fmt.Sprintf("You aren't allowed to create provisioner daemons with scope %q", tags[provisionersdk.TagScope]),
},
)
return return
} }
api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags)) api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags))
@ -221,6 +226,35 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
slog.F("tags", tags), slog.F("tags", tags),
) )
authCtx := ctx
if r.Header.Get(codersdk.ProvisionerDaemonPSK) != "" {
//nolint:gocritic // PSK auth means no actor in request,
// so use system restricted.
authCtx = dbauthz.AsSystemRestricted(ctx)
}
// Create the daemon in the database.
now := dbtime.Now()
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
Name: name,
Provisioners: provisioners,
Tags: tags,
CreatedAt: now,
LastSeenAt: sql.NullTime{Time: now, Valid: true},
Version: "", // TODO: provisionerd needs to send version
APIVersion: "1.0",
})
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "create provisioner daemon", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error creating provisioner daemon.",
Detail: err.Error(),
})
}
return
}
api.AGPL.WebsocketWaitMutex.Lock() api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1) api.AGPL.WebsocketWaitGroup.Add(1)
api.AGPL.WebsocketWaitMutex.Unlock() api.AGPL.WebsocketWaitMutex.Unlock()
@ -264,11 +298,13 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
} }
mux := drpcmux.New() mux := drpcmux.New()
logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name)) logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name))
srvCtx, srvCancel := context.WithCancel(ctx)
defer srvCancel()
logger.Info(ctx, "starting external provisioner daemon") logger.Info(ctx, "starting external provisioner daemon")
srv, err := provisionerdserver.NewServer( srv, err := provisionerdserver.NewServer(
api.ctx, srvCtx,
api.AccessURL, api.AccessURL,
id, daemon.ID,
logger, logger,
provisioners, provisioners,
tags, tags,
@ -308,6 +344,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}, },
}) })
err = server.Serve(ctx, session) err = server.Serve(ctx, session)
srvCancel()
logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err)) logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err))
if err != nil && !xerrors.Is(err, io.EOF) { if err != nil && !xerrors.Is(err, io.EOF) {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"cdr.dev/slog" "cdr.dev/slog"
@ -50,6 +51,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
srv.DRPCConn().Close() srv.DRPCConn().Close()
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
require.Len(t, daemons, 1)
}) })
t.Run("NoLicense", func(t *testing.T) { t.Run("NoLicense", func(t *testing.T) {
@ -163,6 +168,15 @@ func TestProvisionerDaemonServe(t *testing.T) {
file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, bytes.NewReader(data)) file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, bytes.NewReader(data))
require.NoError(t, err) require.NoError(t, err)
require.Eventually(t, func() bool {
daemons, err := client.ProvisionerDaemons(context.Background())
assert.NoError(t, err, "failed to get provisioner daemons")
return len(daemons) > 0 &&
assert.Equal(t, t.Name(), daemons[0].Name) &&
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope]) &&
assert.Equal(t, user.UserID.String(), daemons[0].Tags[provisionersdk.TagOwner])
}, testutil.WaitShort, testutil.IntervalMedium)
version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{ version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{
Name: "example", Name: "example",
StorageMethod: codersdk.ProvisionerStorageMethodFile, StorageMethod: codersdk.ProvisionerStorageMethodFile,
@ -211,6 +225,12 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
err = srv.DRPCConn().Close() err = srv.DRPCConn().Close()
require.NoError(t, err) require.NoError(t, err)
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
if assert.Len(t, daemons, 1) {
assert.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
}
}) })
t.Run("PSK_daily_cost", func(t *testing.T) { t.Run("PSK_daily_cost", func(t *testing.T) {
@ -346,6 +366,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
var apiError *codersdk.Error var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError) require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode()) require.Equal(t, http.StatusForbidden, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
require.Len(t, daemons, 0)
}) })
t.Run("NoAuth", func(t *testing.T) { t.Run("NoAuth", func(t *testing.T) {
@ -376,6 +400,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
var apiError *codersdk.Error var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError) require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode()) require.Equal(t, http.StatusForbidden, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
require.Len(t, daemons, 0)
}) })
t.Run("NoPSK", func(t *testing.T) { t.Run("NoPSK", func(t *testing.T) {
@ -406,5 +434,9 @@ func TestProvisionerDaemonServe(t *testing.T) {
var apiError *codersdk.Error var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError) require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode()) require.Equal(t, http.StatusForbidden, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
require.Len(t, daemons, 0)
}) })
} }

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http"
"reflect" "reflect"
"sync" "sync"
"time" "time"
@ -20,6 +21,7 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionerd/runner" "github.com/coder/coder/v2/provisionerd/runner"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto" sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
@ -199,6 +201,12 @@ connectLoop:
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return return
} }
var sdkErr *codersdk.Error
// If something is wrong with our auth, stop trying to connect.
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusForbidden {
p.opts.Logger.Error(p.closeContext, "not authorized to dial coderd", slog.Error(err))
return
}
if p.isClosed() { if p.isClosed() {
return return
} }

View File

@ -14,7 +14,10 @@ const (
// If the scope is "user", the "owner" is changed to the user ID. // If the scope is "user", the "owner" is changed to the user ID.
// This is for user-scoped provisioner daemons, where users should // This is for user-scoped provisioner daemons, where users should
// own their own operations. // own their own operations.
// Otherwise, the "owner" tag is always empty. // Otherwise, the "owner" tag is always an empty string.
// NOTE: "owner" must NEVER be nil. Otherwise it will end up being
// duplicated in the database, as idx_provisioner_daemons_name_owner_key
// is a partial unique index that includes a JSON field.
func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string { func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
if tags == nil { if tags == nil {
tags = map[string]string{} tags = map[string]string{}
@ -22,15 +25,16 @@ func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
_, ok := tags[TagScope] _, ok := tags[TagScope]
if !ok { if !ok {
tags[TagScope] = ScopeOrganization tags[TagScope] = ScopeOrganization
delete(tags, TagOwner) tags[TagOwner] = ""
} }
switch tags[TagScope] { switch tags[TagScope] {
case ScopeUser: case ScopeUser:
tags[TagOwner] = userID.String() tags[TagOwner] = userID.String()
case ScopeOrganization: case ScopeOrganization:
delete(tags, TagOwner) tags[TagOwner] = ""
default: default:
tags[TagScope] = ScopeOrganization tags[TagScope] = ScopeOrganization
tags[TagOwner] = ""
} }
return tags return tags
} }

View File

@ -27,6 +27,7 @@ func TestMutateTags(t *testing.T) {
tags: nil, tags: nil,
want: map[string]string{ want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization, provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
}, },
}, },
{ {
@ -35,6 +36,7 @@ func TestMutateTags(t *testing.T) {
tags: map[string]string{}, tags: map[string]string{},
want: map[string]string{ want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization, provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
}, },
}, },
{ {
@ -52,6 +54,7 @@ func TestMutateTags(t *testing.T) {
userID: testUserID, userID: testUserID,
want: map[string]string{ want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization, provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
}, },
}, },
{ {
@ -63,6 +66,7 @@ func TestMutateTags(t *testing.T) {
userID: uuid.Nil, userID: uuid.Nil,
want: map[string]string{ want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization, provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
}, },
}, },
{ {
@ -73,6 +77,7 @@ func TestMutateTags(t *testing.T) {
userID: uuid.Nil, userID: uuid.Nil,
want: map[string]string{ want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization, provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
}, },
}, },
{ {
@ -81,6 +86,7 @@ func TestMutateTags(t *testing.T) {
userID: testUserID, userID: testUserID,
want: map[string]string{ want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization, provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
}, },
}, },
} { } {