mirror of https://github.com/coder/coder.git
1743 lines
55 KiB
Go
1743 lines
55 KiB
Go
package provisionerdserver_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"io"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/xerrors"
|
|
"storj.io/drpc"
|
|
|
|
"cdr.dev/slog"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/oauth2"
|
|
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/buildinfo"
|
|
"github.com/coder/coder/v2/coderd/audit"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbmem"
|
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
|
"github.com/coder/coder/v2/coderd/externalauth"
|
|
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
|
"github.com/coder/coder/v2/coderd/schedule"
|
|
"github.com/coder/coder/v2/coderd/schedule/cron"
|
|
"github.com/coder/coder/v2/coderd/telemetry"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/provisionerd/proto"
|
|
"github.com/coder/coder/v2/provisionersdk"
|
|
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/serpent"
|
|
)
|
|
|
|
func testTemplateScheduleStore() *atomic.Pointer[schedule.TemplateScheduleStore] {
|
|
ptr := &atomic.Pointer[schedule.TemplateScheduleStore]{}
|
|
store := schedule.NewAGPLTemplateScheduleStore()
|
|
ptr.Store(&store)
|
|
return ptr
|
|
}
|
|
|
|
func testUserQuietHoursScheduleStore() *atomic.Pointer[schedule.UserQuietHoursScheduleStore] {
|
|
ptr := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{}
|
|
store := schedule.NewAGPLUserQuietHoursScheduleStore()
|
|
ptr.Store(&store)
|
|
return ptr
|
|
}
|
|
|
|
func TestAcquireJob_LongPoll(t *testing.T) {
|
|
t.Parallel()
|
|
//nolint:dogsled
|
|
srv, _, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond})
|
|
job, err := srv.AcquireJob(context.Background(), nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, &proto.AcquiredJob{}, job)
|
|
}
|
|
|
|
func TestAcquireJobWithCancel_Cancel(t *testing.T) {
|
|
t.Parallel()
|
|
//nolint:dogsled
|
|
srv, _, _, _ := setup(t, false, nil)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
fs := newFakeStream(ctx)
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- srv.AcquireJobWithCancel(fs)
|
|
}()
|
|
fs.cancel()
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for AcquireJobWithCancel")
|
|
case err := <-errCh:
|
|
require.NoError(t, err)
|
|
}
|
|
job, err := fs.waitForJob()
|
|
require.NoError(t, err)
|
|
require.NotNil(t, job)
|
|
require.Equal(t, "", job.JobId)
|
|
}
|
|
|
|
func TestHeartbeat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
numBeats := 3
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
heartbeatChan := make(chan struct{})
|
|
heartbeatFn := func(hbCtx context.Context) error {
|
|
t.Logf("heartbeat")
|
|
select {
|
|
case <-hbCtx.Done():
|
|
return hbCtx.Err()
|
|
default:
|
|
heartbeatChan <- struct{}{}
|
|
return nil
|
|
}
|
|
}
|
|
//nolint:dogsled
|
|
_, _, _, _ = setup(t, false, &overrides{
|
|
ctx: ctx,
|
|
heartbeatFn: heartbeatFn,
|
|
heartbeatInterval: testutil.IntervalFast,
|
|
})
|
|
|
|
for i := 0; i < numBeats; i++ {
|
|
testutil.RequireRecvCtx(ctx, t, heartbeatChan)
|
|
}
|
|
// goleak.VerifyTestMain ensures that the heartbeat goroutine does not leak
|
|
}
|
|
|
|
func TestAcquireJob(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// These test acquiring a single job without canceling, and tests both AcquireJob (deprecated) and
|
|
// AcquireJobWithCancel as the way to get the job.
|
|
cases := []struct {
|
|
name string
|
|
acquire func(context.Context, proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error)
|
|
}{
|
|
{name: "Deprecated", acquire: func(ctx context.Context, srv proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error) {
|
|
return srv.AcquireJob(ctx, nil)
|
|
}},
|
|
{name: "WithCancel", acquire: func(ctx context.Context, srv proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error) {
|
|
fs := newFakeStream(ctx)
|
|
err := srv.AcquireJobWithCancel(fs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return fs.waitForJob()
|
|
}},
|
|
}
|
|
for _, tc := range cases {
|
|
tc := tc
|
|
t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, pd := setup(t, false, nil)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
|
|
_, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
|
|
OrganizationID: pd.OrganizationID,
|
|
ID: uuid.New(),
|
|
InitiatorID: uuid.New(),
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = tc.acquire(ctx, srv)
|
|
require.ErrorContains(t, err, "sql: no rows in result set")
|
|
})
|
|
t.Run(tc.name+"_WorkspaceBuildJob", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Set the max session token lifetime so we can assert we
|
|
// create an API key with an expiration within the bounds of the
|
|
// deployment config.
|
|
dv := &codersdk.DeploymentValues{
|
|
Sessions: codersdk.SessionLifetime{
|
|
MaximumTokenDuration: serpent.Duration(time.Hour),
|
|
},
|
|
}
|
|
gitAuthProvider := &sdkproto.ExternalAuthProviderResource{
|
|
Id: "github",
|
|
}
|
|
|
|
srv, db, ps, pd := setup(t, false, &overrides{
|
|
deploymentValues: dv,
|
|
externalAuthConfigs: []*externalauth.Config{{
|
|
ID: gitAuthProvider.Id,
|
|
InstrumentedOAuth2Config: &testutil.OAuth2Config{},
|
|
}},
|
|
})
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
group1 := dbgen.Group(t, db, database.Group{
|
|
Name: "group1",
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
err := db.InsertGroupMember(ctx, database.InsertGroupMemberParams{
|
|
UserID: user.ID,
|
|
GroupID: group1.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
link := dbgen.UserLink(t, db, database.UserLink{
|
|
LoginType: database.LoginTypeOIDC,
|
|
UserID: user.ID,
|
|
OAuthExpiry: dbtime.Now().Add(time.Hour),
|
|
OAuthAccessToken: "access-token",
|
|
})
|
|
dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
|
|
ProviderID: gitAuthProvider.Id,
|
|
UserID: user.ID,
|
|
})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
Name: "template",
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
|
|
versionFile := dbgen.File(t, db, database.File{CreatedBy: user.ID})
|
|
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: pd.OrganizationID,
|
|
TemplateID: uuid.NullUUID{
|
|
UUID: template.ID,
|
|
Valid: true,
|
|
},
|
|
JobID: uuid.New(),
|
|
})
|
|
externalAuthProviders, err := json.Marshal([]database.ExternalAuthProvider{{
|
|
ID: gitAuthProvider.Id,
|
|
Optional: gitAuthProvider.Optional,
|
|
}})
|
|
require.NoError(t, err)
|
|
err = db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
|
|
JobID: version.JobID,
|
|
ExternalAuthProviders: json.RawMessage(externalAuthProviders),
|
|
UpdatedAt: dbtime.Now(),
|
|
})
|
|
require.NoError(t, err)
|
|
// Import version job
|
|
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
OrganizationID: pd.OrganizationID,
|
|
ID: version.JobID,
|
|
InitiatorID: user.ID,
|
|
FileID: versionFile.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{
|
|
TemplateVersionID: version.ID,
|
|
UserVariableValues: []codersdk.VariableValue{
|
|
{Name: "second", Value: "bah"},
|
|
},
|
|
})),
|
|
})
|
|
_ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{
|
|
TemplateVersionID: version.ID,
|
|
Name: "first",
|
|
Value: "first_value",
|
|
DefaultValue: "default_value",
|
|
Sensitive: true,
|
|
})
|
|
_ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{
|
|
TemplateVersionID: version.ID,
|
|
Name: "second",
|
|
Value: "second_value",
|
|
DefaultValue: "default_value",
|
|
Required: true,
|
|
Sensitive: false,
|
|
})
|
|
workspace := dbgen.Workspace(t, db, database.Workspace{
|
|
TemplateID: template.ID,
|
|
OwnerID: user.ID,
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
BuildNumber: 1,
|
|
JobID: uuid.New(),
|
|
TemplateVersionID: version.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
Reason: database.BuildReasonInitiator,
|
|
})
|
|
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
ID: build.ID,
|
|
OrganizationID: pd.OrganizationID,
|
|
InitiatorID: user.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
FileID: file.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
|
WorkspaceBuildID: build.ID,
|
|
})),
|
|
})
|
|
|
|
startPublished := make(chan struct{})
|
|
var closed bool
|
|
closeStartSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
|
|
if !closed {
|
|
close(startPublished)
|
|
closed = true
|
|
}
|
|
})
|
|
require.NoError(t, err)
|
|
defer closeStartSubscribe()
|
|
|
|
var job *proto.AcquiredJob
|
|
|
|
for {
|
|
// Grab jobs until we find the workspace build job. There is also
|
|
// an import version job that we need to ignore.
|
|
job, err = tc.acquire(ctx, srv)
|
|
require.NoError(t, err)
|
|
if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok {
|
|
break
|
|
}
|
|
}
|
|
|
|
<-startPublished
|
|
|
|
got, err := json.Marshal(job.Type)
|
|
require.NoError(t, err)
|
|
|
|
// Validate that a session token is generated during the job.
|
|
sessionToken := job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken
|
|
require.NotEmpty(t, sessionToken)
|
|
toks := strings.Split(sessionToken, "-")
|
|
require.Len(t, toks, 2, "invalid api key")
|
|
key, err := db.GetAPIKeyByID(ctx, toks[0])
|
|
require.NoError(t, err)
|
|
require.Equal(t, int64(dv.Sessions.MaximumTokenDuration.Value().Seconds()), key.LifetimeSeconds)
|
|
require.WithinDuration(t, time.Now().Add(dv.Sessions.MaximumTokenDuration.Value()), key.ExpiresAt, time.Minute)
|
|
|
|
want, err := json.Marshal(&proto.AcquiredJob_WorkspaceBuild_{
|
|
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
|
|
WorkspaceBuildId: build.ID.String(),
|
|
WorkspaceName: workspace.Name,
|
|
VariableValues: []*sdkproto.VariableValue{
|
|
{
|
|
Name: "first",
|
|
Value: "first_value",
|
|
Sensitive: true,
|
|
},
|
|
{
|
|
Name: "second",
|
|
Value: "second_value",
|
|
},
|
|
},
|
|
ExternalAuthProviders: []*sdkproto.ExternalAuthProvider{{
|
|
Id: gitAuthProvider.Id,
|
|
AccessToken: "access_token",
|
|
}},
|
|
Metadata: &sdkproto.Metadata{
|
|
CoderUrl: (&url.URL{}).String(),
|
|
WorkspaceTransition: sdkproto.WorkspaceTransition_START,
|
|
WorkspaceName: workspace.Name,
|
|
WorkspaceOwner: user.Username,
|
|
WorkspaceOwnerEmail: user.Email,
|
|
WorkspaceOwnerName: user.Name,
|
|
WorkspaceOwnerOidcAccessToken: link.OAuthAccessToken,
|
|
WorkspaceOwnerGroups: []string{group1.Name},
|
|
WorkspaceId: workspace.ID.String(),
|
|
WorkspaceOwnerId: user.ID.String(),
|
|
TemplateId: template.ID.String(),
|
|
TemplateName: template.Name,
|
|
TemplateVersion: version.Name,
|
|
WorkspaceOwnerSessionToken: sessionToken,
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
require.JSONEq(t, string(want), string(got))
|
|
|
|
// Assert that we delete the session token whenever
|
|
// a stop is issued.
|
|
stopbuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
BuildNumber: 2,
|
|
JobID: uuid.New(),
|
|
TemplateVersionID: version.ID,
|
|
Transition: database.WorkspaceTransitionStop,
|
|
Reason: database.BuildReasonInitiator,
|
|
})
|
|
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
ID: stopbuild.ID,
|
|
InitiatorID: user.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
FileID: file.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
|
WorkspaceBuildID: stopbuild.ID,
|
|
})),
|
|
})
|
|
|
|
stopPublished := make(chan struct{})
|
|
closeStopSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
|
|
close(stopPublished)
|
|
})
|
|
require.NoError(t, err)
|
|
defer closeStopSubscribe()
|
|
|
|
// Grab jobs until we find the workspace build job. There is also
|
|
// an import version job that we need to ignore.
|
|
job, err = tc.acquire(ctx, srv)
|
|
require.NoError(t, err)
|
|
_, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_)
|
|
require.True(t, ok, "acquired job not a workspace build?")
|
|
|
|
<-stopPublished
|
|
|
|
// Validate that a session token is deleted during a stop job.
|
|
sessionToken = job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken
|
|
require.Empty(t, sessionToken)
|
|
_, err = db.GetAPIKeyByID(ctx, key.ID)
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
})
|
|
|
|
t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, ps, _ := setup(t, false, nil)
|
|
ctx := context.Background()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
|
|
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
|
|
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
InitiatorID: user.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
FileID: file.ID,
|
|
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
|
Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{
|
|
TemplateVersionID: version.ID,
|
|
WorkspaceName: "testing",
|
|
})),
|
|
})
|
|
|
|
job, err := tc.acquire(ctx, srv)
|
|
require.NoError(t, err)
|
|
|
|
got, err := json.Marshal(job.Type)
|
|
require.NoError(t, err)
|
|
|
|
want, err := json.Marshal(&proto.AcquiredJob_TemplateDryRun_{
|
|
TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{
|
|
Metadata: &sdkproto.Metadata{
|
|
CoderUrl: (&url.URL{}).String(),
|
|
WorkspaceName: "testing",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.JSONEq(t, string(want), string(got))
|
|
})
|
|
t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, ps, _ := setup(t, false, nil)
|
|
ctx := context.Background()
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
|
|
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
FileID: file.ID,
|
|
InitiatorID: user.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
})
|
|
|
|
job, err := tc.acquire(ctx, srv)
|
|
require.NoError(t, err)
|
|
|
|
got, err := json.Marshal(job.Type)
|
|
require.NoError(t, err)
|
|
|
|
want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{
|
|
TemplateImport: &proto.AcquiredJob_TemplateImport{
|
|
Metadata: &sdkproto.Metadata{
|
|
CoderUrl: (&url.URL{}).String(),
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.JSONEq(t, string(want), string(got))
|
|
})
|
|
t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, ps, _ := setup(t, false, nil)
|
|
|
|
user := dbgen.User(t, db, database.User{})
|
|
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
|
|
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
|
|
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
FileID: file.ID,
|
|
InitiatorID: user.ID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{
|
|
TemplateVersionID: version.ID,
|
|
UserVariableValues: []codersdk.VariableValue{
|
|
{Name: "first", Value: "first_value"},
|
|
},
|
|
})),
|
|
})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
|
|
job, err := tc.acquire(ctx, srv)
|
|
require.NoError(t, err)
|
|
|
|
got, err := json.Marshal(job.Type)
|
|
require.NoError(t, err)
|
|
|
|
want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{
|
|
TemplateImport: &proto.AcquiredJob_TemplateImport{
|
|
UserVariableValues: []*sdkproto.VariableValue{
|
|
{Name: "first", Sensitive: true, Value: "first_value"},
|
|
},
|
|
Metadata: &sdkproto.Metadata{
|
|
CoderUrl: (&url.URL{}).String(),
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.JSONEq(t, string(want), string(got))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateJob(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
t.Run("NotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, _, _, _ := setup(t, false, nil)
|
|
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: "hello",
|
|
})
|
|
require.ErrorContains(t, err, "invalid UUID")
|
|
|
|
_, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: uuid.NewString(),
|
|
})
|
|
require.ErrorContains(t, err, "no rows in result set")
|
|
})
|
|
t.Run("NotRunning", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, _ := setup(t, false, nil)
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: job.ID.String(),
|
|
})
|
|
require.ErrorContains(t, err, "job isn't running yet")
|
|
})
|
|
// This test prevents runners from updating jobs they don't own!
|
|
t.Run("NotOwner", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, _ := setup(t, false, nil)
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: job.ID.String(),
|
|
})
|
|
require.ErrorContains(t, err, "you don't own this job")
|
|
})
|
|
|
|
setupJob := func(t *testing.T, db database.Store, srvID uuid.UUID) uuid.UUID {
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: srvID,
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
return job.ID
|
|
}
|
|
|
|
t.Run("Success", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, pd := setup(t, false, &overrides{})
|
|
job := setupJob(t, db, pd.ID)
|
|
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: job.String(),
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("Logs", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, ps, pd := setup(t, false, &overrides{})
|
|
job := setupJob(t, db, pd.ID)
|
|
|
|
published := make(chan struct{})
|
|
|
|
closeListener, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) {
|
|
close(published)
|
|
})
|
|
require.NoError(t, err)
|
|
defer closeListener()
|
|
|
|
_, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: job.String(),
|
|
Logs: []*proto.Log{{
|
|
Source: proto.LogSource_PROVISIONER,
|
|
Level: sdkproto.LogLevel_INFO,
|
|
Output: "hi",
|
|
}},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
<-published
|
|
})
|
|
t.Run("Readme", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, pd := setup(t, false, &overrides{})
|
|
job := setupJob(t, db, pd.ID)
|
|
versionID := uuid.New()
|
|
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
|
ID: versionID,
|
|
JobID: job,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: job.String(),
|
|
Readme: []byte("# hello world"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
version, err := db.GetTemplateVersionByID(ctx, versionID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "# hello world", version.Readme)
|
|
})
|
|
|
|
t.Run("TemplateVariables", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Valid", func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
|
|
srv, db, _, pd := setup(t, false, &overrides{})
|
|
job := setupJob(t, db, pd.ID)
|
|
versionID := uuid.New()
|
|
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
|
ID: versionID,
|
|
JobID: job,
|
|
})
|
|
require.NoError(t, err)
|
|
firstTemplateVariable := &sdkproto.TemplateVariable{
|
|
Name: "first",
|
|
Type: "string",
|
|
DefaultValue: "default_value",
|
|
Sensitive: true,
|
|
}
|
|
secondTemplateVariable := &sdkproto.TemplateVariable{
|
|
Name: "second",
|
|
Type: "string",
|
|
Required: true,
|
|
Sensitive: true,
|
|
}
|
|
response, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: job.String(),
|
|
TemplateVariables: []*sdkproto.TemplateVariable{
|
|
firstTemplateVariable,
|
|
secondTemplateVariable,
|
|
},
|
|
UserVariableValues: []*sdkproto.VariableValue{
|
|
{
|
|
Name: "second",
|
|
Value: "foobar",
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, response.VariableValues, 2)
|
|
|
|
templateVariables, err := db.GetTemplateVersionVariables(ctx, versionID)
|
|
require.NoError(t, err)
|
|
require.Len(t, templateVariables, 2)
|
|
require.Equal(t, templateVariables[0].Value, firstTemplateVariable.DefaultValue)
|
|
require.Equal(t, templateVariables[1].Value, "foobar")
|
|
})
|
|
|
|
t.Run("Missing required value", func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
|
|
srv, db, _, pd := setup(t, false, &overrides{})
|
|
job := setupJob(t, db, pd.ID)
|
|
versionID := uuid.New()
|
|
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
|
ID: versionID,
|
|
JobID: job,
|
|
})
|
|
require.NoError(t, err)
|
|
firstTemplateVariable := &sdkproto.TemplateVariable{
|
|
Name: "first",
|
|
Type: "string",
|
|
DefaultValue: "default_value",
|
|
Sensitive: true,
|
|
}
|
|
secondTemplateVariable := &sdkproto.TemplateVariable{
|
|
Name: "second",
|
|
Type: "string",
|
|
Required: true,
|
|
Sensitive: true,
|
|
}
|
|
response, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: job.String(),
|
|
TemplateVariables: []*sdkproto.TemplateVariable{
|
|
firstTemplateVariable,
|
|
secondTemplateVariable,
|
|
},
|
|
})
|
|
require.Error(t, err) // required template variables need values
|
|
require.Nil(t, response)
|
|
|
|
// Even though there is an error returned, variables are stored in the database
|
|
// to show the schema in the site UI.
|
|
templateVariables, err := db.GetTemplateVersionVariables(ctx, versionID)
|
|
require.NoError(t, err)
|
|
require.Len(t, templateVariables, 2)
|
|
require.Equal(t, templateVariables[0].Value, firstTemplateVariable.DefaultValue)
|
|
require.Equal(t, templateVariables[1].Value, "")
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestFailJob(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
t.Run("NotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, _, _, _ := setup(t, false, nil)
|
|
_, err := srv.FailJob(ctx, &proto.FailedJob{
|
|
JobId: "hello",
|
|
})
|
|
require.ErrorContains(t, err, "invalid UUID")
|
|
|
|
_, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{
|
|
JobId: uuid.NewString(),
|
|
})
|
|
require.ErrorContains(t, err, "no rows in result set")
|
|
})
|
|
// This test prevents runners from updating jobs they don't own!
|
|
t.Run("NotOwner", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, _ := setup(t, false, nil)
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = srv.FailJob(ctx, &proto.FailedJob{
|
|
JobId: job.ID.String(),
|
|
})
|
|
require.ErrorContains(t, err, "you don't own this job")
|
|
})
|
|
t.Run("AlreadyCompleted", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, pd := setup(t, false, &overrides{})
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: pd.ID,
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
|
ID: job.ID,
|
|
CompletedAt: sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = srv.FailJob(ctx, &proto.FailedJob{
|
|
JobId: job.ID.String(),
|
|
})
|
|
require.ErrorContains(t, err, "job already completed")
|
|
})
|
|
t.Run("WorkspaceBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Ignore log errors because we get:
|
|
//
|
|
// (*Server).FailJob audit log - get build {"error": "sql: no rows in result set"}
|
|
ignoreLogErrors := true
|
|
auditor := audit.NewMock()
|
|
srv, db, ps, pd := setup(t, ignoreLogErrors, &overrides{
|
|
auditor: auditor,
|
|
})
|
|
org := dbgen.Organization(t, db, database.Organization{})
|
|
workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{
|
|
ID: uuid.New(),
|
|
AutomaticUpdates: database.AutomaticUpdatesNever,
|
|
OrganizationID: org.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
buildID := uuid.New()
|
|
input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
|
WorkspaceBuildID: buildID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
Input: input,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
})
|
|
require.NoError(t, err)
|
|
err = db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{
|
|
ID: buildID,
|
|
WorkspaceID: workspace.ID,
|
|
Transition: database.WorkspaceTransitionStart,
|
|
Reason: database.BuildReasonInitiator,
|
|
JobID: job.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: pd.ID,
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
publishedWorkspace := make(chan struct{})
|
|
closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
|
|
close(publishedWorkspace)
|
|
})
|
|
require.NoError(t, err)
|
|
defer closeWorkspaceSubscribe()
|
|
publishedLogs := make(chan struct{})
|
|
closeLogsSubscribe, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
|
|
close(publishedLogs)
|
|
})
|
|
require.NoError(t, err)
|
|
defer closeLogsSubscribe()
|
|
|
|
auditor.ResetLogs()
|
|
_, err = srv.FailJob(ctx, &proto.FailedJob{
|
|
JobId: job.ID.String(),
|
|
Type: &proto.FailedJob_WorkspaceBuild_{
|
|
WorkspaceBuild: &proto.FailedJob_WorkspaceBuild{
|
|
State: []byte("some state"),
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
<-publishedWorkspace
|
|
<-publishedLogs
|
|
build, err := db.GetWorkspaceBuildByID(ctx, buildID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "some state", string(build.ProvisionerState))
|
|
require.Len(t, auditor.AuditLogs(), 1)
|
|
|
|
// Assert that the workspace_id field get populated
|
|
var additionalFields audit.AdditionalFields
|
|
err = json.Unmarshal(auditor.AuditLogs()[0].AdditionalFields, &additionalFields)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspace.ID, additionalFields.WorkspaceID)
|
|
})
|
|
}
|
|
|
|
func TestCompleteJob(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
t.Run("NotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, _, _, _ := setup(t, false, nil)
|
|
_, err := srv.CompleteJob(ctx, &proto.CompletedJob{
|
|
JobId: "hello",
|
|
})
|
|
require.ErrorContains(t, err, "invalid UUID")
|
|
|
|
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
|
JobId: uuid.NewString(),
|
|
})
|
|
require.ErrorContains(t, err, "no rows in result set")
|
|
})
|
|
// This test prevents runners from updating jobs they don't own!
|
|
t.Run("NotOwner", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, pd := setup(t, false, nil)
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: pd.OrganizationID,
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
|
JobId: job.ID.String(),
|
|
})
|
|
require.ErrorContains(t, err, "you don't own this job")
|
|
})
|
|
|
|
t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, pd := setup(t, false, &overrides{})
|
|
jobID := uuid.New()
|
|
versionID := uuid.New()
|
|
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
|
ID: versionID,
|
|
JobID: jobID,
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
require.NoError(t, err)
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: jobID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: pd.OrganizationID,
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: pd.ID,
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
completeJob := func() {
|
|
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
|
JobId: job.ID.String(),
|
|
Type: &proto.CompletedJob_TemplateImport_{
|
|
TemplateImport: &proto.CompletedJob_TemplateImport{
|
|
StartResources: []*sdkproto.Resource{{
|
|
Name: "hello",
|
|
Type: "aws_instance",
|
|
}},
|
|
StopResources: []*sdkproto.Resource{},
|
|
ExternalAuthProviders: []*sdkproto.ExternalAuthProviderResource{{
|
|
Id: "github",
|
|
}},
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
completeJob()
|
|
job, err = db.GetProvisionerJobByID(ctx, job.ID)
|
|
require.NoError(t, err)
|
|
require.Contains(t, job.Error.String, `external auth provider "github" is not configured`)
|
|
})
|
|
|
|
t.Run("TemplateImport_WithGitAuth", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, pd := setup(t, false, &overrides{
|
|
externalAuthConfigs: []*externalauth.Config{{
|
|
ID: "github",
|
|
}},
|
|
})
|
|
jobID := uuid.New()
|
|
versionID := uuid.New()
|
|
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
|
ID: versionID,
|
|
JobID: jobID,
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
require.NoError(t, err)
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
OrganizationID: pd.OrganizationID,
|
|
ID: jobID,
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: pd.OrganizationID,
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: pd.ID,
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
completeJob := func() {
|
|
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
|
JobId: job.ID.String(),
|
|
Type: &proto.CompletedJob_TemplateImport_{
|
|
TemplateImport: &proto.CompletedJob_TemplateImport{
|
|
StartResources: []*sdkproto.Resource{{
|
|
Name: "hello",
|
|
Type: "aws_instance",
|
|
}},
|
|
StopResources: []*sdkproto.Resource{},
|
|
ExternalAuthProviders: []*sdkproto.ExternalAuthProviderResource{{Id: "github"}},
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
completeJob()
|
|
job, err = db.GetProvisionerJobByID(ctx, job.ID)
|
|
require.NoError(t, err)
|
|
require.False(t, job.Error.Valid)
|
|
})
|
|
|
|
t.Run("WorkspaceBuild", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := time.Now()
|
|
|
|
// NOTE: if you're looking for more in-depth deadline/max_deadline
|
|
// calculation testing, see the schedule package. The provsiionerdserver
|
|
// package calls `schedule.CalculateAutostop()` to generate the deadline
|
|
// and max_deadline.
|
|
|
|
// Wednesday the 8th of February 2023 at midnight. This date was
|
|
// specifically chosen as it doesn't fall on a applicable week for both
|
|
// fortnightly and triweekly autostop requirements.
|
|
wednesdayMidnightUTC := time.Date(2023, 2, 8, 0, 0, 0, 0, time.UTC)
|
|
|
|
sydneyQuietHours := "CRON_TZ=Australia/Sydney 0 0 * * *"
|
|
sydneyLoc, err := time.LoadLocation("Australia/Sydney")
|
|
require.NoError(t, err)
|
|
// 12am on Saturday the 11th of February 2023 in Sydney.
|
|
saturdayMidnightSydney := time.Date(2023, 2, 11, 0, 0, 0, 0, sydneyLoc)
|
|
|
|
t.Log("now", now)
|
|
t.Log("wednesdayMidnightUTC", wednesdayMidnightUTC)
|
|
t.Log("saturdayMidnightSydney", saturdayMidnightSydney)
|
|
|
|
cases := []struct {
|
|
name string
|
|
now time.Time
|
|
workspaceTTL time.Duration
|
|
transition database.WorkspaceTransition
|
|
|
|
// These fields are only used when testing max deadline.
|
|
userQuietHoursSchedule string
|
|
templateAutostopRequirement schedule.TemplateAutostopRequirement
|
|
|
|
expectedDeadline time.Time
|
|
expectedMaxDeadline time.Time
|
|
}{
|
|
{
|
|
name: "OK",
|
|
now: now,
|
|
templateAutostopRequirement: schedule.TemplateAutostopRequirement{},
|
|
workspaceTTL: 0,
|
|
transition: database.WorkspaceTransitionStart,
|
|
expectedDeadline: time.Time{},
|
|
expectedMaxDeadline: time.Time{},
|
|
},
|
|
{
|
|
name: "Delete",
|
|
now: now,
|
|
templateAutostopRequirement: schedule.TemplateAutostopRequirement{},
|
|
workspaceTTL: 0,
|
|
transition: database.WorkspaceTransitionDelete,
|
|
expectedDeadline: time.Time{},
|
|
expectedMaxDeadline: time.Time{},
|
|
},
|
|
{
|
|
name: "WorkspaceTTL",
|
|
now: now,
|
|
templateAutostopRequirement: schedule.TemplateAutostopRequirement{},
|
|
workspaceTTL: time.Hour,
|
|
transition: database.WorkspaceTransitionStart,
|
|
expectedDeadline: now.Add(time.Hour),
|
|
expectedMaxDeadline: time.Time{},
|
|
},
|
|
{
|
|
name: "TemplateAutostopRequirement",
|
|
now: wednesdayMidnightUTC,
|
|
userQuietHoursSchedule: sydneyQuietHours,
|
|
templateAutostopRequirement: schedule.TemplateAutostopRequirement{
|
|
DaysOfWeek: 0b00100000, // Saturday
|
|
Weeks: 0, // weekly
|
|
},
|
|
workspaceTTL: 0,
|
|
transition: database.WorkspaceTransitionStart,
|
|
// expectedDeadline is copied from expectedMaxDeadline.
|
|
expectedMaxDeadline: saturdayMidnightSydney.In(time.UTC),
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
c := c
|
|
|
|
t.Run(c.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Simulate the given time starting from now.
|
|
require.False(t, c.now.IsZero())
|
|
start := time.Now()
|
|
tss := &atomic.Pointer[schedule.TemplateScheduleStore]{}
|
|
uqhss := &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{}
|
|
auditor := audit.NewMock()
|
|
srv, db, ps, pd := setup(t, false, &overrides{
|
|
timeNowFn: func() time.Time {
|
|
return c.now.Add(time.Since(start))
|
|
},
|
|
templateScheduleStore: tss,
|
|
userQuietHoursScheduleStore: uqhss,
|
|
auditor: auditor,
|
|
})
|
|
|
|
var templateScheduleStore schedule.TemplateScheduleStore = schedule.MockTemplateScheduleStore{
|
|
GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.TemplateScheduleOptions, error) {
|
|
return schedule.TemplateScheduleOptions{
|
|
UserAutostartEnabled: false,
|
|
UserAutostopEnabled: true,
|
|
DefaultTTL: 0,
|
|
AutostopRequirement: c.templateAutostopRequirement,
|
|
}, nil
|
|
},
|
|
}
|
|
tss.Store(&templateScheduleStore)
|
|
|
|
var userQuietHoursScheduleStore schedule.UserQuietHoursScheduleStore = schedule.MockUserQuietHoursScheduleStore{
|
|
GetFn: func(_ context.Context, _ database.Store, _ uuid.UUID) (schedule.UserQuietHoursScheduleOptions, error) {
|
|
if c.userQuietHoursSchedule == "" {
|
|
return schedule.UserQuietHoursScheduleOptions{
|
|
Schedule: nil,
|
|
}, nil
|
|
}
|
|
|
|
sched, err := cron.Daily(c.userQuietHoursSchedule)
|
|
if !assert.NoError(t, err) {
|
|
return schedule.UserQuietHoursScheduleOptions{}, err
|
|
}
|
|
|
|
return schedule.UserQuietHoursScheduleOptions{
|
|
Schedule: sched,
|
|
UserSet: false,
|
|
}, nil
|
|
},
|
|
}
|
|
uqhss.Store(&userQuietHoursScheduleStore)
|
|
|
|
user := dbgen.User(t, db, database.User{
|
|
QuietHoursSchedule: c.userQuietHoursSchedule,
|
|
})
|
|
template := dbgen.Template(t, db, database.Template{
|
|
Name: "template",
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{
|
|
ID: template.ID,
|
|
UpdatedAt: dbtime.Now(),
|
|
AllowUserAutostart: false,
|
|
AllowUserAutostop: true,
|
|
DefaultTTL: 0,
|
|
AutostopRequirementDaysOfWeek: int16(c.templateAutostopRequirement.DaysOfWeek),
|
|
AutostopRequirementWeeks: c.templateAutostopRequirement.Weeks,
|
|
})
|
|
require.NoError(t, err)
|
|
template, err = db.GetTemplateByID(ctx, template.ID)
|
|
require.NoError(t, err)
|
|
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
|
|
workspaceTTL := sql.NullInt64{}
|
|
if c.workspaceTTL != 0 {
|
|
workspaceTTL = sql.NullInt64{
|
|
Int64: int64(c.workspaceTTL),
|
|
Valid: true,
|
|
}
|
|
}
|
|
workspace := dbgen.Workspace(t, db, database.Workspace{
|
|
TemplateID: template.ID,
|
|
Ttl: workspaceTTL,
|
|
OwnerID: user.ID,
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
|
OrganizationID: pd.OrganizationID,
|
|
TemplateID: uuid.NullUUID{
|
|
UUID: template.ID,
|
|
Valid: true,
|
|
},
|
|
JobID: uuid.New(),
|
|
})
|
|
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
|
WorkspaceID: workspace.ID,
|
|
TemplateVersionID: version.ID,
|
|
Transition: c.transition,
|
|
Reason: database.BuildReasonInitiator,
|
|
})
|
|
job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
|
FileID: file.ID,
|
|
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
|
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
|
WorkspaceBuildID: build.ID,
|
|
})),
|
|
OrganizationID: pd.OrganizationID,
|
|
})
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
OrganizationID: pd.OrganizationID,
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: pd.ID,
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
publishedWorkspace := make(chan struct{})
|
|
closeWorkspaceSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), func(_ context.Context, _ []byte) {
|
|
close(publishedWorkspace)
|
|
})
|
|
require.NoError(t, err)
|
|
defer closeWorkspaceSubscribe()
|
|
publishedLogs := make(chan struct{})
|
|
closeLogsSubscribe, err := ps.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
|
|
close(publishedLogs)
|
|
})
|
|
require.NoError(t, err)
|
|
defer closeLogsSubscribe()
|
|
|
|
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
|
JobId: job.ID.String(),
|
|
Type: &proto.CompletedJob_WorkspaceBuild_{
|
|
WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{
|
|
State: []byte{},
|
|
Resources: []*sdkproto.Resource{{
|
|
Name: "example",
|
|
Type: "aws_instance",
|
|
}},
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
<-publishedWorkspace
|
|
<-publishedLogs
|
|
|
|
workspace, err = db.GetWorkspaceByID(ctx, workspace.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, c.transition == database.WorkspaceTransitionDelete, workspace.Deleted)
|
|
|
|
workspaceBuild, err := db.GetWorkspaceBuildByID(ctx, build.ID)
|
|
require.NoError(t, err)
|
|
|
|
// If the max deadline is set, the deadline should also be set.
|
|
// Default to the max deadline if the deadline is not set.
|
|
if c.expectedDeadline.IsZero() {
|
|
c.expectedDeadline = c.expectedMaxDeadline
|
|
}
|
|
|
|
if c.expectedDeadline.IsZero() {
|
|
require.True(t, workspaceBuild.Deadline.IsZero())
|
|
} else {
|
|
require.WithinDuration(t, c.expectedDeadline, workspaceBuild.Deadline, 15*time.Second, "deadline does not match expected")
|
|
}
|
|
if c.expectedMaxDeadline.IsZero() {
|
|
require.True(t, workspaceBuild.MaxDeadline.IsZero())
|
|
} else {
|
|
require.WithinDuration(t, c.expectedMaxDeadline, workspaceBuild.MaxDeadline, 15*time.Second, "max deadline does not match expected")
|
|
require.GreaterOrEqual(t, workspaceBuild.MaxDeadline.Unix(), workspaceBuild.Deadline.Unix(), "max deadline is smaller than deadline")
|
|
}
|
|
|
|
require.Len(t, auditor.AuditLogs(), 1)
|
|
var additionalFields audit.AdditionalFields
|
|
err = json.Unmarshal(auditor.AuditLogs()[0].AdditionalFields, &additionalFields)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspace.ID, additionalFields.WorkspaceID)
|
|
})
|
|
}
|
|
})
|
|
t.Run("TemplateDryRun", func(t *testing.T) {
|
|
t.Parallel()
|
|
srv, db, _, pd := setup(t, false, &overrides{})
|
|
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
|
ID: uuid.New(),
|
|
Provisioner: database.ProvisionerTypeEcho,
|
|
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
|
StorageMethod: database.ProvisionerStorageMethodFile,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
|
WorkerID: uuid.NullUUID{
|
|
UUID: pd.ID,
|
|
Valid: true,
|
|
},
|
|
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
|
JobId: job.ID.String(),
|
|
Type: &proto.CompletedJob_TemplateDryRun_{
|
|
TemplateDryRun: &proto.CompletedJob_TemplateDryRun{
|
|
Resources: []*sdkproto.Resource{{
|
|
Name: "something",
|
|
Type: "aws_instance",
|
|
}},
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestInsertWorkspaceResource(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
insert := func(db database.Store, jobID uuid.UUID, resource *sdkproto.Resource) error {
|
|
return provisionerdserver.InsertWorkspaceResource(ctx, db, jobID, database.WorkspaceTransitionStart, resource, &telemetry.Snapshot{})
|
|
}
|
|
t.Run("NoAgents", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := dbmem.New()
|
|
job := uuid.New()
|
|
err := insert(db, job, &sdkproto.Resource{
|
|
Name: "something",
|
|
Type: "aws_instance",
|
|
})
|
|
require.NoError(t, err)
|
|
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job)
|
|
require.NoError(t, err)
|
|
require.Len(t, resources, 1)
|
|
})
|
|
t.Run("InvalidAgentToken", func(t *testing.T) {
|
|
t.Parallel()
|
|
err := insert(dbmem.New(), uuid.New(), &sdkproto.Resource{
|
|
Name: "something",
|
|
Type: "aws_instance",
|
|
Agents: []*sdkproto.Agent{{
|
|
Auth: &sdkproto.Agent_Token{
|
|
Token: "bananas",
|
|
},
|
|
}},
|
|
})
|
|
require.ErrorContains(t, err, "invalid UUID length")
|
|
})
|
|
t.Run("DuplicateApps", func(t *testing.T) {
|
|
t.Parallel()
|
|
err := insert(dbmem.New(), uuid.New(), &sdkproto.Resource{
|
|
Name: "something",
|
|
Type: "aws_instance",
|
|
Agents: []*sdkproto.Agent{{
|
|
Apps: []*sdkproto.App{{
|
|
Slug: "a",
|
|
}, {
|
|
Slug: "a",
|
|
}},
|
|
}},
|
|
})
|
|
require.ErrorContains(t, err, "duplicate app slug")
|
|
})
|
|
t.Run("Success", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := dbmem.New()
|
|
job := uuid.New()
|
|
err := insert(db, job, &sdkproto.Resource{
|
|
Name: "something",
|
|
Type: "aws_instance",
|
|
DailyCost: 10,
|
|
Agents: []*sdkproto.Agent{{
|
|
Name: "dev",
|
|
Env: map[string]string{
|
|
"something": "test",
|
|
},
|
|
OperatingSystem: "linux",
|
|
Architecture: "amd64",
|
|
Auth: &sdkproto.Agent_Token{
|
|
Token: uuid.NewString(),
|
|
},
|
|
Apps: []*sdkproto.App{{
|
|
Slug: "a",
|
|
}},
|
|
ExtraEnvs: []*sdkproto.Env{
|
|
{
|
|
Name: "something", // Duplicate, already set by Env.
|
|
Value: "I should be discarded!",
|
|
},
|
|
{
|
|
Name: "else",
|
|
Value: "I laugh in the face of danger.",
|
|
},
|
|
},
|
|
Scripts: []*sdkproto.Script{{
|
|
DisplayName: "Startup",
|
|
Icon: "/test.png",
|
|
}},
|
|
DisplayApps: &sdkproto.DisplayApps{
|
|
Vscode: true,
|
|
PortForwardingHelper: true,
|
|
SshHelper: true,
|
|
},
|
|
}},
|
|
})
|
|
require.NoError(t, err)
|
|
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job)
|
|
require.NoError(t, err)
|
|
require.Len(t, resources, 1)
|
|
require.EqualValues(t, 10, resources[0].DailyCost)
|
|
agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 1)
|
|
agent := agents[0]
|
|
require.Equal(t, "amd64", agent.Architecture)
|
|
require.Equal(t, "linux", agent.OperatingSystem)
|
|
want, err := json.Marshal(map[string]string{
|
|
"something": "test",
|
|
"else": "I laugh in the face of danger.",
|
|
})
|
|
require.NoError(t, err)
|
|
got, err := agent.EnvironmentVariables.RawMessage.MarshalJSON()
|
|
require.NoError(t, err)
|
|
require.Equal(t, want, got)
|
|
require.ElementsMatch(t, []database.DisplayApp{
|
|
database.DisplayAppPortForwardingHelper,
|
|
database.DisplayAppSSHHelper,
|
|
database.DisplayAppVscode,
|
|
}, agent.DisplayApps)
|
|
})
|
|
|
|
t.Run("AllDisplayApps", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := dbmem.New()
|
|
job := uuid.New()
|
|
err := insert(db, job, &sdkproto.Resource{
|
|
Name: "something",
|
|
Type: "aws_instance",
|
|
Agents: []*sdkproto.Agent{{
|
|
DisplayApps: &sdkproto.DisplayApps{
|
|
Vscode: true,
|
|
VscodeInsiders: true,
|
|
SshHelper: true,
|
|
PortForwardingHelper: true,
|
|
WebTerminal: true,
|
|
},
|
|
}},
|
|
})
|
|
require.NoError(t, err)
|
|
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job)
|
|
require.NoError(t, err)
|
|
require.Len(t, resources, 1)
|
|
agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 1)
|
|
agent := agents[0]
|
|
require.ElementsMatch(t, database.AllDisplayAppValues(), agent.DisplayApps)
|
|
})
|
|
|
|
t.Run("DisableDefaultApps", func(t *testing.T) {
|
|
t.Parallel()
|
|
db := dbmem.New()
|
|
job := uuid.New()
|
|
err := insert(db, job, &sdkproto.Resource{
|
|
Name: "something",
|
|
Type: "aws_instance",
|
|
Agents: []*sdkproto.Agent{{
|
|
DisplayApps: &sdkproto.DisplayApps{},
|
|
}},
|
|
})
|
|
require.NoError(t, err)
|
|
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job)
|
|
require.NoError(t, err)
|
|
require.Len(t, resources, 1)
|
|
agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID})
|
|
require.NoError(t, err)
|
|
require.Len(t, agents, 1)
|
|
agent := agents[0]
|
|
// An empty array (as opposed to nil) should be returned to indicate
|
|
// that all apps are disabled.
|
|
require.Equal(t, []database.DisplayApp{}, agent.DisplayApps)
|
|
})
|
|
}
|
|
|
|
type overrides struct {
|
|
ctx context.Context
|
|
deploymentValues *codersdk.DeploymentValues
|
|
externalAuthConfigs []*externalauth.Config
|
|
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
|
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
|
|
timeNowFn func() time.Time
|
|
acquireJobLongPollDuration time.Duration
|
|
heartbeatFn func(ctx context.Context) error
|
|
heartbeatInterval time.Duration
|
|
auditor audit.Auditor
|
|
}
|
|
|
|
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub, database.ProvisionerDaemon) {
|
|
t.Helper()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
db := dbmem.New()
|
|
ps := pubsub.NewInMemory()
|
|
defOrg, err := db.GetDefaultOrganization(context.Background())
|
|
require.NoError(t, err, "default org not found")
|
|
|
|
deploymentValues := &codersdk.DeploymentValues{}
|
|
var externalAuthConfigs []*externalauth.Config
|
|
tss := testTemplateScheduleStore()
|
|
uqhss := testUserQuietHoursScheduleStore()
|
|
var timeNowFn func() time.Time
|
|
pollDur := time.Duration(0)
|
|
if ov == nil {
|
|
ov = &overrides{}
|
|
}
|
|
if ov.ctx == nil {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
ov.ctx = ctx
|
|
}
|
|
if ov.heartbeatInterval == 0 {
|
|
ov.heartbeatInterval = testutil.IntervalMedium
|
|
}
|
|
if ov.deploymentValues != nil {
|
|
deploymentValues = ov.deploymentValues
|
|
}
|
|
if ov.externalAuthConfigs != nil {
|
|
externalAuthConfigs = ov.externalAuthConfigs
|
|
}
|
|
if ov.templateScheduleStore != nil {
|
|
ttss := tss.Load()
|
|
// keep the initial test value if the override hasn't set the atomic pointer.
|
|
tss = ov.templateScheduleStore
|
|
if tss.Load() == nil {
|
|
swapped := tss.CompareAndSwap(nil, ttss)
|
|
require.True(t, swapped)
|
|
}
|
|
}
|
|
if ov.userQuietHoursScheduleStore != nil {
|
|
tuqhss := uqhss.Load()
|
|
// keep the initial test value if the override hasn't set the atomic pointer.
|
|
uqhss = ov.userQuietHoursScheduleStore
|
|
if uqhss.Load() == nil {
|
|
swapped := uqhss.CompareAndSwap(nil, tuqhss)
|
|
require.True(t, swapped)
|
|
}
|
|
}
|
|
if ov.timeNowFn != nil {
|
|
timeNowFn = ov.timeNowFn
|
|
}
|
|
auditPtr := &atomic.Pointer[audit.Auditor]{}
|
|
var auditor audit.Auditor = audit.NewMock()
|
|
if ov.auditor != nil {
|
|
auditor = ov.auditor
|
|
}
|
|
auditPtr.Store(&auditor)
|
|
pollDur = ov.acquireJobLongPollDuration
|
|
|
|
daemon, err := db.UpsertProvisionerDaemon(ov.ctx, database.UpsertProvisionerDaemonParams{
|
|
Name: "test",
|
|
CreatedAt: dbtime.Now(),
|
|
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
Tags: database.StringMap{},
|
|
LastSeenAt: sql.NullTime{},
|
|
Version: buildinfo.Version(),
|
|
APIVersion: proto.CurrentVersion.String(),
|
|
OrganizationID: defOrg.ID,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
srv, err := provisionerdserver.NewServer(
|
|
ov.ctx,
|
|
&url.URL{},
|
|
daemon.ID,
|
|
defOrg.ID,
|
|
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
|
|
[]database.ProvisionerType{database.ProvisionerTypeEcho},
|
|
provisionerdserver.Tags(daemon.Tags),
|
|
db,
|
|
ps,
|
|
provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps),
|
|
telemetry.NewNoop(),
|
|
trace.NewNoopTracerProvider().Tracer("noop"),
|
|
&atomic.Pointer[proto.QuotaCommitter]{},
|
|
auditPtr,
|
|
tss,
|
|
uqhss,
|
|
deploymentValues,
|
|
provisionerdserver.Options{
|
|
ExternalAuthConfigs: externalAuthConfigs,
|
|
TimeNowFn: timeNowFn,
|
|
OIDCConfig: &oauth2.Config{},
|
|
AcquireJobLongPollDur: pollDur,
|
|
HeartbeatInterval: ov.heartbeatInterval,
|
|
HeartbeatFn: ov.heartbeatFn,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
return srv, db, ps, daemon
|
|
}
|
|
|
|
func must[T any](value T, err error) T {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return value
|
|
}
|
|
|
|
var (
|
|
errUnimplemented = xerrors.New("unimplemented")
|
|
errClosed = xerrors.New("closed")
|
|
)
|
|
|
|
type fakeStream struct {
|
|
ctx context.Context
|
|
c *sync.Cond
|
|
closed bool
|
|
canceled bool
|
|
sendCalled bool
|
|
job *proto.AcquiredJob
|
|
}
|
|
|
|
func newFakeStream(ctx context.Context) *fakeStream {
|
|
return &fakeStream{
|
|
ctx: ctx,
|
|
c: sync.NewCond(&sync.Mutex{}),
|
|
}
|
|
}
|
|
|
|
func (s *fakeStream) Send(j *proto.AcquiredJob) error {
|
|
s.c.L.Lock()
|
|
defer s.c.L.Unlock()
|
|
s.sendCalled = true
|
|
s.job = j
|
|
s.c.Broadcast()
|
|
return nil
|
|
}
|
|
|
|
func (s *fakeStream) Recv() (*proto.CancelAcquire, error) {
|
|
s.c.L.Lock()
|
|
defer s.c.L.Unlock()
|
|
for !(s.canceled || s.closed) {
|
|
s.c.Wait()
|
|
}
|
|
if s.canceled {
|
|
return &proto.CancelAcquire{}, nil
|
|
}
|
|
return nil, io.EOF
|
|
}
|
|
|
|
// Context returns the context associated with the stream. It is canceled
|
|
// when the Stream is closed and no more messages will ever be sent or
|
|
// received on it.
|
|
func (s *fakeStream) Context() context.Context {
|
|
return s.ctx
|
|
}
|
|
|
|
// MsgSend sends the Message to the remote.
|
|
func (*fakeStream) MsgSend(drpc.Message, drpc.Encoding) error {
|
|
return errUnimplemented
|
|
}
|
|
|
|
// MsgRecv receives a Message from the remote.
|
|
func (*fakeStream) MsgRecv(drpc.Message, drpc.Encoding) error {
|
|
return errUnimplemented
|
|
}
|
|
|
|
// CloseSend signals to the remote that we will no longer send any messages.
|
|
func (*fakeStream) CloseSend() error {
|
|
return errUnimplemented
|
|
}
|
|
|
|
// Close closes the stream.
|
|
func (s *fakeStream) Close() error {
|
|
s.c.L.Lock()
|
|
defer s.c.L.Unlock()
|
|
s.closed = true
|
|
s.c.Broadcast()
|
|
return nil
|
|
}
|
|
|
|
func (s *fakeStream) waitForJob() (*proto.AcquiredJob, error) {
|
|
s.c.L.Lock()
|
|
defer s.c.L.Unlock()
|
|
for !(s.sendCalled || s.closed) {
|
|
s.c.Wait()
|
|
}
|
|
if s.sendCalled {
|
|
return s.job, nil
|
|
}
|
|
return nil, errClosed
|
|
}
|
|
|
|
func (s *fakeStream) cancel() {
|
|
s.c.L.Lock()
|
|
defer s.c.L.Unlock()
|
|
s.canceled = true
|
|
s.c.Broadcast()
|
|
}
|