feat(coderd): insert provisioner daemons (#11207)

* Adds UpdateProvisionerDaemonLastSeenAt
* Adds heartbeat to provisioner daemons
* Inserts provisioner daemons to database upon start
* Ensures TagOwner is an empty string and not nil
* Adds COALESCE() in idx_provisioner_daemons_name_owner_key
This commit is contained in:
Cian Johnston 2023-12-18 16:44:52 +00:00 committed by GitHub
parent a6901ae2c5
commit 213b768785
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 548 additions and 137 deletions

View File

@ -36,6 +36,7 @@
"worker_id": "[workspace build worker ID]",
"file_id": "[workspace build file ID]",
"tags": {
"owner": "",
"scope": "organization"
},
"queue_position": 0,

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"flag"
"fmt"
"io"
@ -49,6 +50,7 @@ import (
"github.com/coder/coder/v2/coderd/batchstats"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/gitsshkey"
@ -1178,22 +1180,32 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string
}
}()
tags := provisionerdserver.Tags{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
//nolint:gocritic // in-memory provisioners are owned by system
daemon, err := api.Database.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{
Name: name,
CreatedAt: dbtime.Now(),
Provisioners: []database.ProvisionerType{
database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform,
},
Tags: provisionersdk.MutateTags(uuid.Nil, nil),
LastSeenAt: sql.NullTime{Time: dbtime.Now(), Valid: true},
Version: buildinfo.Version(),
APIVersion: "1.0",
})
if err != nil {
return nil, xerrors.Errorf("failed to create in-memory provisioner daemon: %w", err)
}
mux := drpcmux.New()
api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name))
logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name))
srv, err := provisionerdserver.NewServer(
api.ctx,
api.ctx, // use the same ctx as the API
api.AccessURL,
uuid.New(),
daemon.ID,
logger,
[]database.ProvisionerType{
database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform,
},
tags,
daemon.Provisioners,
provisionerdserver.Tags(daemon.Tags),
api.Database,
api.Pubsub,
api.Acquirer,

View File

@ -533,7 +533,7 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer {
}()
daemon := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, t.Name())
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, "test")
}, &provisionerd.Options{
Logger: coderAPI.Logger.Named("provisionerd").Leveled(slog.LevelDebug),
UpdateInterval: 250 * time.Millisecond,

View File

@ -232,6 +232,7 @@ var (
rbac.ResourceOrganization.Type: {rbac.ActionCreate},
rbac.ResourceOrganizationMember.Type: {rbac.ActionCreate},
rbac.ResourceOrgRoleAssignment.Type: {rbac.ActionCreate},
rbac.ResourceProvisionerDaemon.Type: {rbac.ActionCreate, rbac.ActionUpdate},
rbac.ResourceUser.Type: {rbac.ActionCreate, rbac.ActionUpdate, rbac.ActionDelete},
rbac.ResourceUserData.Type: {rbac.ActionCreate, rbac.ActionUpdate},
rbac.ResourceWorkspace.Type: {rbac.ActionUpdate},
@ -2499,6 +2500,13 @@ func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemb
return q.db.UpdateMemberRoles(ctx, arg)
}
func (q *querier) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceProvisionerDaemon); err != nil {
return err
}
return q.db.UpdateProvisionerDaemonLastSeenAt(ctx, arg)
}
// TODO: We need to create a ProvisionerJob resource type
func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
// if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {

View File

@ -1592,6 +1592,18 @@ func (s *MethodTestSuite) TestExtraMethods() {
s.NoError(err, "insert provisioner daemon")
check.Args().Asserts(rbac.ResourceSystem, rbac.ActionDelete)
}))
s.Run("UpdateProvisionerDaemonLastSeenAt", s.Subtest(func(db database.Store, check *expects) {
d, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{
Tags: database.StringMap(map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
}),
})
s.NoError(err, "insert provisioner daemon")
check.Args(database.UpdateProvisionerDaemonLastSeenAtParams{
ID: d.ID,
LastSeenAt: sql.NullTime{Time: dbtime.Now(), Valid: true},
}).Asserts(rbac.ResourceProvisionerDaemon, rbac.ActionUpdate)
}))
}
// All functions in this method test suite are not implemented in dbmem, but

View File

@ -28,7 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/util/slice"
)
var errMatchAny = errors.New("match any error")
var errMatchAny = xerrors.New("match any error")
var skipMethods = map[string]string{
"InTx": "Not relevant",

View File

@ -5945,6 +5945,28 @@ func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMe
return database.OrganizationMember{}, sql.ErrNoRows
}
func (q *FakeQuerier) UpdateProvisionerDaemonLastSeenAt(_ context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
err := validateDatabaseType(arg)
if err != nil {
return err
}
q.mutex.Lock()
defer q.mutex.Unlock()
for idx := range q.provisionerDaemons {
if q.provisionerDaemons[idx].ID != arg.ID {
continue
}
if q.provisionerDaemons[idx].LastSeenAt.Time.After(arg.LastSeenAt.Time) {
continue
}
q.provisionerDaemons[idx].LastSeenAt = arg.LastSeenAt
return nil
}
return sql.ErrNoRows
}
func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error {
if err := validateDatabaseType(arg); err != nil {
return err

View File

@ -1593,6 +1593,13 @@ func (m metricsStore) UpdateMemberRoles(ctx context.Context, arg database.Update
return member, err
}
func (m metricsStore) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg database.UpdateProvisionerDaemonLastSeenAtParams) error {
start := time.Now()
r0 := m.s.UpdateProvisionerDaemonLastSeenAt(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateProvisionerDaemonLastSeenAt").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
start := time.Now()
err := m.s.UpdateProvisionerJobByID(ctx, arg)

View File

@ -3362,6 +3362,20 @@ func (mr *MockStoreMockRecorder) UpdateMemberRoles(arg0, arg1 interface{}) *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMemberRoles", reflect.TypeOf((*MockStore)(nil).UpdateMemberRoles), arg0, arg1)
}
// UpdateProvisionerDaemonLastSeenAt mocks base method.
func (m *MockStore) UpdateProvisionerDaemonLastSeenAt(arg0 context.Context, arg1 database.UpdateProvisionerDaemonLastSeenAtParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProvisionerDaemonLastSeenAt", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProvisionerDaemonLastSeenAt indicates an expected call of UpdateProvisionerDaemonLastSeenAt.
func (mr *MockStoreMockRecorder) UpdateProvisionerDaemonLastSeenAt(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProvisionerDaemonLastSeenAt", reflect.TypeOf((*MockStore)(nil).UpdateProvisionerDaemonLastSeenAt), arg0, arg1)
}
// UpdateProvisionerJobByID mocks base method.
func (m *MockStore) UpdateProvisionerJobByID(arg0 context.Context, arg1 database.UpdateProvisionerJobByIDParams) error {
m.ctrl.T.Helper()

View File

@ -1417,9 +1417,9 @@ CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name);
CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name));
CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower((tags ->> 'owner'::text)));
CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower(COALESCE((tags ->> 'owner'::text), ''::text)));
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key IS 'Relax uniqueness constraint for provisioner daemon names';
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key IS 'Allow unique provisioner daemon names by user';
CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id);

View File

@ -0,0 +1,8 @@
DROP INDEX IF EXISTS idx_provisioner_daemons_name_owner_key;
CREATE UNIQUE INDEX IF NOT EXISTS idx_provisioner_daemons_name_owner_key
ON provisioner_daemons
USING btree (name, lower((tags->>'owner')::text));
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key
IS 'Relax uniqueness constraint for provisioner daemon names';

View File

@ -0,0 +1,8 @@
DROP INDEX IF EXISTS idx_provisioner_daemons_name_owner_key;
CREATE UNIQUE INDEX IF NOT EXISTS idx_provisioner_daemons_name_owner_key
ON provisioner_daemons
USING btree (name, LOWER(COALESCE(tags->>'owner', '')::text));
COMMENT ON INDEX idx_provisioner_daemons_name_owner_key
IS 'Allow unique provisioner daemon names by user';

View File

@ -318,6 +318,7 @@ type sqlcQuerier interface {
UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error)
UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error)
UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error)
UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error
UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error

View File

@ -3057,6 +3057,26 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
return items, nil
}
const updateProvisionerDaemonLastSeenAt = `-- name: UpdateProvisionerDaemonLastSeenAt :exec
UPDATE provisioner_daemons
SET
last_seen_at = $1
WHERE
id = $2
AND
last_seen_at <= $1
`
type UpdateProvisionerDaemonLastSeenAtParams struct {
LastSeenAt sql.NullTime `db:"last_seen_at" json:"last_seen_at"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateProvisionerDaemonLastSeenAt(ctx context.Context, arg UpdateProvisionerDaemonLastSeenAtParams) error {
_, err := q.db.ExecContext(ctx, updateProvisionerDaemonLastSeenAt, arg.LastSeenAt, arg.ID)
return err
}
const upsertProvisionerDaemon = `-- name: UpsertProvisionerDaemon :one
INSERT INTO
provisioner_daemons (
@ -3078,7 +3098,7 @@ VALUES (
$5,
$6,
$7
) ON CONFLICT("name", lower((tags ->> 'owner'::text))) DO UPDATE SET
) ON CONFLICT("name", LOWER(COALESCE(tags ->> 'owner'::text, ''::text))) DO UPDATE SET
provisioners = $3,
tags = $4,
last_seen_at = $5,

View File

@ -35,7 +35,7 @@ VALUES (
@last_seen_at,
@version,
@api_version
) ON CONFLICT("name", lower((tags ->> 'owner'::text))) DO UPDATE SET
) ON CONFLICT("name", LOWER(COALESCE(tags ->> 'owner'::text, ''::text))) DO UPDATE SET
provisioners = @provisioners,
tags = @tags,
last_seen_at = @last_seen_at,
@ -45,3 +45,12 @@ WHERE
-- Only ones with the same tags are allowed clobber
provisioner_daemons.tags <@ @tags :: jsonb
RETURNING *;
-- name: UpdateProvisionerDaemonLastSeenAt :exec
UPDATE provisioner_daemons
SET
last_seen_at = @last_seen_at
WHERE
id = @id
AND
last_seen_at <= @last_seen_at;

View File

@ -65,7 +65,7 @@ const (
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
UniqueIndexOrganizationName UniqueConstraint = "idx_organization_name" // CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name);
UniqueIndexOrganizationNameLower UniqueConstraint = "idx_organization_name_lower" // CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name));
UniqueIndexProvisionerDaemonsNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower((tags ->> 'owner'::text)));
UniqueIndexProvisionerDaemonsNameOwnerKey UniqueConstraint = "idx_provisioner_daemons_name_owner_key" // CREATE UNIQUE INDEX idx_provisioner_daemons_name_owner_key ON provisioner_daemons USING btree (name, lower(COALESCE((tags ->> 'owner'::text), ''::text)));
UniqueIndexUsersEmail UniqueConstraint = "idx_users_email" // CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false);
UniqueIndexUsersUsername UniqueConstraint = "idx_users_username" // CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false);
UniqueTemplatesOrganizationIDNameIndex UniqueConstraint = "templates_organization_id_name_idx" // CREATE UNIQUE INDEX templates_organization_id_name_idx ON templates USING btree (organization_id, lower((name)::text)) WHERE (deleted = false);

View File

@ -44,9 +44,15 @@ import (
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
)
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
// canceling and returning an empty job.
const DefaultAcquireJobLongPollDur = time.Second * 5
const (
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
// canceling and returning an empty job.
DefaultAcquireJobLongPollDur = time.Second * 5
// DefaultHeartbeatInterval is the interval at which the provisioner daemon
// will update its last seen at timestamp in the database.
DefaultHeartbeatInterval = time.Minute
)
type Options struct {
OIDCConfig httpmw.OAuth2Config
@ -56,6 +62,16 @@ type Options struct {
// AcquireJobLongPollDur is used in tests
AcquireJobLongPollDur time.Duration
// HeartbeatInterval is the interval at which the provisioner daemon
// will update its last seen at timestamp in the database.
HeartbeatInterval time.Duration
// HeartbeatFn is the function that will be called at the interval
// specified by HeartbeatInterval.
// The default function just calls UpdateProvisionerDaemonLastSeenAt.
// This is mainly used for testing.
HeartbeatFn func(context.Context) error
}
type server struct {
@ -85,6 +101,9 @@ type server struct {
TimeNowFn func() time.Time
acquireJobLongPollDur time.Duration
heartbeatInterval time.Duration
heartbeatFn func(ctx context.Context) error
}
// We use the null byte (0x00) in generating a canonical map key for tags, so
@ -161,7 +180,21 @@ func NewServer(
if options.AcquireJobLongPollDur == 0 {
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
}
return &server{
if options.HeartbeatInterval == 0 {
options.HeartbeatInterval = DefaultHeartbeatInterval
}
// Avoid a nil check in s.heartbeat.
if options.HeartbeatFn == nil {
options.HeartbeatFn = func(hbCtx context.Context) error {
//nolint:gocritic // This is specifically for updating the last seen at timestamp.
return db.UpdateProvisionerDaemonLastSeenAt(dbauthz.AsSystemRestricted(hbCtx), database.UpdateProvisionerDaemonLastSeenAtParams{
ID: id,
LastSeenAt: sql.NullTime{Time: time.Now(), Valid: true},
})
}
}
s := &server{
lifecycleCtx: lifecycleCtx,
AccessURL: accessURL,
ID: id,
@ -182,7 +215,12 @@ func NewServer(
OIDCConfig: options.OIDCConfig,
TimeNowFn: options.TimeNowFn,
acquireJobLongPollDur: options.AcquireJobLongPollDur,
}, nil
heartbeatInterval: options.HeartbeatInterval,
heartbeatFn: options.HeartbeatFn,
}
go s.heartbeatLoop()
return s, nil
}
// timeNow should be used when trying to get the current time for math
@ -194,6 +232,50 @@ func (s *server) timeNow() time.Time {
return dbtime.Now()
}
// heartbeatLoop runs heartbeatOnce at the interval specified by HeartbeatInterval
// until the lifecycle context is canceled.
func (s *server) heartbeatLoop() {
tick := time.NewTicker(time.Nanosecond)
defer tick.Stop()
for {
select {
case <-s.lifecycleCtx.Done():
s.Logger.Debug(s.lifecycleCtx, "heartbeat loop canceled")
return
case <-tick.C:
if s.lifecycleCtx.Err() != nil {
return
}
start := s.timeNow()
hbCtx, hbCancel := context.WithTimeout(s.lifecycleCtx, s.heartbeatInterval)
if err := s.heartbeat(hbCtx); err != nil {
if !xerrors.Is(err, context.DeadlineExceeded) && !xerrors.Is(err, context.Canceled) {
s.Logger.Error(hbCtx, "heartbeat failed", slog.Error(err))
}
}
hbCancel()
elapsed := s.timeNow().Sub(start)
nextBeat := s.heartbeatInterval - elapsed
// avoid negative interval
if nextBeat <= 0 {
nextBeat = time.Nanosecond
}
tick.Reset(nextBeat)
}
}
}
// heartbeat updates the last seen at timestamp in the database.
// If HeartbeatFn is set, it will be called instead.
func (s *server) heartbeat(ctx context.Context) error {
select {
case <-ctx.Done():
return nil
default:
return s.heartbeatFn(ctx)
}
}
// AcquireJob queries the database to lock a job.
//
// Deprecated: This method is only available for back-level provisioner daemons.

View File

@ -66,7 +66,8 @@ func testUserQuietHoursScheduleStore() *atomic.Pointer[schedule.UserQuietHoursSc
func TestAcquireJob_LongPoll(t *testing.T) {
t.Parallel()
srv, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond})
//nolint:dogsled // ૮・ᴥ・ა
srv, _, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond})
job, err := srv.AcquireJob(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, &proto.AcquiredJob{}, job)
@ -74,7 +75,8 @@ func TestAcquireJob_LongPoll(t *testing.T) {
func TestAcquireJobWithCancel_Cancel(t *testing.T) {
t.Parallel()
srv, _, _ := setup(t, false, nil)
//nolint:dogsled // ૮ ˶′ﻌ ‵˶ ა
srv, _, _, _ := setup(t, false, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
fs := newFakeStream(ctx)
@ -95,6 +97,46 @@ func TestAcquireJobWithCancel_Cancel(t *testing.T) {
require.Equal(t, "", job.JobId)
}
func TestHeartbeat(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
heartbeatChan := make(chan struct{})
heartbeatFn := func(hbCtx context.Context) error {
t.Logf("heartbeat")
select {
case <-hbCtx.Done():
return hbCtx.Err()
default:
heartbeatChan <- struct{}{}
return nil
}
}
//nolint:dogsled // 。:゚૮ ˶ˆ ﻌ ˆ˶ ა ゚:。
_, _, _, _ = setup(t, false, &overrides{
ctx: ctx,
heartbeatFn: heartbeatFn,
heartbeatInterval: testutil.IntervalFast,
})
_, ok := <-heartbeatChan
require.True(t, ok, "first heartbeat not received")
_, ok = <-heartbeatChan
require.True(t, ok, "second heartbeat not received")
cancel()
// Close the channel to ensure we don't receive any more heartbeats.
// The test will fail if we do.
defer func() {
if r := recover(); r != nil {
t.Fatalf("heartbeat received after cancel: %v", r)
}
}()
close(heartbeatChan)
<-time.After(testutil.IntervalMedium)
}
func TestAcquireJob(t *testing.T) {
t.Parallel()
@ -120,7 +162,7 @@ func TestAcquireJob(t *testing.T) {
tc := tc
t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
srv, db, _, _ := setup(t, false, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
_, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
@ -141,7 +183,7 @@ func TestAcquireJob(t *testing.T) {
// deployment config.
dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)}
gitAuthProvider := "github"
srv, db, ps := setup(t, false, &overrides{
srv, db, ps, _ := setup(t, false, &overrides{
deploymentValues: dv,
externalAuthConfigs: []*externalauth.Config{{
ID: gitAuthProvider,
@ -359,7 +401,7 @@ func TestAcquireJob(t *testing.T) {
t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) {
t.Parallel()
srv, db, ps := setup(t, false, nil)
srv, db, ps, _ := setup(t, false, nil)
ctx := context.Background()
user := dbgen.User(t, db, database.User{})
@ -396,7 +438,7 @@ func TestAcquireJob(t *testing.T) {
})
t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) {
t.Parallel()
srv, db, ps := setup(t, false, nil)
srv, db, ps, _ := setup(t, false, nil)
ctx := context.Background()
user := dbgen.User(t, db, database.User{})
@ -427,7 +469,7 @@ func TestAcquireJob(t *testing.T) {
})
t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) {
t.Parallel()
srv, db, ps := setup(t, false, nil)
srv, db, ps, _ := setup(t, false, nil)
user := dbgen.User(t, db, database.User{})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
@ -476,7 +518,7 @@ func TestUpdateJob(t *testing.T) {
ctx := context.Background()
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
srv, _, _ := setup(t, false, nil)
srv, _, _, _ := setup(t, false, nil)
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: "hello",
})
@ -489,7 +531,7 @@ func TestUpdateJob(t *testing.T) {
})
t.Run("NotRunning", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
srv, db, _, _ := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
@ -505,7 +547,7 @@ func TestUpdateJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
srv, db, _, _ := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
@ -548,9 +590,8 @@ func TestUpdateJob(t *testing.T) {
t.Run("Success", func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
srv, db, _ := setup(t, false, &overrides{id: &srvID})
job := setupJob(t, db, srvID)
srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID)
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.String(),
})
@ -559,9 +600,8 @@ func TestUpdateJob(t *testing.T) {
t.Run("Logs", func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
srv, db, ps := setup(t, false, &overrides{id: &srvID})
job := setupJob(t, db, srvID)
srv, db, ps, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID)
published := make(chan struct{})
@ -585,9 +625,8 @@ func TestUpdateJob(t *testing.T) {
})
t.Run("Readme", func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
srv, db, _ := setup(t, false, &overrides{id: &srvID})
job := setupJob(t, db, srvID)
srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID)
versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID,
@ -612,9 +651,8 @@ func TestUpdateJob(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
srvID := uuid.New()
srv, db, _ := setup(t, false, &overrides{id: &srvID})
job := setupJob(t, db, srvID)
srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID)
versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID,
@ -660,9 +698,8 @@ func TestUpdateJob(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
srvID := uuid.New()
srv, db, _ := setup(t, false, &overrides{id: &srvID})
job := setupJob(t, db, srvID)
srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID)
versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID,
@ -707,7 +744,7 @@ func TestFailJob(t *testing.T) {
ctx := context.Background()
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
srv, _, _ := setup(t, false, nil)
srv, _, _, _ := setup(t, false, nil)
_, err := srv.FailJob(ctx, &proto.FailedJob{
JobId: "hello",
})
@ -721,7 +758,7 @@ func TestFailJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
srv, db, _, _ := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
@ -744,8 +781,7 @@ func TestFailJob(t *testing.T) {
})
t.Run("AlreadyCompleted", func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
srv, db, _ := setup(t, false, &overrides{id: &srvID})
srv, db, _, pd := setup(t, false, &overrides{})
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
@ -755,7 +791,7 @@ func TestFailJob(t *testing.T) {
require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: srvID,
UUID: pd.ID,
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -780,8 +816,7 @@ func TestFailJob(t *testing.T) {
//
// (*Server).FailJob audit log - get build {"error": "sql: no rows in result set"}
ignoreLogErrors := true
srvID := uuid.New()
srv, db, ps := setup(t, ignoreLogErrors, &overrides{id: &srvID})
srv, db, ps, pd := setup(t, ignoreLogErrors, &overrides{})
workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{
ID: uuid.New(),
AutomaticUpdates: database.AutomaticUpdatesNever,
@ -810,7 +845,7 @@ func TestFailJob(t *testing.T) {
require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: srvID,
UUID: pd.ID,
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -852,7 +887,7 @@ func TestCompleteJob(t *testing.T) {
ctx := context.Background()
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
srv, _, _ := setup(t, false, nil)
srv, _, _, _ := setup(t, false, nil)
_, err := srv.CompleteJob(ctx, &proto.CompletedJob{
JobId: "hello",
})
@ -866,7 +901,7 @@ func TestCompleteJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
srv, db, _, _ := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
@ -890,8 +925,7 @@ func TestCompleteJob(t *testing.T) {
t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
srv, db, _ := setup(t, false, &overrides{id: &srvID})
srv, db, _, pd := setup(t, false, &overrides{})
jobID := uuid.New()
versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
@ -909,7 +943,7 @@ func TestCompleteJob(t *testing.T) {
require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: srvID,
UUID: pd.ID,
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -939,9 +973,7 @@ func TestCompleteJob(t *testing.T) {
t.Run("TemplateImport_WithGitAuth", func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
srv, db, _ := setup(t, false, &overrides{
id: &srvID,
srv, db, _, pd := setup(t, false, &overrides{
externalAuthConfigs: []*externalauth.Config{{
ID: "github",
}},
@ -963,7 +995,7 @@ func TestCompleteJob(t *testing.T) {
require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: srvID,
UUID: pd.ID,
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -1106,9 +1138,8 @@ func TestCompleteJob(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
tss := &atomic.Pointer[schedule.TemplateScheduleStore]{}
srv, db, ps := setup(t, false, &overrides{id: &srvID, templateScheduleStore: tss})
srv, db, ps, pd := setup(t, false, &overrides{templateScheduleStore: tss})
var store schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{
GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) {
@ -1123,10 +1154,19 @@ func TestCompleteJob(t *testing.T) {
}
tss.Store(&store)
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{})
template := dbgen.Template(t, db, database.Template{
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
OrganizationID: org.ID,
})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
JobID: uuid.New(),
})
err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{
ID: template.ID,
@ -1148,13 +1188,6 @@ func TestCompleteJob(t *testing.T) {
TemplateID: template.ID,
Ttl: workspaceTTL,
})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
JobID: uuid.New(),
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: version.ID,
@ -1170,7 +1203,7 @@ func TestCompleteJob(t *testing.T) {
})
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: srvID,
UUID: pd.ID,
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -1315,19 +1348,17 @@ func TestCompleteJob(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
// Simulate the given time starting from now.
require.False(t, c.now.IsZero())
start := time.Now()
tss := &atomic.Pointer[schedule.TemplateScheduleStore]{}
uqhss := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{}
srv, db, ps := setup(t, false, &overrides{
srv, db, ps, pd := setup(t, false, &overrides{
timeNowFn: func() time.Time {
return c.now.Add(time.Since(start))
},
templateScheduleStore: tss,
userQuietHoursScheduleStore: uqhss,
id: &srvID,
})
var templateScheduleStore schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{
@ -1418,7 +1449,7 @@ func TestCompleteJob(t *testing.T) {
})
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: srvID,
UUID: pd.ID,
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -1484,8 +1515,7 @@ func TestCompleteJob(t *testing.T) {
})
t.Run("TemplateDryRun", func(t *testing.T) {
t.Parallel()
srvID := uuid.New()
srv, db, _ := setup(t, false, &overrides{id: &srvID})
srv, db, _, pd := setup(t, false, &overrides{})
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
@ -1495,7 +1525,7 @@ func TestCompleteJob(t *testing.T) {
require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: srvID,
UUID: pd.ID,
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
@ -1686,73 +1716,89 @@ func TestInsertWorkspaceResource(t *testing.T) {
}
type overrides struct {
ctx context.Context
deploymentValues *codersdk.DeploymentValues
externalAuthConfigs []*externalauth.Config
id *uuid.UUID
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
timeNowFn func() time.Time
acquireJobLongPollDuration time.Duration
heartbeatFn func(ctx context.Context) error
heartbeatInterval time.Duration
}
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) {
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub, database.ProvisionerDaemon) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
db := dbmem.New()
ps := pubsub.NewInMemory()
deploymentValues := &codersdk.DeploymentValues{}
var externalAuthConfigs []*externalauth.Config
srvID := uuid.New()
tss := testTemplateScheduleStore()
uqhss := testUserQuietHoursScheduleStore()
var timeNowFn func() time.Time
pollDur := time.Duration(0)
if ov != nil {
if ov.deploymentValues != nil {
deploymentValues = ov.deploymentValues
}
if ov.externalAuthConfigs != nil {
externalAuthConfigs = ov.externalAuthConfigs
}
if ov.id != nil {
srvID = *ov.id
}
if ov.templateScheduleStore != nil {
ttss := tss.Load()
// keep the initial test value if the override hasn't set the atomic pointer.
tss = ov.templateScheduleStore
if tss.Load() == nil {
swapped := tss.CompareAndSwap(nil, ttss)
require.True(t, swapped)
}
}
if ov.userQuietHoursScheduleStore != nil {
tuqhss := uqhss.Load()
// keep the initial test value if the override hasn't set the atomic pointer.
uqhss = ov.userQuietHoursScheduleStore
if uqhss.Load() == nil {
swapped := uqhss.CompareAndSwap(nil, tuqhss)
require.True(t, swapped)
}
}
if ov.timeNowFn != nil {
timeNowFn = ov.timeNowFn
}
pollDur = ov.acquireJobLongPollDuration
if ov == nil {
ov = &overrides{}
}
if ov.ctx == nil {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
ov.ctx = ctx
}
if ov.heartbeatInterval == 0 {
ov.heartbeatInterval = testutil.IntervalMedium
}
if ov.deploymentValues != nil {
deploymentValues = ov.deploymentValues
}
if ov.externalAuthConfigs != nil {
externalAuthConfigs = ov.externalAuthConfigs
}
if ov.templateScheduleStore != nil {
ttss := tss.Load()
// keep the initial test value if the override hasn't set the atomic pointer.
tss = ov.templateScheduleStore
if tss.Load() == nil {
swapped := tss.CompareAndSwap(nil, ttss)
require.True(t, swapped)
}
}
if ov.userQuietHoursScheduleStore != nil {
tuqhss := uqhss.Load()
// keep the initial test value if the override hasn't set the atomic pointer.
uqhss = ov.userQuietHoursScheduleStore
if uqhss.Load() == nil {
swapped := uqhss.CompareAndSwap(nil, tuqhss)
require.True(t, swapped)
}
}
if ov.timeNowFn != nil {
timeNowFn = ov.timeNowFn
}
pollDur = ov.acquireJobLongPollDuration
daemon, err := db.UpsertProvisionerDaemon(ov.ctx, database.UpsertProvisionerDaemonParams{
Name: "test",
CreatedAt: dbtime.Now(),
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: database.StringMap{},
LastSeenAt: sql.NullTime{},
Version: "",
APIVersion: "1.0",
})
require.NoError(t, err)
srv, err := provisionerdserver.NewServer(
ctx,
ov.ctx,
&url.URL{},
srvID,
daemon.ID,
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
[]database.ProvisionerType{database.ProvisionerTypeEcho},
provisionerdserver.Tags{},
provisionerdserver.Tags(daemon.Tags),
db,
ps,
provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps),
provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps),
telemetry.NewNoop(),
trace.NewNoopTracerProvider().Tracer("noop"),
&atomic.Pointer[proto.QuotaCommitter]{},
@ -1765,10 +1811,12 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
TimeNowFn: timeNowFn,
OIDCConfig: &oauth2.Config{},
AcquireJobLongPollDur: pollDur,
HeartbeatInterval: ov.heartbeatInterval,
HeartbeatFn: ov.heartbeatFn,
},
)
require.NoError(t, err)
return srv, db, ps
return srv, db, ps, daemon
}
func must[T any](value T, err error) T {

View File

@ -155,6 +155,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) {
// Users cannot do create/update/delete on themselves, but they
// can read their own details.
ResourceUser.Type: {ActionRead},
// Users can create provisioner daemons scoped to themselves.
ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate},
})...,
),
}.withCachedRegoValue()

View File

@ -164,9 +164,6 @@ func (c *Client) Organization(ctx context.Context, id uuid.UUID) (Organization,
}
// ProvisionerDaemons returns provisioner daemons available.
//
// Deprecated: We no longer track provisioner daemons as they connect. This function may return historical data
// but new provisioner daemons will not appear.
func (c *Client) ProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) {
res, err := c.Request(ctx, http.MethodGet,
// TODO: the organization path parameter is currently ignored.

View File

@ -4,13 +4,16 @@ import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@ -35,6 +38,17 @@ func TestProvisionerDaemon_PSK(t *testing.T) {
clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
pty.ExpectMatchContext(ctx, "matt-daemon")
var daemons []codersdk.ProvisionerDaemon
require.Eventually(t, func() bool {
daemons, err = client.ProvisionerDaemons(ctx)
if err != nil {
return false
}
return len(daemons) == 1
}, testutil.WaitShort, testutil.IntervalSlow)
require.Equal(t, "matt-daemon", daemons[0].Name)
require.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
}
func TestProvisionerDaemon_SessionToken(t *testing.T) {
@ -49,14 +63,61 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) {
},
},
})
anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user")
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user", "--name", "my-daemon")
clitest.SetupConfig(t, anotherClient, conf)
pty := ptytest.New(t).Attach(inv)
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
defer cancel()
clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
var daemons []codersdk.ProvisionerDaemon
var err error
require.Eventually(t, func() bool {
daemons, err = client.ProvisionerDaemons(ctx)
if err != nil {
return false
}
return len(daemons) == 1
}, testutil.WaitShort, testutil.IntervalSlow)
assert.Equal(t, "my-daemon", daemons[0].Name)
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope])
assert.Equal(t, anotherUser.ID.String(), daemons[0].Tags[provisionersdk.TagOwner])
})
t.Run("ScopeAnotherUser", func(t *testing.T) {
t.Parallel()
client, admin := coderdenttest.New(t, &coderdenttest.Options{
ProvisionerDaemonPSK: "provisionersftw",
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
})
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=user", "--tag", "owner="+admin.UserID.String(), "--name", "my-daemon")
clitest.SetupConfig(t, anotherClient, conf)
pty := ptytest.New(t).Attach(inv)
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
defer cancel()
clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
var daemons []codersdk.ProvisionerDaemon
var err error
require.Eventually(t, func() bool {
daemons, err = client.ProvisionerDaemons(ctx)
if err != nil {
return false
}
return len(daemons) == 1
}, testutil.WaitShort, testutil.IntervalSlow)
assert.Equal(t, "my-daemon", daemons[0].Name)
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope])
// This should get clobbered to the user who started the daemon.
assert.Equal(t, anotherUser.ID.String(), daemons[0].Tags[provisionersdk.TagOwner])
})
t.Run("ScopeOrg", func(t *testing.T) {
@ -69,13 +130,25 @@ func TestProvisionerDaemon_SessionToken(t *testing.T) {
},
},
})
anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=organization")
anotherClient, _ := coderdtest.CreateAnotherUser(t, client, admin.OrganizationID, rbac.RoleTemplateAdmin())
inv, conf := newCLI(t, "provisionerd", "start", "--tag", "scope=organization", "--name", "org-daemon")
clitest.SetupConfig(t, anotherClient, conf)
pty := ptytest.New(t).Attach(inv)
ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong)
defer cancel()
clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "starting provisioner daemon")
var daemons []codersdk.ProvisionerDaemon
var err error
require.Eventually(t, func() bool {
daemons, err = client.ProvisionerDaemons(ctx)
if err != nil {
return false
}
return len(daemons) == 1
}, testutil.WaitShort, testutil.IntervalSlow)
assert.Equal(t, "org-daemon", daemons[0].Name)
assert.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
})
}

View File

@ -26,6 +26,8 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/provisionerdserver"
@ -121,7 +123,7 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
// If using PSK auth, the daemon is, by definition, scoped to the organization.
tags[provisionersdk.TagScope] = provisionersdk.ScopeOrganization
tags = provisionersdk.MutateTags(uuid.Nil, tags)
return tags, true
}
}
@ -191,7 +193,10 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
if !authorized {
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags))
httpapi.Write(ctx, rw, http.StatusForbidden,
codersdk.Response{Message: "You aren't allowed to create provisioner daemons"})
codersdk.Response{
Message: fmt.Sprintf("You aren't allowed to create provisioner daemons with scope %q", tags[provisionersdk.TagScope]),
},
)
return
}
api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags))
@ -221,6 +226,35 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
slog.F("tags", tags),
)
authCtx := ctx
if r.Header.Get(codersdk.ProvisionerDaemonPSK) != "" {
//nolint:gocritic // PSK auth means no actor in request,
// so use system restricted.
authCtx = dbauthz.AsSystemRestricted(ctx)
}
// Create the daemon in the database.
now := dbtime.Now()
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
Name: name,
Provisioners: provisioners,
Tags: tags,
CreatedAt: now,
LastSeenAt: sql.NullTime{Time: now, Valid: true},
Version: "", // TODO: provisionerd needs to send version
APIVersion: "1.0",
})
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "create provisioner daemon", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error creating provisioner daemon.",
Detail: err.Error(),
})
}
return
}
api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1)
api.AGPL.WebsocketWaitMutex.Unlock()
@ -264,11 +298,13 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}
mux := drpcmux.New()
logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name))
srvCtx, srvCancel := context.WithCancel(ctx)
defer srvCancel()
logger.Info(ctx, "starting external provisioner daemon")
srv, err := provisionerdserver.NewServer(
api.ctx,
srvCtx,
api.AccessURL,
id,
daemon.ID,
logger,
provisioners,
tags,
@ -308,6 +344,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
},
})
err = server.Serve(ctx, session)
srvCancel()
logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err))
if err != nil && !xerrors.Is(err, io.EOF) {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))

View File

@ -7,6 +7,7 @@ import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
@ -50,6 +51,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
})
require.NoError(t, err)
srv.DRPCConn().Close()
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
require.Len(t, daemons, 1)
})
t.Run("NoLicense", func(t *testing.T) {
@ -163,6 +168,15 @@ func TestProvisionerDaemonServe(t *testing.T) {
file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, bytes.NewReader(data))
require.NoError(t, err)
require.Eventually(t, func() bool {
daemons, err := client.ProvisionerDaemons(context.Background())
assert.NoError(t, err, "failed to get provisioner daemons")
return len(daemons) > 0 &&
assert.Equal(t, t.Name(), daemons[0].Name) &&
assert.Equal(t, provisionersdk.ScopeUser, daemons[0].Tags[provisionersdk.TagScope]) &&
assert.Equal(t, user.UserID.String(), daemons[0].Tags[provisionersdk.TagOwner])
}, testutil.WaitShort, testutil.IntervalMedium)
version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{
Name: "example",
StorageMethod: codersdk.ProvisionerStorageMethodFile,
@ -211,6 +225,12 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.NoError(t, err)
err = srv.DRPCConn().Close()
require.NoError(t, err)
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
if assert.Len(t, daemons, 1) {
assert.Equal(t, provisionersdk.ScopeOrganization, daemons[0].Tags[provisionersdk.TagScope])
}
})
t.Run("PSK_daily_cost", func(t *testing.T) {
@ -346,6 +366,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
require.Len(t, daemons, 0)
})
t.Run("NoAuth", func(t *testing.T) {
@ -376,6 +400,10 @@ func TestProvisionerDaemonServe(t *testing.T) {
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
require.Len(t, daemons, 0)
})
t.Run("NoPSK", func(t *testing.T) {
@ -406,5 +434,9 @@ func TestProvisionerDaemonServe(t *testing.T) {
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
require.Len(t, daemons, 0)
})
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"reflect"
"sync"
"time"
@ -20,6 +21,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionerd/runner"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
@ -199,6 +201,12 @@ connectLoop:
if errors.Is(err, context.Canceled) {
return
}
var sdkErr *codersdk.Error
// If something is wrong with our auth, stop trying to connect.
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusForbidden {
p.opts.Logger.Error(p.closeContext, "not authorized to dial coderd", slog.Error(err))
return
}
if p.isClosed() {
return
}

View File

@ -14,7 +14,10 @@ const (
// If the scope is "user", the "owner" is changed to the user ID.
// This is for user-scoped provisioner daemons, where users should
// own their own operations.
// Otherwise, the "owner" tag is always empty.
// Otherwise, the "owner" tag is always an empty string.
// NOTE: "owner" must NEVER be nil. Otherwise it will end up being
// duplicated in the database, as idx_provisioner_daemons_name_owner_key
// is a partial unique index that includes a JSON field.
func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
if tags == nil {
tags = map[string]string{}
@ -22,15 +25,16 @@ func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
_, ok := tags[TagScope]
if !ok {
tags[TagScope] = ScopeOrganization
delete(tags, TagOwner)
tags[TagOwner] = ""
}
switch tags[TagScope] {
case ScopeUser:
tags[TagOwner] = userID.String()
case ScopeOrganization:
delete(tags, TagOwner)
tags[TagOwner] = ""
default:
tags[TagScope] = ScopeOrganization
tags[TagOwner] = ""
}
return tags
}

View File

@ -27,6 +27,7 @@ func TestMutateTags(t *testing.T) {
tags: nil,
want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
},
},
{
@ -35,6 +36,7 @@ func TestMutateTags(t *testing.T) {
tags: map[string]string{},
want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
},
},
{
@ -52,6 +54,7 @@ func TestMutateTags(t *testing.T) {
userID: testUserID,
want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
},
},
{
@ -63,6 +66,7 @@ func TestMutateTags(t *testing.T) {
userID: uuid.Nil,
want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
},
},
{
@ -73,6 +77,7 @@ func TestMutateTags(t *testing.T) {
userID: uuid.Nil,
want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
},
},
{
@ -81,6 +86,7 @@ func TestMutateTags(t *testing.T) {
userID: testUserID,
want: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
provisionersdk.TagOwner: "",
},
},
} {