chore: enforce that provisioners can only acquire jobs in their own organization (#12600)

* chore: add org ID as optional param to AcquireJob
* chore: plumb through organization id to provisioner daemons
* add org id to provisioner domain key
* enforce org id argument
* dbgen provisioner jobs defaults to default org
This commit is contained in:
Steven Masley 2024-03-18 12:48:13 -05:00 committed by GitHub
parent 0e8ebb9b22
commit f0f9569d51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 204 additions and 126 deletions

View File

@ -1271,6 +1271,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(dialCtx context.Context, name st
api.ctx, // use the same ctx as the API
api.AccessURL,
daemon.ID,
defaultOrg.ID,
logger,
daemon.Provisioners,
provisionerdserver.Tags(daemon.Tags),

View File

@ -2093,7 +2093,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
StartedAt: sql.NullTime{Valid: false},
})
check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
check.Args(database.AcquireProvisionerJobParams{OrganizationID: j.OrganizationID, Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ )
}))
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {

View File

@ -187,6 +187,7 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
// import job as well
for {
j, err := b.db.AcquireProvisionerJob(ownerCtx, database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,

View File

@ -387,6 +387,12 @@ func GroupMember(t testing.TB, db database.Store, orig database.GroupMember) dat
func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig database.ProvisionerJob) database.ProvisionerJob {
t.Helper()
var defOrgID uuid.UUID
if orig.OrganizationID == uuid.Nil {
defOrg, _ := db.GetDefaultOrganization(genCtx)
defOrgID = defOrg.ID
}
jobID := takeFirst(orig.ID, uuid.New())
// Always set some tags to prevent Acquire from grabbing jobs it should not.
if !orig.StartedAt.Time.IsZero() {
@ -401,7 +407,7 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
ID: jobID,
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
OrganizationID: takeFirst(orig.OrganizationID, defOrgID, uuid.New()),
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
Provisioner: takeFirst(orig.Provisioner, database.ProvisionerTypeEcho),
StorageMethod: takeFirst(orig.StorageMethod, database.ProvisionerStorageMethodFile),
@ -418,10 +424,11 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
}
if !orig.StartedAt.Time.IsZero() {
job, err = db.AcquireProvisionerJob(genCtx, database.AcquireProvisionerJobParams{
StartedAt: orig.StartedAt,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
StartedAt: orig.StartedAt,
OrganizationID: job.OrganizationID,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
})
require.NoError(t, err)
// There is no easy way to make sure we acquire the correct job.

View File

@ -803,6 +803,9 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
defer q.mutex.Unlock()
for index, provisionerJob := range q.provisionerJobs {
if provisionerJob.OrganizationID != arg.OrganizationID {
continue
}
if provisionerJob.StartedAt.Valid {
continue
}
@ -7861,15 +7864,16 @@ func (q *FakeQuerier) UpsertProvisionerDaemon(_ context.Context, arg database.Up
}
}
d := database.ProvisionerDaemon{
ID: uuid.New(),
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: maps.Clone(arg.Tags),
ReplicaID: uuid.NullUUID{},
LastSeenAt: arg.LastSeenAt,
Version: arg.Version,
APIVersion: arg.APIVersion,
ID: uuid.New(),
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: maps.Clone(arg.Tags),
ReplicaID: uuid.NullUUID{},
LastSeenAt: arg.LastSeenAt,
Version: arg.Version,
APIVersion: arg.APIVersion,
OrganizationID: arg.OrganizationID,
}
q.provisionerDaemons = append(q.provisionerDaemons, d)
return d, nil

View File

@ -363,6 +363,7 @@ func TestQueuePosition(t *testing.T) {
}
job, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: org.ID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,

View File

@ -3941,14 +3941,15 @@ WHERE
provisioner_jobs AS nested
WHERE
nested.started_at IS NULL
AND nested.organization_id = $3
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY($3 :: provisioner_type [ ])
AND nested.provisioner = ANY($4 :: provisioner_type [ ])
AND CASE
-- Special case for untagged provisioners: only match untagged jobs.
WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
THEN nested.tags :: jsonb = $4 :: jsonb
THEN nested.tags :: jsonb = $5 :: jsonb
-- Ensure the caller satisfies all job tags.
ELSE nested.tags :: jsonb <@ $4 :: jsonb
ELSE nested.tags :: jsonb <@ $5 :: jsonb
END
ORDER BY
nested.created_at
@ -3960,10 +3961,11 @@ WHERE
`
type AcquireProvisionerJobParams struct {
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
Types []ProvisionerType `db:"types" json:"types"`
Tags json.RawMessage `db:"tags" json:"tags"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
Types []ProvisionerType `db:"types" json:"types"`
Tags json.RawMessage `db:"tags" json:"tags"`
}
// Acquires the lock for a single job that isn't started, completed,
@ -3976,6 +3978,7 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi
row := q.db.QueryRowContext(ctx, acquireProvisionerJob,
arg.StartedAt,
arg.WorkerID,
arg.OrganizationID,
pq.Array(arg.Types),
arg.Tags,
)

View File

@ -19,6 +19,7 @@ WHERE
provisioner_jobs AS nested
WHERE
nested.started_at IS NULL
AND nested.organization_id = @organization_id
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
AND CASE

View File

@ -134,6 +134,7 @@ func TestWorkspaces(t *testing.T) {
require.NoError(t, err)
// This marks the job as started.
_, err = db.AcquireProvisionerJob(context.Background(), database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,

View File

@ -89,16 +89,17 @@ func NewAcquirer(ctx context.Context, logger slog.Logger, store AcquirerStore, p
// done, or the database returns an error _other_ than that no jobs are available.
// If no jobs are available, this method handles retrying as appropriate.
func (a *Acquirer) AcquireJob(
ctx context.Context, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
ctx context.Context, organization uuid.UUID, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
) (
retJob database.ProvisionerJob, retErr error,
) {
logger := a.logger.With(
slog.F("organization_id", organization),
slog.F("worker_id", worker),
slog.F("provisioner_types", pt),
slog.F("tags", tags))
logger.Debug(ctx, "acquiring job")
dk := domainKey(pt, tags)
dk := domainKey(organization, pt, tags)
dbTags, err := tags.ToJSON()
if err != nil {
return database.ProvisionerJob{}, err
@ -106,7 +107,7 @@ func (a *Acquirer) AcquireJob(
// buffer of 1 so that cancel doesn't deadlock while writing to the channel
clearance := make(chan struct{}, 1)
for {
a.want(pt, tags, clearance)
a.want(organization, pt, tags, clearance)
select {
case <-ctx.Done():
err := ctx.Err()
@ -120,6 +121,7 @@ func (a *Acquirer) AcquireJob(
case <-clearance:
logger.Debug(ctx, "got clearance to call database")
job, err := a.store.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: organization,
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
@ -152,8 +154,8 @@ func (a *Acquirer) AcquireJob(
}
// want signals that an acquiree wants clearance to query for a job with the given dKey.
func (a *Acquirer) want(pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
dk := domainKey(pt, tags)
func (a *Acquirer) want(organization uuid.UUID, pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
dk := domainKey(organization, pt, tags)
a.mu.Lock()
defer a.mu.Unlock()
cleared := false
@ -404,13 +406,16 @@ type dKey string
// unprintable control character and won't show up in any "reasonable" set of
// string tags, even in non-Latin scripts. It is important that Tags are
// validated not to contain this control character prior to use.
func domainKey(pt []database.ProvisionerType, tags Tags) dKey {
func domainKey(orgID uuid.UUID, pt []database.ProvisionerType, tags Tags) dKey {
sb := strings.Builder{}
_, _ = sb.WriteString(orgID.String())
_ = sb.WriteByte(0x00)
// make a copy of pt before sorting, so that we don't mutate the original
// slice or underlying array.
pts := make([]database.ProvisionerType, len(pt))
copy(pts, pt)
slices.Sort(pts)
sb := strings.Builder{}
for _, t := range pts {
_, _ = sb.WriteString(string(t))
_ = sb.WriteByte(0x00)

View File

@ -53,12 +53,13 @@ func TestAcquirer_Single(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
orgID := uuid.New()
workerID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"environment": "on-prem",
}
acquiree := newTestAcquiree(t, workerID, pt, tags)
acquiree := newTestAcquiree(t, orgID, workerID, pt, tags)
jobID := uuid.New()
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
require.NoError(t, err)
@ -82,6 +83,7 @@ func TestAcquirer_MultipleSameDomain(t *testing.T) {
acquirees := make([]*testAcquiree, 0, 10)
jobIDs := make(map[uuid.UUID]bool)
workerIDs := make(map[uuid.UUID]bool)
orgID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"environment": "on-prem",
@ -89,7 +91,7 @@ func TestAcquirer_MultipleSameDomain(t *testing.T) {
for i := 0; i < 10; i++ {
wID := uuid.New()
workerIDs[wID] = true
a := newTestAcquiree(t, wID, pt, tags)
a := newTestAcquiree(t, orgID, wID, pt, tags)
acquirees = append(acquirees, a)
a.startAcquire(ctx, uut)
}
@ -124,12 +126,13 @@ func TestAcquirer_WaitsOnNoJobs(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
orgID := uuid.New()
workerID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"environment": "on-prem",
}
acquiree := newTestAcquiree(t, workerID, pt, tags)
acquiree := newTestAcquiree(t, orgID, workerID, pt, tags)
jobID := uuid.New()
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
require.NoError(t, err)
@ -175,12 +178,13 @@ func TestAcquirer_RetriesPending(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
orgID := uuid.New()
workerID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"environment": "on-prem",
}
acquiree := newTestAcquiree(t, workerID, pt, tags)
acquiree := newTestAcquiree(t, orgID, workerID, pt, tags)
jobID := uuid.New()
acquiree.startAcquire(ctx, uut)
@ -217,17 +221,18 @@ func TestAcquirer_DifferentDomains(t *testing.T) {
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
orgID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
worker0 := uuid.New()
tags0 := provisionerdserver.Tags{
"worker": "0",
}
acquiree0 := newTestAcquiree(t, worker0, pt, tags0)
acquiree0 := newTestAcquiree(t, orgID, worker0, pt, tags0)
worker1 := uuid.New()
tags1 := provisionerdserver.Tags{
"worker": "1",
}
acquiree1 := newTestAcquiree(t, worker1, pt, tags1)
acquiree1 := newTestAcquiree(t, orgID, worker1, pt, tags1)
jobID := uuid.New()
fs.jobs = []database.ProvisionerJob{
{ID: jobID, Provisioner: database.ProvisionerTypeEcho, Tags: database.StringMap{"worker": "1"}},
@ -268,11 +273,12 @@ func TestAcquirer_BackupPoll(t *testing.T) {
)
workerID := uuid.New()
orgID := uuid.New()
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
tags := provisionerdserver.Tags{
"environment": "on-prem",
}
acquiree := newTestAcquiree(t, workerID, pt, tags)
acquiree := newTestAcquiree(t, orgID, workerID, pt, tags)
jobID := uuid.New()
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
require.NoError(t, err)
@ -294,13 +300,14 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
orgID := uuid.New()
worker0 := uuid.New()
tags := provisionerdserver.Tags{
"environment": "on-prem",
}
acquiree0 := newTestAcquiree(t, worker0, pt, tags)
acquiree0 := newTestAcquiree(t, orgID, worker0, pt, tags)
worker1 := uuid.New()
acquiree1 := newTestAcquiree(t, worker1, pt, tags)
acquiree1 := newTestAcquiree(t, orgID, worker1, pt, tags)
jobID := uuid.New()
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
@ -329,8 +336,10 @@ func TestAcquirer_MatchTags(t *testing.T) {
testCases := []struct {
name string
provisionerJobTags map[string]string
acquireJobTags map[string]string
expectAcquire bool
acquireJobTags map[string]string
unmatchedOrg bool // acquire will use a random org id
expectAcquire bool
}{
{
name: "untagged provisioner and untagged job",
@ -452,6 +461,13 @@ func TestAcquirer_MatchTags(t *testing.T) {
acquireJobTags: map[string]string{"scope": "user", "owner": "aaa", "environment": "on-prem", "datacenter": "chicago"},
expectAcquire: false,
},
{
name: "matching tags with unmatched org",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "environment": "on-prem"},
expectAcquire: false,
unmatchedOrg: true,
},
}
for _, tt := range testCases {
tt := tt
@ -486,7 +502,12 @@ func TestAcquirer_MatchTags(t *testing.T) {
require.NoError(t, err)
ptypes := []database.ProvisionerType{database.ProvisionerTypeEcho}
acq := provisionerdserver.NewAcquirer(ctx, log, db, ps)
aj, err := acq.AcquireJob(ctx, uuid.New(), ptypes, tt.acquireJobTags)
acquireOrgID := org.ID
if tt.unmatchedOrg {
acquireOrgID = uuid.New()
}
aj, err := acq.AcquireJob(ctx, acquireOrgID, uuid.New(), ptypes, tt.acquireJobTags)
if tt.expectAcquire {
assert.NoError(t, err)
assert.Equal(t, pj.ID, aj.ID)
@ -659,6 +680,7 @@ jobLoop:
// and asserting whether or not it returns, blocks, or is canceled.
type testAcquiree struct {
t *testing.T
orgID uuid.UUID
workerID uuid.UUID
pt []database.ProvisionerType
tags provisionerdserver.Tags
@ -666,9 +688,10 @@ type testAcquiree struct {
jc chan database.ProvisionerJob
}
func newTestAcquiree(t *testing.T, workerID uuid.UUID, pt []database.ProvisionerType, tags provisionerdserver.Tags) *testAcquiree {
func newTestAcquiree(t *testing.T, orgID uuid.UUID, workerID uuid.UUID, pt []database.ProvisionerType, tags provisionerdserver.Tags) *testAcquiree {
return &testAcquiree{
t: t,
orgID: orgID,
workerID: workerID,
pt: pt,
tags: tags,
@ -679,7 +702,7 @@ func newTestAcquiree(t *testing.T, workerID uuid.UUID, pt []database.Provisioner
func (a *testAcquiree) startAcquire(ctx context.Context, uut *provisionerdserver.Acquirer) {
go func() {
j, e := uut.AcquireJob(ctx, a.workerID, a.pt, a.tags)
j, e := uut.AcquireJob(ctx, a.orgID, a.workerID, a.pt, a.tags)
a.ec <- e
a.jc <- j
}()

View File

@ -81,6 +81,7 @@ type server struct {
lifecycleCtx context.Context
AccessURL *url.URL
ID uuid.UUID
OrganizationID uuid.UUID
Logger slog.Logger
Provisioners []database.ProvisionerType
ExternalAuthConfigs []*externalauth.Config
@ -134,6 +135,7 @@ func NewServer(
lifecycleCtx context.Context,
accessURL *url.URL,
id uuid.UUID,
organizationID uuid.UUID,
logger slog.Logger,
provisioners []database.ProvisionerType,
tags Tags,
@ -188,6 +190,7 @@ func NewServer(
lifecycleCtx: lifecycleCtx,
AccessURL: accessURL,
ID: id,
OrganizationID: organizationID,
Logger: logger,
Provisioners: provisioners,
ExternalAuthConfigs: options.ExternalAuthConfigs,
@ -287,7 +290,7 @@ func (s *server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Acquire
// database.
acqCtx, acqCancel := context.WithTimeout(ctx, s.acquireJobLongPollDur)
defer acqCancel()
job, err := s.Acquirer.AcquireJob(acqCtx, s.ID, s.Provisioners, s.Tags)
job, err := s.Acquirer.AcquireJob(acqCtx, s.OrganizationID, s.ID, s.Provisioners, s.Tags)
if xerrors.Is(err, context.DeadlineExceeded) {
s.Logger.Debug(ctx, "successful cancel")
return &proto.AcquiredJob{}, nil
@ -324,7 +327,7 @@ func (s *server) AcquireJobWithCancel(stream proto.DRPCProvisionerDaemon_Acquire
}()
jec := make(chan jobAndErr, 1)
go func() {
job, err := s.Acquirer.AcquireJob(acqCtx, s.ID, s.Provisioners, s.Tags)
job, err := s.Acquirer.AcquireJob(acqCtx, s.OrganizationID, s.ID, s.Provisioners, s.Tags)
jec <- jobAndErr{job: job, err: err}
}()
var recvErr error

View File

@ -152,15 +152,17 @@ func TestAcquireJob(t *testing.T) {
tc := tc
t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) {
t.Parallel()
srv, db, _, _ := setup(t, false, nil)
srv, db, _, pd := setup(t, false, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
_, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
ID: uuid.New(),
InitiatorID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
OrganizationID: pd.OrganizationID,
ID: uuid.New(),
InitiatorID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
})
require.NoError(t, err)
_, err = tc.acquire(ctx, srv)
@ -176,7 +178,7 @@ func TestAcquireJob(t *testing.T) {
Id: "github",
}
srv, db, ps, _ := setup(t, false, &overrides{
srv, db, ps, pd := setup(t, false, &overrides{
deploymentValues: dv,
externalAuthConfigs: []*externalauth.Config{{
ID: gitAuthProvider.Id,
@ -198,12 +200,14 @@ func TestAcquireJob(t *testing.T) {
UserID: user.ID,
})
template := dbgen.Template(t, db, database.Template{
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
OrganizationID: pd.OrganizationID,
})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
versionFile := dbgen.File(t, db, database.File{CreatedBy: user.ID})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: pd.OrganizationID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
@ -223,12 +227,13 @@ func TestAcquireJob(t *testing.T) {
require.NoError(t, err)
// Import version job
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
ID: version.JobID,
InitiatorID: user.ID,
FileID: versionFile.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport,
OrganizationID: pd.OrganizationID,
ID: version.JobID,
InitiatorID: user.ID,
FileID: versionFile.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{
TemplateVersionID: version.ID,
UserVariableValues: []codersdk.VariableValue{
@ -252,8 +257,9 @@ func TestAcquireJob(t *testing.T) {
Sensitive: false,
})
workspace := dbgen.Workspace(t, db, database.Workspace{
TemplateID: template.ID,
OwnerID: user.ID,
TemplateID: template.ID,
OwnerID: user.ID,
OrganizationID: pd.OrganizationID,
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
@ -264,12 +270,13 @@ func TestAcquireJob(t *testing.T) {
Reason: database.BuildReasonInitiator,
})
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
ID: build.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
ID: build.ID,
OrganizationID: pd.OrganizationID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
})),
@ -900,15 +907,17 @@ func TestCompleteJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) {
t.Parallel()
srv, db, _, _ := setup(t, false, nil)
srv, db, _, pd := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
OrganizationID: pd.OrganizationID,
})
require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID,
WorkerID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
@ -928,19 +937,22 @@ func TestCompleteJob(t *testing.T) {
jobID := uuid.New()
versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID,
JobID: jobID,
ID: versionID,
JobID: jobID,
OrganizationID: pd.OrganizationID,
})
require.NoError(t, err)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: jobID,
Provisioner: database.ProvisionerTypeEcho,
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
ID: jobID,
Provisioner: database.ProvisionerTypeEcho,
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
OrganizationID: pd.OrganizationID,
})
require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID,
WorkerID: uuid.NullUUID{
UUID: pd.ID,
Valid: true,
@ -982,19 +994,22 @@ func TestCompleteJob(t *testing.T) {
jobID := uuid.New()
versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID,
JobID: jobID,
ID: versionID,
JobID: jobID,
OrganizationID: pd.OrganizationID,
})
require.NoError(t, err)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: jobID,
Provisioner: database.ProvisionerTypeEcho,
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
OrganizationID: pd.OrganizationID,
ID: jobID,
Provisioner: database.ProvisionerTypeEcho,
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID,
WorkerID: uuid.NullUUID{
UUID: pd.ID,
Valid: true,
@ -1155,14 +1170,14 @@ func TestCompleteJob(t *testing.T) {
}
tss.Store(&store)
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{})
template := dbgen.Template(t, db, database.Template{
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
OrganizationID: org.ID,
OrganizationID: pd.OrganizationID,
})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: pd.OrganizationID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
@ -1186,8 +1201,9 @@ func TestCompleteJob(t *testing.T) {
}
}
workspace := dbgen.Workspace(t, db, database.Workspace{
TemplateID: template.ID,
Ttl: workspaceTTL,
TemplateID: template.ID,
Ttl: workspaceTTL,
OrganizationID: pd.OrganizationID,
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
@ -1196,13 +1212,15 @@ func TestCompleteJob(t *testing.T) {
Reason: database.BuildReasonInitiator,
})
job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
OrganizationID: pd.OrganizationID,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
})),
})
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID,
WorkerID: uuid.NullUUID{
UUID: pd.ID,
Valid: true,
@ -1400,8 +1418,9 @@ func TestCompleteJob(t *testing.T) {
QuietHoursSchedule: c.userQuietHoursSchedule,
})
template := dbgen.Template(t, db, database.Template{
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
OrganizationID: pd.OrganizationID,
})
err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{
ID: template.ID,
@ -1424,11 +1443,13 @@ func TestCompleteJob(t *testing.T) {
}
}
workspace := dbgen.Workspace(t, db, database.Workspace{
TemplateID: template.ID,
Ttl: workspaceTTL,
OwnerID: user.ID,
TemplateID: template.ID,
Ttl: workspaceTTL,
OwnerID: user.ID,
OrganizationID: pd.OrganizationID,
})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: pd.OrganizationID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
@ -1447,8 +1468,10 @@ func TestCompleteJob(t *testing.T) {
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
})),
OrganizationID: pd.OrganizationID,
})
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID,
WorkerID: uuid.NullUUID{
UUID: pd.ID,
Valid: true,
@ -1733,6 +1756,9 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
db := dbmem.New()
ps := pubsub.NewInMemory()
defOrg, err := db.GetDefaultOrganization(context.Background())
require.NoError(t, err, "default org not found")
deploymentValues := &codersdk.DeploymentValues{}
var externalAuthConfigs []*externalauth.Config
tss := testTemplateScheduleStore()
@ -1780,13 +1806,14 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
pollDur = ov.acquireJobLongPollDuration
daemon, err := db.UpsertProvisionerDaemon(ov.ctx, database.UpsertProvisionerDaemonParams{
Name: "test",
CreatedAt: dbtime.Now(),
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: database.StringMap{},
LastSeenAt: sql.NullTime{},
Version: buildinfo.Version(),
APIVersion: proto.CurrentVersion.String(),
Name: "test",
CreatedAt: dbtime.Now(),
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: database.StringMap{},
LastSeenAt: sql.NullTime{},
Version: buildinfo.Version(),
APIVersion: proto.CurrentVersion.String(),
OrganizationID: defOrg.ID,
})
require.NoError(t, err)
@ -1794,6 +1821,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
ov.ctx,
&url.URL{},
daemon.ID,
defOrg.ID,
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
[]database.ProvisionerType{database.ProvisionerTypeEcho},
provisionerdserver.Tags(daemon.Tags),

View File

@ -317,6 +317,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
srvCtx,
api.AccessURL,
daemon.ID,
organization.ID,
logger,
provisioners,
tags,

View File

@ -28,7 +28,6 @@ func TestTemplateUpdateBuildDeadlines(t *testing.T) {
db, _ := dbtestutil.NewDB(t)
var (
org = dbgen.Organization(t, db, database.Organization{})
quietUser = dbgen.User(t, db, database.User{
Username: "quiet",
})
@ -39,18 +38,18 @@ func TestTemplateUpdateBuildDeadlines(t *testing.T) {
CreatedBy: quietUser.ID,
})
templateJob = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
FileID: file.ID,
InitiatorID: quietUser.ID,
FileID: file.ID,
InitiatorID: quietUser.ID,
Tags: database.StringMap{
"foo": "bar",
},
})
templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
OrganizationID: templateJob.OrganizationID,
CreatedBy: quietUser.ID,
JobID: templateJob.ID,
})
organizationID = templateJob.OrganizationID
)
const userQuietHoursSchedule = "CRON_TZ=UTC 0 0 * * *" // midnight UTC
@ -204,17 +203,17 @@ func TestTemplateUpdateBuildDeadlines(t *testing.T) {
var (
template = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
OrganizationID: organizationID,
ActiveVersionID: templateVersion.ID,
CreatedBy: user.ID,
})
ws = dbgen.Workspace(t, db, database.Workspace{
OrganizationID: org.ID,
OrganizationID: organizationID,
OwnerID: user.ID,
TemplateID: template.ID,
})
job = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
OrganizationID: organizationID,
FileID: file.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
@ -236,6 +235,7 @@ func TestTemplateUpdateBuildDeadlines(t *testing.T) {
require.NotEmpty(t, wsBuild.ProvisionerState, "provisioner state must not be empty")
acquiredJob, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: buildTime,
Valid: true,
@ -324,41 +324,39 @@ func TestTemplateUpdateBuildDeadlinesSkip(t *testing.T) {
db, _ := dbtestutil.NewDB(t)
var (
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{
CreatedBy: user.ID,
})
templateJob = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
FileID: file.ID,
InitiatorID: user.ID,
FileID: file.ID,
InitiatorID: user.ID,
Tags: database.StringMap{
"foo": "bar",
},
})
templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
CreatedBy: user.ID,
JobID: templateJob.ID,
OrganizationID: templateJob.OrganizationID,
})
template = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
ActiveVersionID: templateVersion.ID,
CreatedBy: user.ID,
OrganizationID: templateJob.OrganizationID,
})
otherTemplate = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
ActiveVersionID: templateVersion.ID,
CreatedBy: user.ID,
OrganizationID: templateJob.OrganizationID,
})
)
// Create a workspace that will be shared by two builds.
ws := dbgen.Workspace(t, db, database.Workspace{
OrganizationID: org.ID,
OwnerID: user.ID,
TemplateID: template.ID,
OrganizationID: templateJob.OrganizationID,
})
const userQuietHoursSchedule = "CRON_TZ=UTC 0 0 * * *" // midnight UTC
@ -473,20 +471,20 @@ func TestTemplateUpdateBuildDeadlinesSkip(t *testing.T) {
wsID := b.workspaceID
if wsID == uuid.Nil {
ws := dbgen.Workspace(t, db, database.Workspace{
OrganizationID: org.ID,
OwnerID: user.ID,
TemplateID: b.templateID,
OrganizationID: templateJob.OrganizationID,
})
wsID = ws.ID
}
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
FileID: file.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
FileID: file.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
Tags: database.StringMap{
wsID.String(): "yeah",
},
OrganizationID: templateJob.OrganizationID,
})
wsBuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: wsID,
@ -521,6 +519,7 @@ func TestTemplateUpdateBuildDeadlinesSkip(t *testing.T) {
}
acquiredJob, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: job.OrganizationID,
StartedAt: sql.NullTime{
Time: buildTime,
Valid: true,