coder/provisionerd/provisionerd_test.go

1167 lines
37 KiB
Go

package provisionerd_test
import (
"archive/tar"
"bytes"
"context"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/hashicorp/yamux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"golang.org/x/xerrors"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/provisionerd"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionerd/runner"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func closedWithin(c chan struct{}, d time.Duration) func() bool {
return func() bool {
select {
case <-c:
return true
case <-time.After(d):
return false
}
}
}
func TestProvisionerd(t *testing.T) {
t.Parallel()
noopUpdateJob := func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return &proto.UpdateJobResponse{}, nil
}
t.Run("InstantClose", func(t *testing.T) {
t.Parallel()
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{}), nil
}, provisionerd.Provisioners{})
require.NoError(t, closer.Close())
})
t.Run("ConnectErrorClose", func(t *testing.T) {
t.Parallel()
completeChan := make(chan struct{})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
defer close(completeChan)
return nil, xerrors.New("an error")
}, provisionerd.Provisioners{})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("AcquireEmptyJob", func(t *testing.T) {
// The provisioner daemon is supposed to skip the job acquire if
// the job provided is empty. This is to show it successfully
// tried to get a job, but none were available.
t.Parallel()
completeChan := make(chan struct{})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
acquireJobAttempt := 0
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if acquireJobAttempt == 1 {
close(completeChan)
}
acquireJobAttempt++
return &proto.AcquiredJob{}, nil
},
updateJob: noopUpdateJob,
}), nil
}, provisionerd.Provisioners{})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("CloseCancelsJob", func(t *testing.T) {
t.Parallel()
completeChan := make(chan struct{})
var completed sync.Once
var closer io.Closer
var closerMutex sync.Mutex
closerMutex.Lock()
closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: noopUpdateJob,
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
completed.Do(func() {
close(completeChan)
})
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
closerMutex.Lock()
defer closerMutex.Unlock()
return closer.Close()
},
}),
})
closerMutex.Unlock()
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("MaliciousTar", func(t *testing.T) {
// Ensures tars with "../../../etc/passwd" as the path
// are not allowed to run, and will fail the job.
t.Parallel()
var (
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"../../../etc/passwd": "content",
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: noopUpdateJob,
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
completeOnce.Do(func() { close(completeChan) })
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
t.Parallel()
var (
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
completeOnce.Do(func() { close(completeChan) })
return &proto.UpdateJobResponse{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
<-stream.Context().Done()
return nil
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("TemplateImport", func(t *testing.T) {
t.Parallel()
var (
didComplete atomic.Bool
didLog atomic.Bool
didAcquireJob atomic.Bool
didDryRun = atomic.NewBool(true)
didReadme atomic.Bool
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
runner.ReadmeFile: "# A cool template 😎\n",
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) > 0 {
didLog.Store(true)
}
if len(update.Readme) > 0 {
didReadme.Store(true)
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
didComplete.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt"))
require.NoError(t, err)
require.Equal(t, "content", string(data))
err = stream.Send(&sdkproto.Parse_Response{
Type: &sdkproto.Parse_Response_Log{
Log: &sdkproto.Log{
Level: sdkproto.LogLevel_INFO,
Output: "hello",
},
},
})
require.NoError(t, err)
err = stream.Send(&sdkproto.Parse_Response{
Type: &sdkproto.Parse_Response_Complete{
Complete: &sdkproto.Parse_Complete{},
},
})
require.NoError(t, err)
return nil
},
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
request, err := stream.Recv()
require.NoError(t, err)
if request.GetApply() != nil {
didDryRun.Store(false)
}
err = stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
Log: &sdkproto.Log{
Level: sdkproto.LogLevel_INFO,
Output: "hello",
},
},
})
require.NoError(t, err)
err = stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{
Resources: []*sdkproto.Resource{},
},
},
})
require.NoError(t, err)
return nil
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didLog.Load(), "should log some updates")
assert.True(t, didComplete.Load(), "should complete the job")
assert.True(t, didDryRun.Load(), "should be a dry run")
})
t.Run("TemplateDryRun", func(t *testing.T) {
t.Parallel()
var (
didComplete atomic.Bool
didLog atomic.Bool
didAcquireJob atomic.Bool
completeChan = make(chan struct{})
completeOnce sync.Once
parameterValues = []*sdkproto.ParameterValue{
{
DestinationScheme: sdkproto.ParameterDestination_PROVISIONER_VARIABLE,
Name: "test_var",
Value: "dean was here",
},
{
DestinationScheme: sdkproto.ParameterDestination_PROVISIONER_VARIABLE,
Name: "test_var_2",
Value: "1234",
},
}
metadata = &sdkproto.Provision_Metadata{}
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_TemplateDryRun_{
TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{
ParameterValues: parameterValues,
Metadata: metadata,
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) == 0 {
t.Log("provisionerDaemonTestServer: no log messages")
return &proto.UpdateJobResponse{}, nil
}
didLog.Store(true)
for _, msg := range update.Logs {
t.Log("provisionerDaemonTestServer", "msg:", msg)
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
didComplete.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{
Resources: []*sdkproto.Resource{},
},
},
})
require.NoError(t, err)
return nil
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didLog.Load(), "should log some updates")
assert.True(t, didComplete.Load(), "should complete the job")
})
t.Run("WorkspaceBuild", func(t *testing.T) {
t.Parallel()
var (
didComplete atomic.Bool
didLog atomic.Bool
didAcquireJob atomic.Bool
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) != 0 {
didLog.Store(true)
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
didComplete.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
Log: &sdkproto.Log{
Level: sdkproto.LogLevel_DEBUG,
Output: "wow",
},
},
})
require.NoError(t, err)
err = stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{},
},
})
require.NoError(t, err)
return nil
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didLog.Load(), "should log some updates")
assert.True(t, didComplete.Load(), "should complete the job")
})
t.Run("WorkspaceBuildQuotaExceeded", func(t *testing.T) {
t.Parallel()
var (
didComplete atomic.Bool
didLog atomic.Bool
didAcquireJob atomic.Bool
didFail atomic.Bool
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) != 0 {
didLog.Store(true)
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
didComplete.Store(true)
return &proto.Empty{}, nil
},
commitQuota: func(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
return &proto.CommitQuotaResponse{
Ok: com.DailyCost < 20,
}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
didFail.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
Log: &sdkproto.Log{
Level: sdkproto.LogLevel_DEBUG,
Output: "wow",
},
},
})
require.NoError(t, err)
err = stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{
Resources: []*sdkproto.Resource{
{
DailyCost: 10,
},
{
DailyCost: 15,
},
},
},
},
})
require.NoError(t, err)
return nil
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didLog.Load(), "should log some updates")
assert.False(t, didComplete.Load(), "should complete the job")
assert.True(t, didFail.Load(), "should fail the job")
})
t.Run("WorkspaceBuildFailComplete", func(t *testing.T) {
t.Parallel()
var (
didFail atomic.Bool
didAcquireJob atomic.Bool
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: noopUpdateJob,
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
didFail.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
return stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{
Error: "some error",
},
},
})
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didFail.Load(), "should fail the job")
})
t.Run("Shutdown", func(t *testing.T) {
t.Parallel()
var updated sync.Once
var completed sync.Once
updateChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) > 0 {
for _, log := range update.Logs {
if log.Source != proto.LogSource_PROVISIONER {
continue
}
// Close on a log so we know when the job is in progress!
updated.Do(func() {
close(updateChan)
})
break
}
}
return &proto.UpdateJobResponse{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
completed.Do(func() {
close(completeChan)
})
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
Log: &sdkproto.Log{
Level: sdkproto.LogLevel_DEBUG,
Output: "in progress",
},
},
})
require.NoError(t, err)
msg, err := stream.Recv()
require.NoError(t, err)
require.NotNil(t, msg.GetCancel())
return stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{
Error: "some error",
},
},
})
},
}),
})
require.Condition(t, closedWithin(updateChan, testutil.WaitShort))
err := server.Shutdown(context.Background())
require.NoError(t, err)
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, server.Close())
})
t.Run("ShutdownFromJob", func(t *testing.T) {
t.Parallel()
var completed sync.Once
var updated sync.Once
updateChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
resp := &proto.UpdateJobResponse{}
if len(update.Logs) > 0 {
for _, log := range update.Logs {
if log.Source != proto.LogSource_PROVISIONER {
continue
}
// Close on a log so we know when the job is in progress!
updated.Do(func() {
close(updateChan)
})
break
}
}
// start returning Canceled once we've gotten at least one log.
select {
case <-updateChan:
resp.Canceled = true
default:
// pass
}
return resp, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
completed.Do(func() {
close(completeChan)
})
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
Log: &sdkproto.Log{
Level: sdkproto.LogLevel_DEBUG,
Output: "in progress",
},
},
})
require.NoError(t, err)
msg, err := stream.Recv()
require.NoError(t, err)
require.NotNil(t, msg.GetCancel())
return stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{
Error: "some error",
},
},
})
},
}),
})
require.Condition(t, closedWithin(updateChan, testutil.WaitShort))
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, server.Close())
})
t.Run("ReconnectAndFail", func(t *testing.T) {
t.Parallel()
var (
second atomic.Bool
failChan = make(chan struct{})
failOnce sync.Once
failedChan = make(chan struct{})
failedOnce sync.Once
completeChan = make(chan struct{})
completeOnce sync.Once
)
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if second.Load() {
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return &proto.UpdateJobResponse{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
assert.Equal(t, job.JobId, "test")
if second.Load() {
completeOnce.Do(func() { close(completeChan) })
return &proto.Empty{}, nil
}
failOnce.Do(func() { close(failChan) })
<-failedChan
return &proto.Empty{}, nil
},
})
if !second.Load() {
go func() {
<-failChan
_ = client.DRPCConn().Close()
second.Store(true)
time.Sleep(50 * time.Millisecond)
failedOnce.Do(func() { close(failedChan) })
}()
}
return client, nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
return stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{
Error: "some error",
},
},
})
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, server.Close())
})
t.Run("ReconnectAndComplete", func(t *testing.T) {
t.Parallel()
var (
second atomic.Bool
failChan = make(chan struct{})
failOnce sync.Once
failedChan = make(chan struct{})
failedOnce sync.Once
completeChan = make(chan struct{})
completeOnce sync.Once
)
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if second.Load() {
completeOnce.Do(func() { close(completeChan) })
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
return nil, yamux.ErrSessionShutdown
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
if second.Load() {
return &proto.Empty{}, nil
}
failOnce.Do(func() { close(failChan) })
<-failedChan
return &proto.Empty{}, nil
},
})
if !second.Load() {
go func() {
<-failChan
_ = client.DRPCConn().Close()
second.Store(true)
time.Sleep(50 * time.Millisecond)
failedOnce.Do(func() { close(failedChan) })
}()
}
return client, nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
return stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{},
},
})
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, server.Close())
})
t.Run("UpdatesBeforeComplete", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
m := sync.Mutex{}
var ops []string
completeChan := make(chan struct{})
completeOnce := sync.Once{}
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
m.Lock()
defer m.Unlock()
logger.Info(ctx, "AcquiredJob called.")
if len(ops) > 0 {
return &proto.AcquiredJob{}, nil
}
ops = append(ops, "AcquireJob")
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
m.Lock()
defer m.Unlock()
logger.Info(ctx, "UpdateJob called.")
ops = append(ops, "UpdateJob")
for _, log := range update.Logs {
ops = append(ops, fmt.Sprintf("Log: %s | %s", log.Stage, log.Output))
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
m.Lock()
defer m.Unlock()
logger.Info(ctx, "CompleteJob called.")
ops = append(ops, "CompleteJob")
completeOnce.Do(func() { close(completeChan) })
return &proto.Empty{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
m.Lock()
defer m.Unlock()
logger.Info(ctx, "FailJob called.")
ops = append(ops, "FailJob")
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
Log: &sdkproto.Log{
Level: sdkproto.LogLevel_DEBUG,
Output: "wow",
},
},
})
require.NoError(t, err)
err = stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{},
},
})
require.NoError(t, err)
return nil
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, server.Close())
assert.Equal(t, ops[len(ops)-1], "CompleteJob")
assert.Contains(t, ops[0:len(ops)-1], "Log: Cleaning Up | ")
})
}
// Creates an in-memory tar of the files provided.
func createTar(t *testing.T, files map[string]string) []byte {
var buffer bytes.Buffer
writer := tar.NewWriter(&buffer)
for path, content := range files {
err := writer.WriteHeader(&tar.Header{
Name: path,
Size: int64(len(content)),
})
require.NoError(t, err)
_, err = writer.Write([]byte(content))
require.NoError(t, err)
}
err := writer.Flush()
require.NoError(t, err)
return buffer.Bytes()
}
// Creates a provisionerd implementation with the provided dialer and provisioners.
func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners provisionerd.Provisioners) *provisionerd.Server {
server := provisionerd.New(dialer, &provisionerd.Options{
Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug),
JobPollInterval: 50 * time.Millisecond,
UpdateInterval: 50 * time.Millisecond,
Provisioners: provisioners,
WorkDirectory: t.TempDir(),
})
t.Cleanup(func() {
_ = server.Close()
})
return server
}
// Creates a provisionerd protobuf client that's connected
// to the server implementation provided.
func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient {
t.Helper()
if server.failJob == nil {
// Default to asserting the error from the failure, otherwise
// it can be lost in tests!
server.failJob = func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
assert.Fail(t, job.Error)
return &proto.Empty{}, nil
}
}
clientPipe, serverPipe := provisionersdk.MemTransportPipe()
t.Cleanup(func() {
_ = clientPipe.Close()
_ = serverPipe.Close()
})
mux := drpcmux.New()
err := proto.DRPCRegisterProvisionerDaemon(mux, &server)
require.NoError(t, err)
srv := drpcserver.New(mux)
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(cancelFunc)
go func() {
_ = srv.Serve(ctx, serverPipe)
}()
return proto.NewDRPCProvisionerDaemonClient(clientPipe)
}
// Creates a provisioner protobuf client that's connected
// to the server implementation provided.
func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
t.Helper()
clientPipe, serverPipe := provisionersdk.MemTransportPipe()
t.Cleanup(func() {
_ = clientPipe.Close()
_ = serverPipe.Close()
})
mux := drpcmux.New()
err := sdkproto.DRPCRegisterProvisioner(mux, &server)
require.NoError(t, err)
srv := drpcserver.New(mux)
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(cancelFunc)
go func() {
_ = srv.Serve(ctx, serverPipe)
}()
return sdkproto.NewDRPCProvisionerClient(clientPipe)
}
type provisionerTestServer struct {
parse func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error
provision func(stream sdkproto.DRPCProvisioner_ProvisionStream) error
}
func (p *provisionerTestServer) Parse(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
return p.parse(request, stream)
}
func (p *provisionerTestServer) Provision(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
return p.provision(stream)
}
// Fulfills the protobuf interface for a ProvisionerDaemon with
// passable functions for dynamic functionality.
type provisionerDaemonTestServer struct {
acquireJob func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error)
commitQuota func(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error)
updateJob func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error)
failJob func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error)
completeJob func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error)
}
func (p *provisionerDaemonTestServer) AcquireJob(ctx context.Context, empty *proto.Empty) (*proto.AcquiredJob, error) {
return p.acquireJob(ctx, empty)
}
func (p *provisionerDaemonTestServer) CommitQuota(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
if p.commitQuota == nil {
return &proto.CommitQuotaResponse{
Ok: true,
}, nil
}
return p.commitQuota(ctx, com)
}
func (p *provisionerDaemonTestServer) UpdateJob(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return p.updateJob(ctx, update)
}
func (p *provisionerDaemonTestServer) FailJob(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
return p.failJob(ctx, job)
}
func (p *provisionerDaemonTestServer) CompleteJob(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
return p.completeJob(ctx, job)
}