fix(coderd): only allow untagged provisioners to pick up untagged jobs (#12269)

Alternative solution to #6442

Modifies the behaviour of AcquireProvisionerJob and adds a special case for 'un-tagged' jobs such that they can only be picked up by 'un-tagged' provisioners.

Also adds comprehensive test coverage for AcquireJob given various combinations of tags.
This commit is contained in:
Cian Johnston 2024-02-22 15:04:31 +00:00 committed by GitHub
parent aa7a12a5ec
commit 53e8f9c0f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 173 additions and 19 deletions

View File

@ -748,6 +748,25 @@ var deletedUserLinkError = &pq.Error{
Routine: "exec_stmt_raise",
}
// m1 and m2 are equal iff |m1| = |m2| ^ m2 ⊆ m1
func tagsEqual(m1, m2 map[string]string) bool {
return len(m1) == len(m2) && tagsSubset(m1, m2)
}
// m2 is a subset of m1 if each key in m1 exists in m2
// with the same value
func tagsSubset(m1, m2 map[string]string) bool {
for k, v1 := range m1 {
if v2, found := m2[k]; !found || v1 != v2 {
return false
}
}
return true
}
// default tags when no tag is specified for a provisioner or job
var tagsUntagged = provisionersdk.MutateTags(uuid.Nil, nil)
func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
return xerrors.New("AcquireLock must only be called within a transaction")
}
@ -783,19 +802,15 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
}
}
missing := false
for key, value := range provisionerJob.Tags {
provided, found := tags[key]
if !found {
missing = true
break
}
if provided != value {
missing = true
break
}
// Special case for untagged provisioners: only match untagged jobs.
// Ref: coderd/database/queries/provisionerjobs.sql:24-30
// CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
// THEN nested.tags :: jsonb = @tags :: jsonb
if tagsEqual(provisionerJob.Tags, tagsUntagged) && !tagsEqual(provisionerJob.Tags, tags) {
continue
}
if missing {
// ELSE nested.tags :: jsonb <@ @tags :: jsonb
if !tagsSubset(provisionerJob.Tags, tags) {
continue
}
provisionerJob.StartedAt = arg.StartedAt

View File

@ -3936,8 +3936,13 @@ WHERE
nested.started_at IS NULL
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY($3 :: provisioner_type [ ])
-- Ensure the caller satisfies all job tags.
AND nested.tags <@ $4 :: jsonb
AND CASE
-- Special case for untagged provisioners: only match untagged jobs.
WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
THEN nested.tags :: jsonb = $4 :: jsonb
-- Ensure the caller satisfies all job tags.
ELSE nested.tags :: jsonb <@ $4 :: jsonb
END
ORDER BY
nested.created_at
FOR UPDATE

View File

@ -21,8 +21,13 @@ WHERE
nested.started_at IS NULL
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
-- Ensure the caller satisfies all job tags.
AND nested.tags <@ @tags :: jsonb
AND CASE
-- Special case for untagged provisioners: only match untagged jobs.
WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
THEN nested.tags :: jsonb = @tags :: jsonb
-- Ensure the caller satisfies all job tags.
ELSE nested.tags :: jsonb <@ @tags :: jsonb
END
ORDER BY
nested.created_at
FOR UPDATE

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
@ -18,6 +19,8 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/provisionerdserver"
@ -315,6 +318,133 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) {
require.Equal(t, jobID, job.ID)
}
func TestAcquirer_MatchTags(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skipping this test due to -short")
}
someID := uuid.NewString()
someOtherID := uuid.NewString()
for _, tt := range []struct {
name string
provisionerJobTags map[string]string
acquireJobTags map[string]string
expectAcquire bool
}{
{
name: "untagged provisioner and untagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
expectAcquire: true,
},
{
name: "untagged provisioner and tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
expectAcquire: false,
},
{
name: "tagged provisioner and untagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
expectAcquire: false,
},
{
name: "tagged provisioner and tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
expectAcquire: true,
},
{
name: "tagged provisioner and double-tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
expectAcquire: false,
},
{
name: "double-tagged provisioner and tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
expectAcquire: true,
},
{
name: "double-tagged provisioner and double-tagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
acquireJobTags: map[string]string{"scope": "organization", "owner": "", "foo": "bar", "baz": "zap"},
expectAcquire: true,
},
{
name: "owner-scoped provisioner and untagged job",
provisionerJobTags: map[string]string{"scope": "organization", "owner": ""},
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
expectAcquire: false,
},
{
name: "owner-scoped provisioner and owner-scoped job",
provisionerJobTags: map[string]string{"scope": "owner", "owner": someID},
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
expectAcquire: true,
},
{
name: "owner-scoped provisioner and different owner-scoped job",
provisionerJobTags: map[string]string{"scope": "owner", "owner": someOtherID},
acquireJobTags: map[string]string{"scope": "owner", "owner": someID},
expectAcquire: false,
},
{
name: "org-scoped provisioner and owner-scoped job",
provisionerJobTags: map[string]string{"scope": "owner", "owner": someID},
acquireJobTags: map[string]string{"scope": "organization", "owner": ""},
expectAcquire: false,
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort/2)
// NOTE: explicitly not using fake store for this test.
db, ps := dbtestutil.NewDB(t)
log := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
org, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{
ID: uuid.New(),
Name: "test org",
Description: "the organization of testing",
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
})
require.NoError(t, err)
pj, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
OrganizationID: org.ID,
InitiatorID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: uuid.New(),
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
Tags: tt.provisionerJobTags,
TraceMetadata: pqtype.NullRawMessage{},
})
require.NoError(t, err)
ptypes := []database.ProvisionerType{database.ProvisionerTypeEcho}
acq := provisionerdserver.NewAcquirer(ctx, log, db, ps)
aj, err := acq.AcquireJob(ctx, uuid.New(), ptypes, tt.acquireJobTags)
if tt.expectAcquire {
assert.NoError(t, err)
assert.Equal(t, pj.ID, aj.ID)
} else {
assert.Empty(t, aj, "should not have acquired job")
assert.ErrorIs(t, err, context.DeadlineExceeded, "should have timed out")
}
})
}
}
func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) {
t.Helper()
msg, err := json.Marshal(provisionerjobs.JobPosting{

View File

@ -117,9 +117,8 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
defer closeLogger()
}
if len(tags) != 0 {
logger.Info(ctx, "note: tagged provisioners can currently pick up jobs from untagged templates")
logger.Info(ctx, "see https://github.com/coder/coder/issues/6442 for details")
if len(tags) == 0 {
logger.Info(ctx, "note: untagged provisioners can only pick up jobs from untagged templates")
}
// When authorizing with a PSK, we automatically scope the provisionerd