mirror of https://github.com/coder/coder.git
feat(coderd): insert provisioner daemons (#11207)
* Adds UpdateProvisionerDaemonLastSeenAt * Adds heartbeat to provisioner daemons * Inserts provisioner daemons to database upon start * Ensures TagOwner is an empty string and not nil * Adds COALESCE() in idx_provisioner_daemons_name_owner_key
This commit is contained in:
parent
a6901ae2c5
commit
213b768785
|
@ -36,6 +36,7 @@
|
|||
"worker_id": "[workspace build worker ID]",
|
||||
"file_id": "[workspace build file ID]",
|
||||
"tags": {
|
||||
"owner": "",
|
||||
"scope": "organization"
|
||||
},
|
||||
"queue_position": 0,
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -49,6 +50,7 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/batchstats"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/gitsshkey"
|
||||
|
@ -1178,22 +1180,32 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string
|
|||
}
|
||||
}()
|
||||
|
||||
tags := provisionerdserver.Tags{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
//nolint:gocritic // in-memory provisioners are owned by system
|
||||
daemon, err := api.Database.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{
|
||||
Name: name,
|
||||
CreatedAt: dbtime.Now(),
|
||||
Provisioners: []database.ProvisionerType{
|
||||
database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform,
|
||||
},
|
||||
Tags: provisionersdk.MutateTags(uuid.Nil, nil),
|
||||
LastSeenAt: sql.NullTime{Time: dbtime.Now(), Valid: true},
|
||||
Version: buildinfo.Version(),
|
||||
APIVersion: "1.0",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to create in-memory provisioner daemon: %w", err)
|
||||
}
|
||||
|
||||
mux := drpcmux.New()
|
||||
api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name))
|
||||
logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name))
|
||||
srv, err := provisionerdserver.NewServer(
|
||||
api.ctx,
|
||||
api.ctx, // use the same ctx as the API
|
||||
api.AccessURL,
|
||||
uuid.New(),
|
||||
daemon.ID,
|
||||
logger,
|
||||
[]database.ProvisionerType{
|
||||
database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform,
|
||||
},
|
||||
tags,
|
||||
daemon.Provisioners,
|
||||
provisionerdserver.Tags(daemon.Tags),
|
||||
api.Database,
|
||||
api.Pubsub,
|
||||
api.Acquirer,
|
||||
|
|
|
@ -533,7 +533,7 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer {
|
|||
}()
|
||||
|
||||
daemon := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
|
||||
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, t.Name())
|
||||
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, "test")
|
||||
}, &provisionerd.Options{
|
||||
Logger: coderAPI.Logger.Named("provisionerd").Leveled(slog.LevelDebug),
|
||||
UpdateInterval: 250 * time.Millisecond,
|
||||
|
|
|
@ -232,6 +232,7 @@ var (
|
|||
rbac.ResourceOrganization.Type: {rbac.ActionCreate},
|
||||
rbac.ResourceOrganizationMember.Type: {rbac.ActionCreate},
|
||||
rbac.ResourceOrgRoleAssignment.Type: {rbac.ActionCreate},
|
||||
rbac.ResourceProvisionerDaemon.Type: {rbac.ActionCreate, rbac.ActionUpdate},
|
||||
rbac.ResourceUser.Type: {rbac.ActionCreate, rbac.ActionUpdate, rbac.ActionDelete},
|
||||
rbac.ResourceUserData.Type: {rbac.ActionCreate, rbac.ActionUpdate},
|
||||
rbac.ResourceWorkspace.Type: {rbac.ActionUpdate},
|
||||
|
@ -2499,6 +2500,13 @@ func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemb
|
|||
return q.db.UpdateMemberRoles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceProvisionerDaemon); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpdateProvisionerDaemonLastSeenAt(ctx, arg)
|
||||
}
|
||||
|
||||
// TODO: We need to create a ProvisionerJob resource type
|
||||
func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
|
||||
// if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
|
|
|
@ -1592,6 +1592,18 @@ func (s *MethodTestSuite) TestExtraMethods() {
|
|||
s.NoError(err, "insert provisioner daemon")
|
||||
check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete)
|
||||
}))
|
||||
s.Run("UpdateProvisionerDaemonLastSeenAt", s.Subtest(func(db database.Store, check *expects) {
|
||||
d, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{
|
||||
Tags: database.StringMap(map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
}),
|
||||
})
|
||||
s.NoError(err, "insert provisioner daemon")
|
||||
check.Args(database.UpdateProvisionerDaemonLastSeenAtParams{
|
||||
ID: d.ID,
|
||||
LastSeenAt: sql.NullTime{Time: dbtime.Now(), Valid: true},
|
||||
}).Asserts(rbac.ResourceProvisionerDaemon, rbac.ActionUpdate)
|
||||
}))
|
||||
}
|
||||
|
||||
// All functions in this method test suite are not implemented in dbmem, but
|
||||
|
|
|
@ -28,7 +28,7 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
)
|
||||
|
||||
var errMatchAny = errors.New("match any error")
|
||||
var errMatchAny = xerrors.New("match any error")
|
||||
|
||||
var skipMethods = map[string]string{
|
||||
"InTx": "Not relevant",
|
||||
|
|
|
@ -5945,6 +5945,28 @@ func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMe
|
|||
return database.OrganizationMember{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) UpdateProvisionerDaemonLastSeenAt(_ context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
|
||||
err := validateDatabaseType(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for idx := range q.provisionerDaemons {
|
||||
if q.provisionerDaemons[idx].ID != arg.ID {
|
||||
continue
|
||||
}
|
||||
if q.provisionerDaemons[idx].LastSeenAt.Time.After(arg.LastSeenAt.Time) {
|
||||
continue
|
||||
}
|
||||
q.provisionerDaemons[idx].LastSeenAt = arg.LastSeenAt
|
||||
return nil
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return err
|
||||
|
|
|
@ -1593,6 +1593,13 @@ func (m metricsStore) UpdateMemberRoles(ctx context.Context, arg database.Update
|
|||
return member, err
|
||||
}
|
||||
|
||||
func (m metricsStore) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateProvisionerDaemonLastSeenAt(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateProvisionerDaemonLastSeenAt").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
|
||||
start := time.Now()
|
||||
err := m.s.UpdateProvisionerJobByID(ctx, arg)
|
||||
|
|
|
@ -3362,6 +3362,20 @@ func (mr *MockStoreMockRecorder) UpdateMemberRoles(arg0, arg1 interface{}) *gomo
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMemberRoles", reflect.TypeOf((*MockStore)(nil).UpdateMemberRoles), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpdateProvisionerDaemonLastSeenAt mocks base method.
|
||||
func (m *MockStore) UpdateProvisionerDaemonLastSeenAt(arg0 context.Context, arg1 database.UpdateProvisionerDaemonLastSeenAtParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateProvisionerDaemonLastSeenAt", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateProvisionerDaemonLastSeenAt indicates an expected call of UpdateProvisionerDaemonLastSeenAt.
|
||||
func (mr *MockStoreMockRecorder) UpdateProvisionerDaemonLastSeenAt(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProvisionerDaemonLastSeenAt", reflect.TypeOf((*MockStore)(nil).UpdateProvisionerDaemonLastSeenAt), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpdateProvisionerJobByID mocks base method.
|
||||
func (m *MockStore) UpdateProvisionerJobByID(arg0 context.Context, arg1 database.UpdateProvisionerJobByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -1417,9 +1417,9 @@ CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name);
|
|||
|
||||
CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name));
|
||||
|
||||
CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower((tags ->> 'owner'::text)));
|
||||
CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower(COALESCE((tags ->> 'owner'::text), ''::text)));
|
||||
|
||||
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key IS 'Relax uniqueness constraint for provisioner daemon names';
|
||||
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key IS 'Allow unique provisioner daemon names by user';
|
||||
|
||||
CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id);
|
||||
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
DROP INDEX IF EXISTS idx_provisioner_daemons_name_owner_key;
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_provisioner_daemons_name_owner_key
|
||||
ON provisioner_daemons
|
||||
USING btree (name, lower((tags->>'owner')::text));
|
||||
|
||||
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key
|
||||
IS 'Relax uniqueness constraint for provisioner daemon names';
|
|
@ -0,0 +1,8 @@
|
|||
DROP INDEX IF EXISTS idx_provisioner_daemons_name_owner_key;
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_provisioner_daemons_name_owner_key
|
||||
ON provisioner_daemons
|
||||
USING btree (name, LOWER(COALESCE(tags->>'owner', '')::text));
|
||||
|
||||
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key
|
||||
IS 'Allow unique provisioner daemon names by user';
|
|
@ -318,6 +318,7 @@ type sqlcQuerier interface {
|
|||
UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error)
|
||||
UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error)
|
||||
UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error)
|
||||
UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error
|
||||
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
|
||||
UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error
|
||||
UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error
|
||||
|
|
|
@ -3057,6 +3057,26 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
|
|||
return items, nil
|
||||
}
|
||||
|
||||
const updateProvisionerDaemonLastSeenAt = `-- name: UpdateProvisionerDaemonLastSeenAt :exec
|
||||
UPDATE provisioner_daemons
|
||||
SET
|
||||
last_seen_at = $1
|
||||
WHERE
|
||||
id = $2
|
||||
AND
|
||||
last_seen_at <= $1
|
||||
`
|
||||
|
||||
type UpdateProvisionerDaemonLastSeenAtParams struct {
|
||||
LastSeenAt sql.NullTime `db:"last_seen_at" json:"last_seen_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateProvisionerDaemonLastSeenAt, arg.LastSeenAt, arg.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertProvisionerDaemon = `-- name: UpsertProvisionerDaemon :one
|
||||
INSERT INTO
|
||||
provisioner_daemons (
|
||||
|
@ -3078,7 +3098,7 @@ VALUES (
|
|||
$5,
|
||||
$6,
|
||||
$7
|
||||
) ON CONFLICT("name", lower((tags ->> 'owner'::text))) DO UPDATE SET
|
||||
) ON CONFLICT("name", LOWER(COALESCE(tags ->> 'owner'::text, ''::text))) DO UPDATE SET
|
||||
provisioners = $3,
|
||||
tags = $4,
|
||||
last_seen_at = $5,
|
||||
|
|
|
@ -35,7 +35,7 @@ VALUES (
|
|||
@last_seen_at,
|
||||
@version,
|
||||
@api_version
|
||||
) ON CONFLICT("name", lower((tags ->> 'owner'::text))) DO UPDATE SET
|
||||
) ON CONFLICT("name", LOWER(COALESCE(tags ->> 'owner'::text, ''::text))) DO UPDATE SET
|
||||
provisioners = @provisioners,
|
||||
tags = @tags,
|
||||
last_seen_at = @last_seen_at,
|
||||
|
@ -45,3 +45,12 @@ WHERE
|
|||
-- Only ones with the same tags are allowed clobber
|
||||
provisioner_daemons.tags <@ @tags :: jsonb
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateProvisionerDaemonLastSeenAt :exec
|
||||
UPDATE provisioner_daemons
|
||||
SET
|
||||
last_seen_at = @last_seen_at
|
||||
WHERE
|
||||
id = @id
|
||||
AND
|
||||
last_seen_at <= @last_seen_at;
|
||||
|
|
|
@ -65,7 +65,7 @@ const (
|
|||
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
|
||||
UniqueIndexOrganizationName UniqueConstraint = "idx_organization_name" // CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name);
|
||||
UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name));
|
||||
UniqueIndexProvisionerDaemonsNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower((tags ->> 'owner'::text)));
|
||||
UniqueIndexProvisionerDaemonsNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower(COALESCE((tags ->> 'owner'::text), ''::text)));
|
||||
UniqueIndexUsersEmail UniqueConstraint = "idx_users_email" // CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false);
|
||||
UniqueIndexUsersUsername UniqueConstraint = "idx_users_username" // CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false);
|
||||
UniqueTemplatesOrganizationIDNameIndex UniqueConstraint = "templates_organization_id_name_idx" // CREATE UNIQUE INDEX templates_organization_id_name_idx ON templates USING btree (organization_id, lower((name)::text)) WHERE (deleted = false);
|
||||
|
|
|
@ -44,9 +44,15 @@ import (
|
|||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
)
|
||||
|
||||
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
|
||||
// canceling and returning an empty job.
|
||||
const DefaultAcquireJobLongPollDur = time.Second * 5
|
||||
const (
|
||||
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
|
||||
// canceling and returning an empty job.
|
||||
DefaultAcquireJobLongPollDur = time.Second * 5
|
||||
|
||||
// DefaultHeartbeatInterval is the interval at which the provisioner daemon
|
||||
// will update its last seen at timestamp in the database.
|
||||
DefaultHeartbeatInterval = time.Minute
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
OIDCConfig httpmw.OAuth2Config
|
||||
|
@ -56,6 +62,16 @@ type Options struct {
|
|||
|
||||
// AcquireJobLongPollDur is used in tests
|
||||
AcquireJobLongPollDur time.Duration
|
||||
|
||||
// HeartbeatInterval is the interval at which the provisioner daemon
|
||||
// will update its last seen at timestamp in the database.
|
||||
HeartbeatInterval time.Duration
|
||||
|
||||
// HeartbeatFn is the function that will be called at the interval
|
||||
// specified by HeartbeatInterval.
|
||||
// The default function just calls UpdateProvisionerDaemonLastSeenAt.
|
||||
// This is mainly used for testing.
|
||||
HeartbeatFn func(context.Context) error
|
||||
}
|
||||
|
||||
type server struct {
|
||||
|
@ -85,6 +101,9 @@ type server struct {
|
|||
TimeNowFn func() time.Time
|
||||
|
||||
acquireJobLongPollDur time.Duration
|
||||
|
||||
heartbeatInterval time.Duration
|
||||
heartbeatFn func(ctx context.Context) error
|
||||
}
|
||||
|
||||
// We use the null byte (0x00) in generating a canonical map key for tags, so
|
||||
|
@ -161,7 +180,21 @@ func NewServer(
|
|||
if options.AcquireJobLongPollDur == 0 {
|
||||
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
|
||||
}
|
||||
return &server{
|
||||
if options.HeartbeatInterval == 0 {
|
||||
options.HeartbeatInterval = DefaultHeartbeatInterval
|
||||
}
|
||||
// Avoid a nil check in s.heartbeat.
|
||||
if options.HeartbeatFn == nil {
|
||||
options.HeartbeatFn = func(hbCtx context.Context) error {
|
||||
//nolint:gocritic // This is specifically for updating the last seen at timestamp.
|
||||
return db.UpdateProvisionerDaemonLastSeenAt(dbauthz.AsSystemRestricted(hbCtx), database.UpdateProvisionerDaemonLastSeenAtParams{
|
||||
ID: id,
|
||||
LastSeenAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
s := &server{
|
||||
lifecycleCtx: lifecycleCtx,
|
||||
AccessURL: accessURL,
|
||||
ID: id,
|
||||
|
@ -182,7 +215,12 @@ func NewServer(
|
|||
OIDCConfig: options.OIDCConfig,
|
||||
TimeNowFn: options.TimeNowFn,
|
||||
acquireJobLongPollDur: options.AcquireJobLongPollDur,
|
||||
}, nil
|
||||
heartbeatInterval: options.HeartbeatInterval,
|
||||
heartbeatFn: options.HeartbeatFn,
|
||||
}
|
||||
|
||||
go s.heartbeatLoop()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// timeNow should be used when trying to get the current time for math
|
||||
|
@ -194,6 +232,50 @@ func (s *server) timeNow() time.Time {
|
|||
return dbtime.Now()
|
||||
}
|
||||
|
||||
// heartbeatLoop runs heartbeatOnce at the interval specified by HeartbeatInterval
|
||||
// until the lifecycle context is canceled.
|
||||
func (s *server) heartbeatLoop() {
|
||||
tick := time.NewTicker(time.Nanosecond)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.lifecycleCtx.Done():
|
||||
s.Logger.Debug(s.lifecycleCtx, "heartbeat loop canceled")
|
||||
return
|
||||
case <-tick.C:
|
||||
if s.lifecycleCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
start := s.timeNow()
|
||||
hbCtx, hbCancel := context.WithTimeout(s.lifecycleCtx, s.heartbeatInterval)
|
||||
if err := s.heartbeat(hbCtx); err != nil {
|
||||
if !xerrors.Is(err, context.DeadlineExceeded) && !xerrors.Is(err, context.Canceled) {
|
||||
s.Logger.Error(hbCtx, "heartbeat failed", slog.Error(err))
|
||||
}
|
||||
}
|
||||
hbCancel()
|
||||
elapsed := s.timeNow().Sub(start)
|
||||
nextBeat := s.heartbeatInterval - elapsed
|
||||
// avoid negative interval
|
||||
if nextBeat <= 0 {
|
||||
nextBeat = time.Nanosecond
|
||||
}
|
||||
tick.Reset(nextBeat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the last seen at timestamp in the database.
|
||||
// If HeartbeatFn is set, it will be called instead.
|
||||
func (s *server) heartbeat(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return s.heartbeatFn(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireJob queries the database to lock a job.
|
||||
//
|
||||
// Deprecated: This method is only available for back-level provisioner daemons.
|
||||
|
|
|
@ -66,7 +66,8 @@ func testUserQuietHoursScheduleStore() *atomic.Pointer[schedule.UserQuietHoursSc
|
|||
|
||||
func TestAcquireJob_LongPoll(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond})
|
||||
//nolint:dogsled // ૮・ᴥ・ა
|
||||
srv, _, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond})
|
||||
job, err := srv.AcquireJob(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &proto.AcquiredJob{}, job)
|
||||
|
@ -74,7 +75,8 @@ func TestAcquireJob_LongPoll(t *testing.T) {
|
|||
|
||||
func TestAcquireJobWithCancel_Cancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _, _ := setup(t, false, nil)
|
||||
//nolint:dogsled // ૮ ˶′ﻌ ‵˶ ა
|
||||
srv, _, _, _ := setup(t, false, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
fs := newFakeStream(ctx)
|
||||
|
@ -95,6 +97,46 @@ func TestAcquireJobWithCancel_Cancel(t *testing.T) {
|
|||
require.Equal(t, "", job.JobId)
|
||||
}
|
||||
|
||||
func TestHeartbeat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
heartbeatChan := make(chan struct{})
|
||||
heartbeatFn := func(hbCtx context.Context) error {
|
||||
t.Logf("heartbeat")
|
||||
select {
|
||||
case <-hbCtx.Done():
|
||||
return hbCtx.Err()
|
||||
default:
|
||||
heartbeatChan <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
//nolint:dogsled // 。:゚૮ ˶ˆ ﻌ ˆ˶ ა ゚:。
|
||||
_, _, _, _ = setup(t, false, &overrides{
|
||||
ctx: ctx,
|
||||
heartbeatFn: heartbeatFn,
|
||||
heartbeatInterval: testutil.IntervalFast,
|
||||
})
|
||||
|
||||
_, ok := <-heartbeatChan
|
||||
require.True(t, ok, "first heartbeat not received")
|
||||
_, ok = <-heartbeatChan
|
||||
require.True(t, ok, "second heartbeat not received")
|
||||
cancel()
|
||||
// Close the channel to ensure we don't receive any more heartbeats.
|
||||
// The test will fail if we do.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("heartbeat received after cancel: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
close(heartbeatChan)
|
||||
<-time.After(testutil.IntervalMedium)
|
||||
}
|
||||
|
||||
func TestAcquireJob(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -120,7 +162,7 @@ func TestAcquireJob(t *testing.T) {
|
|||
tc := tc
|
||||
t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, _ := setup(t, false, nil)
|
||||
srv, db, _, _ := setup(t, false, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
_, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
|
||||
|
@ -141,7 +183,7 @@ func TestAcquireJob(t *testing.T) {
|
|||
// deployment config.
|
||||
dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)}
|
||||
gitAuthProvider := "github"
|
||||
srv, db, ps := setup(t, false, &overrides{
|
||||
srv, db, ps, _ := setup(t, false, &overrides{
|
||||
deploymentValues: dv,
|
||||
externalAuthConfigs: []*externalauth.Config{{
|
||||
ID: gitAuthProvider,
|
||||
|
@ -359,7 +401,7 @@ func TestAcquireJob(t *testing.T) {
|
|||
|
||||
t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, ps := setup(t, false, nil)
|
||||
srv, db, ps, _ := setup(t, false, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
@ -396,7 +438,7 @@ func TestAcquireJob(t *testing.T) {
|
|||
})
|
||||
t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, ps := setup(t, false, nil)
|
||||
srv, db, ps, _ := setup(t, false, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
@ -427,7 +469,7 @@ func TestAcquireJob(t *testing.T) {
|
|||
})
|
||||
t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, ps := setup(t, false, nil)
|
||||
srv, db, ps, _ := setup(t, false, nil)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
|
||||
|
@ -476,7 +518,7 @@ func TestUpdateJob(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _, _ := setup(t, false, nil)
|
||||
srv, _, _, _ := setup(t, false, nil)
|
||||
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: "hello",
|
||||
})
|
||||
|
@ -489,7 +531,7 @@ func TestUpdateJob(t *testing.T) {
|
|||
})
|
||||
t.Run("NotRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, _ := setup(t, false, nil)
|
||||
srv, db, _, _ := setup(t, false, nil)
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
|
@ -505,7 +547,7 @@ func TestUpdateJob(t *testing.T) {
|
|||
// This test prevents runners from updating jobs they don't own!
|
||||
t.Run("NotOwner", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, _ := setup(t, false, nil)
|
||||
srv, db, _, _ := setup(t, false, nil)
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
|
@ -548,9 +590,8 @@ func TestUpdateJob(t *testing.T) {
|
|||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srvID := uuid.New()
|
||||
srv, db, _ := setup(t, false, &overrides{id: &srvID})
|
||||
job := setupJob(t, db, srvID)
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
job := setupJob(t, db, pd.ID)
|
||||
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.String(),
|
||||
})
|
||||
|
@ -559,9 +600,8 @@ func TestUpdateJob(t *testing.T) {
|
|||
|
||||
t.Run("Logs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srvID := uuid.New()
|
||||
srv, db, ps := setup(t, false, &overrides{id: &srvID})
|
||||
job := setupJob(t, db, srvID)
|
||||
srv, db, ps, pd := setup(t, false, &overrides{})
|
||||
job := setupJob(t, db, pd.ID)
|
||||
|
||||
published := make(chan struct{})
|
||||
|
||||
|
@ -585,9 +625,8 @@ func TestUpdateJob(t *testing.T) {
|
|||
})
|
||||
t.Run("Readme", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srvID := uuid.New()
|
||||
srv, db, _ := setup(t, false, &overrides{id: &srvID})
|
||||
job := setupJob(t, db, srvID)
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
job := setupJob(t, db, pd.ID)
|
||||
versionID := uuid.New()
|
||||
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
||||
ID: versionID,
|
||||
|
@ -612,9 +651,8 @@ func TestUpdateJob(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
srvID := uuid.New()
|
||||
srv, db, _ := setup(t, false, &overrides{id: &srvID})
|
||||
job := setupJob(t, db, srvID)
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
job := setupJob(t, db, pd.ID)
|
||||
versionID := uuid.New()
|
||||
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
||||
ID: versionID,
|
||||
|
@ -660,9 +698,8 @@ func TestUpdateJob(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
srvID := uuid.New()
|
||||
srv, db, _ := setup(t, false, &overrides{id: &srvID})
|
||||
job := setupJob(t, db, srvID)
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
job := setupJob(t, db, pd.ID)
|
||||
versionID := uuid.New()
|
||||
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
||||
ID: versionID,
|
||||
|
@ -707,7 +744,7 @@ func TestFailJob(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _, _ := setup(t, false, nil)
|
||||
srv, _, _, _ := setup(t, false, nil)
|
||||
_, err := srv.FailJob(ctx, &proto.FailedJob{
|
||||
JobId: "hello",
|
||||
})
|
||||
|
@ -721,7 +758,7 @@ func TestFailJob(t *testing.T) {
|
|||
// This test prevents runners from updating jobs they don't own!
|
||||
t.Run("NotOwner", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, _ := setup(t, false, nil)
|
||||
srv, db, _, _ := setup(t, false, nil)
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
|
@ -744,8 +781,7 @@ func TestFailJob(t *testing.T) {
|
|||
})
|
||||
t.Run("AlreadyCompleted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srvID := uuid.New()
|
||||
srv, db, _ := setup(t, false, &overrides{id: &srvID})
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
|
@ -755,7 +791,7 @@ func TestFailJob(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: srvID,
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
|
@ -780,8 +816,7 @@ func TestFailJob(t *testing.T) {
|
|||
//
|
||||
// (*Server).FailJob audit log - get build {"error": "sql: no rows in result set"}
|
||||
ignoreLogErrors := true
|
||||
srvID := uuid.New()
|
||||
srv, db, ps := setup(t, ignoreLogErrors, &overrides{id: &srvID})
|
||||
srv, db, ps, pd := setup(t, ignoreLogErrors, &overrides{})
|
||||
workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{
|
||||
ID: uuid.New(),
|
||||
AutomaticUpdates: database.AutomaticUpdatesNever,
|
||||
|
@ -810,7 +845,7 @@ func TestFailJob(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: srvID,
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
|
@ -852,7 +887,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _, _ := setup(t, false, nil)
|
||||
srv, _, _, _ := setup(t, false, nil)
|
||||
_, err := srv.CompleteJob(ctx, &proto.CompletedJob{
|
||||
JobId: "hello",
|
||||
})
|
||||
|
@ -866,7 +901,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
// This test prevents runners from updating jobs they don't own!
|
||||
t.Run("NotOwner", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, _ := setup(t, false, nil)
|
||||
srv, db, _, _ := setup(t, false, nil)
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
|
@ -890,8 +925,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
|
||||
t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srvID := uuid.New()
|
||||
srv, db, _ := setup(t, false, &overrides{id: &srvID})
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
jobID := uuid.New()
|
||||
versionID := uuid.New()
|
||||
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
||||
|
@ -909,7 +943,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: srvID,
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
|
@ -939,9 +973,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
|
||||
t.Run("TemplateImport_WithGitAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srvID := uuid.New()
|
||||
srv, db, _ := setup(t, false, &overrides{
|
||||
id: &srvID,
|
||||
srv, db, _, pd := setup(t, false, &overrides{
|
||||
externalAuthConfigs: []*externalauth.Config{{
|
||||
ID: "github",
|
||||
}},
|
||||
|
@ -963,7 +995,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: srvID,
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
|
@ -1106,9 +1138,8 @@ func TestCompleteJob(t *testing.T) {
|
|||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srvID := uuid.New()
|
||||
tss := &atomic.Pointer[schedule.TemplateScheduleStore]{}
|
||||
srv, db, ps := setup(t, false, &overrides{id: &srvID, templateScheduleStore: tss})
|
||||
srv, db, ps, pd := setup(t, false, &overrides{templateScheduleStore: tss})
|
||||
|
||||
var store schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{
|
||||
GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) {
|
||||
|
@ -1123,10 +1154,19 @@ func TestCompleteJob(t *testing.T) {
|
|||
}
|
||||
tss.Store(&store)
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
template := dbgen.Template(t, db, database.Template{
|
||||
Name: "template",
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
Name: "template",
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
TemplateID: uuid.NullUUID{
|
||||
UUID: template.ID,
|
||||
Valid: true,
|
||||
},
|
||||
JobID: uuid.New(),
|
||||
})
|
||||
err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{
|
||||
ID: template.ID,
|
||||
|
@ -1148,13 +1188,6 @@ func TestCompleteJob(t *testing.T) {
|
|||
TemplateID: template.ID,
|
||||
Ttl: workspaceTTL,
|
||||
})
|
||||
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
TemplateID: uuid.NullUUID{
|
||||
UUID: template.ID,
|
||||
Valid: true,
|
||||
},
|
||||
JobID: uuid.New(),
|
||||
})
|
||||
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
TemplateVersionID: version.ID,
|
||||
|
@ -1170,7 +1203,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
})
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: srvID,
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
|
@ -1315,19 +1348,17 @@ func TestCompleteJob(t *testing.T) {
|
|||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srvID := uuid.New()
|
||||
// Simulate the given time starting from now.
|
||||
require.False(t, c.now.IsZero())
|
||||
start := time.Now()
|
||||
tss := &atomic.Pointer[schedule.TemplateScheduleStore]{}
|
||||
uqhss := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{}
|
||||
srv, db, ps := setup(t, false, &overrides{
|
||||
srv, db, ps, pd := setup(t, false, &overrides{
|
||||
timeNowFn: func() time.Time {
|
||||
return c.now.Add(time.Since(start))
|
||||
},
|
||||
templateScheduleStore: tss,
|
||||
userQuietHoursScheduleStore: uqhss,
|
||||
id: &srvID,
|
||||
})
|
||||
|
||||
var templateScheduleStore schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{
|
||||
|
@ -1418,7 +1449,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
})
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: srvID,
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
|
@ -1484,8 +1515,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
})
|
||||
t.Run("TemplateDryRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srvID := uuid.New()
|
||||
srv, db, _ := setup(t, false, &overrides{id: &srvID})
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
|
@ -1495,7 +1525,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: srvID,
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
|
@ -1686,73 +1716,89 @@ func TestInsertWorkspaceResource(t *testing.T) {
|
|||
}
|
||||
|
||||
type overrides struct {
|
||||
ctx context.Context
|
||||
deploymentValues *codersdk.DeploymentValues
|
||||
externalAuthConfigs []*externalauth.Config
|
||||
id *uuid.UUID
|
||||
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
||||
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
|
||||
timeNowFn func() time.Time
|
||||
acquireJobLongPollDuration time.Duration
|
||||
heartbeatFn func(ctx context.Context) error
|
||||
heartbeatInterval time.Duration
|
||||
}
|
||||
|
||||
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) {
|
||||
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub, database.ProvisionerDaemon) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
db := dbmem.New()
|
||||
ps := pubsub.NewInMemory()
|
||||
deploymentValues := &codersdk.DeploymentValues{}
|
||||
var externalAuthConfigs []*externalauth.Config
|
||||
srvID := uuid.New()
|
||||
tss := testTemplateScheduleStore()
|
||||
uqhss := testUserQuietHoursScheduleStore()
|
||||
var timeNowFn func() time.Time
|
||||
pollDur := time.Duration(0)
|
||||
if ov != nil {
|
||||
if ov.deploymentValues != nil {
|
||||
deploymentValues = ov.deploymentValues
|
||||
}
|
||||
if ov.externalAuthConfigs != nil {
|
||||
externalAuthConfigs = ov.externalAuthConfigs
|
||||
}
|
||||
if ov.id != nil {
|
||||
srvID = *ov.id
|
||||
}
|
||||
if ov.templateScheduleStore != nil {
|
||||
ttss := tss.Load()
|
||||
// keep the initial test value if the override hasn't set the atomic pointer.
|
||||
tss = ov.templateScheduleStore
|
||||
if tss.Load() == nil {
|
||||
swapped := tss.CompareAndSwap(nil, ttss)
|
||||
require.True(t, swapped)
|
||||
}
|
||||
}
|
||||
if ov.userQuietHoursScheduleStore != nil {
|
||||
tuqhss := uqhss.Load()
|
||||
// keep the initial test value if the override hasn't set the atomic pointer.
|
||||
uqhss = ov.userQuietHoursScheduleStore
|
||||
if uqhss.Load() == nil {
|
||||
swapped := uqhss.CompareAndSwap(nil, tuqhss)
|
||||
require.True(t, swapped)
|
||||
}
|
||||
}
|
||||
if ov.timeNowFn != nil {
|
||||
timeNowFn = ov.timeNowFn
|
||||
}
|
||||
pollDur = ov.acquireJobLongPollDuration
|
||||
if ov == nil {
|
||||
ov = &overrides{}
|
||||
}
|
||||
if ov.ctx == nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
ov.ctx = ctx
|
||||
}
|
||||
if ov.heartbeatInterval == 0 {
|
||||
ov.heartbeatInterval = testutil.IntervalMedium
|
||||
}
|
||||
if ov.deploymentValues != nil {
|
||||
deploymentValues = ov.deploymentValues
|
||||
}
|
||||
if ov.externalAuthConfigs != nil {
|
||||
externalAuthConfigs = ov.externalAuthConfigs
|
||||
}
|
||||
if ov.templateScheduleStore != nil {
|
||||
ttss := tss.Load()
|
||||
// keep the initial test value if the override hasn't set the atomic pointer.
|
||||
tss = ov.templateScheduleStore
|
||||
if tss.Load() == nil {
|
||||
swapped := tss.CompareAndSwap(nil, ttss)
|
||||
require.True(t, swapped)
|
||||
}
|
||||
}
|
||||
if ov.userQuietHoursScheduleStore != nil {
|
||||
tuqhss := uqhss.Load()
|
||||
// keep the initial test value if the override hasn't set the atomic pointer.
|
||||
uqhss = ov.userQuietHoursScheduleStore
|
||||
if uqhss.Load() == nil {
|
||||
swapped := uqhss.CompareAndSwap(nil, tuqhss)
|
||||
require.True(t, swapped)
|
||||
}
|
||||
}
|
||||
if ov.timeNowFn != nil {
|
||||
timeNowFn = ov.timeNowFn
|
||||
}
|
||||
pollDur = ov.acquireJobLongPollDuration
|
||||
|
||||
daemon, err := db.UpsertProvisionerDaemon(ov.ctx, database.UpsertProvisionerDaemonParams{
|
||||
Name: "test",
|
||||
CreatedAt: dbtime.Now(),
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
Tags: database.StringMap{},
|
||||
LastSeenAt: sql.NullTime{},
|
||||
Version: "",
|
||||
APIVersion: "1.0",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
srv, err := provisionerdserver.NewServer(
|
||||
ctx,
|
||||
ov.ctx,
|
||||
&url.URL{},
|
||||
srvID,
|
||||
daemon.ID,
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
|
||||
[]database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
provisionerdserver.Tags{},
|
||||
provisionerdserver.Tags(daemon.Tags),
|
||||
db,
|
||||
ps,
|
||||
provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps),
|
||||
provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps),
|
||||
telemetry.NewNoop(),
|
||||
trace.NewNoopTracerProvider().Tracer("noop"),
|
||||
&atomic.Pointer[proto.QuotaCommitter]{},
|
||||
|
@ -1765,10 +1811,12 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
|
|||
TimeNowFn: timeNowFn,
|
||||
OIDCConfig: &oauth2.Config{},
|
||||
AcquireJobLongPollDur: pollDur,
|
||||
HeartbeatInterval: ov.heartbeatInterval,
|
||||
HeartbeatFn: ov.heartbeatFn,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return srv, db, ps
|
||||
return srv, db, ps, daemon
|
||||
}
|
||||
|
||||
func must[T any](value T, err error) T {
|
||||
|
|
|
@ -155,6 +155,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
|
|||
// Users cannot do create/update/delete on themselves, but they
|
||||
// can read their own details.
|
||||
ResourceUser.Type: {ActionRead},
|
||||
// Users can create provisioner daemons scoped to themselves.
|
||||
ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate},
|
||||
})...,
|
||||
),
|
||||
}.withCachedRegoValue()
|
||||
|
|
|
@ -164,9 +164,6 @@ func (c *Client) Organization(ctx context.Context, id uuid.UUID) (Organization,
|
|||
}
|
||||
|
||||
// ProvisionerDaemons returns provisioner daemons available.
|
||||
//
|
||||
// Deprecated: We no longer track provisioner daemons as they connect. This function may return historical data
|
||||
// but new provisioner daemons will not appear.
|
||||
func (c *Client) ProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet,
|
||||
// TODO: the organization path parameter is currently ignored.
|
||||
|
|
|
@ -4,13 +4,16 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
@ -35,6 +38,17 @@ func TestProvisionerDaemon_PSK(t *testing.T) {
|
|||
clitest.Start(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
|
||||
pty.ExpectMatchContext(ctx, "matt-daemon")
|
||||
|
||||
var daemons []codersdk.ProvisionerDaemon
|
||||
require.Eventually(t, func() bool {
|
||||
daemons, err = client.ProvisionerDaemons(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return len(daemons) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalSlow)
|
||||
require.Equal(t, "matt-daemon", daemons[0].Name)
|
||||
require.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
|
||||
}
|
||||
|
||||
func TestProvisionerDaemon_SessionToken(t *testing.T) {
|
||||
|
@ -49,14 +63,61 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) {
|
|||
},
|
||||
},
|
||||
})
|
||||
anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user")
|
||||
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user", "--name", "my-daemon")
|
||||
clitest.SetupConfig(t, anotherClient, conf)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
clitest.Start(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
|
||||
|
||||
var daemons []codersdk.ProvisionerDaemon
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
daemons, err = client.ProvisionerDaemons(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return len(daemons) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalSlow)
|
||||
assert.Equal(t, "my-daemon", daemons[0].Name)
|
||||
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope])
|
||||
assert.Equal(t, anotherUser.ID.String(), daemons[0].Tags[provisionersdk.TagOwner])
|
||||
})
|
||||
|
||||
t.Run("ScopeAnotherUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, admin := coderdenttest.New(t, &coderdenttest.Options{
|
||||
ProvisionerDaemonPSK: "provisionersftw",
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user", "--tag", "owner="+admin.UserID.String(), "--name", "my-daemon")
|
||||
clitest.SetupConfig(t, anotherClient, conf)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
clitest.Start(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
|
||||
|
||||
var daemons []codersdk.ProvisionerDaemon
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
daemons, err = client.ProvisionerDaemons(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return len(daemons) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalSlow)
|
||||
assert.Equal(t, "my-daemon", daemons[0].Name)
|
||||
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope])
|
||||
// This should get clobbered to the user who started the daemon.
|
||||
assert.Equal(t, anotherUser.ID.String(), daemons[0].Tags[provisionersdk.TagOwner])
|
||||
})
|
||||
|
||||
t.Run("ScopeOrg", func(t *testing.T) {
|
||||
|
@ -69,13 +130,25 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) {
|
|||
},
|
||||
},
|
||||
})
|
||||
anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=organization")
|
||||
anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID, rbac.RoleTemplateAdmin())
|
||||
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=organization", "--name", "org-daemon")
|
||||
clitest.SetupConfig(t, anotherClient, conf)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
clitest.Start(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
|
||||
|
||||
var daemons []codersdk.ProvisionerDaemon
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
daemons, err = client.ProvisionerDaemons(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return len(daemons) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalSlow)
|
||||
assert.Equal(t, "org-daemon", daemons[0].Name)
|
||||
assert.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
|
||||
})
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ import (
|
|||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||
|
@ -121,7 +123,7 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
|
|||
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
|
||||
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
|
||||
// If using PSK auth, the daemon is, by definition, scoped to the organization.
|
||||
tags[provisionersdk.TagScope] = provisionersdk.ScopeOrganization
|
||||
tags = provisionersdk.MutateTags(uuid.Nil, tags)
|
||||
return tags, true
|
||||
}
|
||||
}
|
||||
|
@ -191,7 +193,10 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||
if !authorized {
|
||||
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags))
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden,
|
||||
codersdk.Response{Message: "You aren't allowed to create provisioner daemons"})
|
||||
codersdk.Response{
|
||||
Message: fmt.Sprintf("You aren't allowed to create provisioner daemons with scope %q", tags[provisionersdk.TagScope]),
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags))
|
||||
|
@ -221,6 +226,35 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||
slog.F("tags", tags),
|
||||
)
|
||||
|
||||
authCtx := ctx
|
||||
if r.Header.Get(codersdk.ProvisionerDaemonPSK) != "" {
|
||||
//nolint:gocritic // PSK auth means no actor in request,
|
||||
// so use system restricted.
|
||||
authCtx = dbauthz.AsSystemRestricted(ctx)
|
||||
}
|
||||
|
||||
// Create the daemon in the database.
|
||||
now := dbtime.Now()
|
||||
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
|
||||
Name: name,
|
||||
Provisioners: provisioners,
|
||||
Tags: tags,
|
||||
CreatedAt: now,
|
||||
LastSeenAt: sql.NullTime{Time: now, Valid: true},
|
||||
Version: "", // TODO: provisionerd needs to send version
|
||||
APIVersion: "1.0",
|
||||
})
|
||||
if err != nil {
|
||||
if !xerrors.Is(err, context.Canceled) {
|
||||
log.Error(ctx, "create provisioner daemon", slog.Error(err))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating provisioner daemon.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
api.AGPL.WebsocketWaitMutex.Lock()
|
||||
api.AGPL.WebsocketWaitGroup.Add(1)
|
||||
api.AGPL.WebsocketWaitMutex.Unlock()
|
||||
|
@ -264,11 +298,13 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
mux := drpcmux.New()
|
||||
logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name))
|
||||
srvCtx, srvCancel := context.WithCancel(ctx)
|
||||
defer srvCancel()
|
||||
logger.Info(ctx, "starting external provisioner daemon")
|
||||
srv, err := provisionerdserver.NewServer(
|
||||
api.ctx,
|
||||
srvCtx,
|
||||
api.AccessURL,
|
||||
id,
|
||||
daemon.ID,
|
||||
logger,
|
||||
provisioners,
|
||||
tags,
|
||||
|
@ -308,6 +344,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||
},
|
||||
})
|
||||
err = server.Serve(ctx, session)
|
||||
srvCancel()
|
||||
logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err))
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
@ -50,6 +51,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
srv.DRPCConn().Close()
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
require.Len(t, daemons, 1)
|
||||
})
|
||||
|
||||
t.Run("NoLicense", func(t *testing.T) {
|
||||
|
@ -163,6 +168,15 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
daemons, err := client.ProvisionerDaemons(context.Background())
|
||||
assert.NoError(t, err, "failed to get provisioner daemons")
|
||||
return len(daemons) > 0 &&
|
||||
assert.Equal(t, t.Name(), daemons[0].Name) &&
|
||||
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope]) &&
|
||||
assert.Equal(t, user.UserID.String(), daemons[0].Tags[provisionersdk.TagOwner])
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{
|
||||
Name: "example",
|
||||
StorageMethod: codersdk.ProvisionerStorageMethodFile,
|
||||
|
@ -211,6 +225,12 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
err = srv.DRPCConn().Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
if assert.Len(t, daemons, 1) {
|
||||
assert.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PSK_daily_cost", func(t *testing.T) {
|
||||
|
@ -346,6 +366,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
var apiError *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiError)
|
||||
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
require.Len(t, daemons, 0)
|
||||
})
|
||||
|
||||
t.Run("NoAuth", func(t *testing.T) {
|
||||
|
@ -376,6 +400,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
var apiError *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiError)
|
||||
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
require.Len(t, daemons, 0)
|
||||
})
|
||||
|
||||
t.Run("NoPSK", func(t *testing.T) {
|
||||
|
@ -406,5 +434,9 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
var apiError *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiError)
|
||||
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
require.Len(t, daemons, 0)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -20,6 +21,7 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionerd/runner"
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
|
@ -199,6 +201,12 @@ connectLoop:
|
|||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
// If something is wrong with our auth, stop trying to connect.
|
||||
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusForbidden {
|
||||
p.opts.Logger.Error(p.closeContext, "not authorized to dial coderd", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if p.isClosed() {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -14,7 +14,10 @@ const (
|
|||
// If the scope is "user", the "owner" is changed to the user ID.
|
||||
// This is for user-scoped provisioner daemons, where users should
|
||||
// own their own operations.
|
||||
// Otherwise, the "owner" tag is always empty.
|
||||
// Otherwise, the "owner" tag is always an empty string.
|
||||
// NOTE: "owner" must NEVER be nil. Otherwise it will end up being
|
||||
// duplicated in the database, as idx_provisioner_daemons_name_owner_key
|
||||
// is a partial unique index that includes a JSON field.
|
||||
func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
|
||||
if tags == nil {
|
||||
tags = map[string]string{}
|
||||
|
@ -22,15 +25,16 @@ func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
|
|||
_, ok := tags[TagScope]
|
||||
if !ok {
|
||||
tags[TagScope] = ScopeOrganization
|
||||
delete(tags, TagOwner)
|
||||
tags[TagOwner] = ""
|
||||
}
|
||||
switch tags[TagScope] {
|
||||
case ScopeUser:
|
||||
tags[TagOwner] = userID.String()
|
||||
case ScopeOrganization:
|
||||
delete(tags, TagOwner)
|
||||
tags[TagOwner] = ""
|
||||
default:
|
||||
tags[TagScope] = ScopeOrganization
|
||||
tags[TagOwner] = ""
|
||||
}
|
||||
return tags
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ func TestMutateTags(t *testing.T) {
|
|||
tags: nil,
|
||||
want: map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -35,6 +36,7 @@ func TestMutateTags(t *testing.T) {
|
|||
tags: map[string]string{},
|
||||
want: map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -52,6 +54,7 @@ func TestMutateTags(t *testing.T) {
|
|||
userID: testUserID,
|
||||
want: map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -63,6 +66,7 @@ func TestMutateTags(t *testing.T) {
|
|||
userID: uuid.Nil,
|
||||
want: map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -73,6 +77,7 @@ func TestMutateTags(t *testing.T) {
|
|||
userID: uuid.Nil,
|
||||
want: map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -81,6 +86,7 @@ func TestMutateTags(t *testing.T) {
|
|||
userID: testUserID,
|
||||
want: map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
provisionersdk.TagOwner: "",
|
||||
},
|
||||
},
|
||||
} {
|
||||
|
|
Loading…
Reference in New Issue