fix: Add resiliency to daemon connections (#1116)

Connections could fail when massive payloads were transmitted.
This fixes an upstream bug in dRPC where the connection would
end with a context canceled if a message was too large.

This adds retransmission of completion and failures too. If
Coder somehow loses connection with a provisioner daemon,
upon the next connection the state will be properly reported.
This commit is contained in:
Kyle Carberry 2022-04-24 20:33:19 -05:00 committed by GitHub
parent be974cf280
commit db7ed4d019
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 351 additions and 65 deletions

View File

@ -17,6 +17,7 @@ import (
"github.com/moby/moby/pkg/namesgenerator"
"github.com/tabbed/pqtype"
"golang.org/x/xerrors"
protobuf "google.golang.org/protobuf/proto"
"nhooyr.io/websocket"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
@ -27,6 +28,7 @@ import (
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
)
@ -47,6 +49,8 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request
})
return
}
// Align with the frame size of yamux.
conn.SetReadLimit(256 * 1024)
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
ID: uuid.New(),
@ -82,9 +86,17 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err))
return
}
server := drpcserver.New(mux)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err))
},
})
err = server.Serve(r.Context(), session)
if err != nil {
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
@ -253,6 +265,9 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty
default:
return nil, failJob(fmt.Sprintf("unsupported storage method: %s", job.StorageMethod))
}
if protobuf.Size(protoJob) > provisionersdk.MaxMessageSize {
return nil, failJob(fmt.Sprintf("payload was too big: %d > %d", protobuf.Size(protoJob), provisionersdk.MaxMessageSize))
}
return protoJob, err
}

View File

@ -0,0 +1,48 @@
package coderd_test
import (
"context"
"crypto/rand"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
)
func TestProvisionerDaemons(t *testing.T) {
t.Parallel()
t.Run("PayloadTooBig", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// Takes too long to allocate memory on Windows!
t.Skip()
}
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdtest.NewProvisionerDaemon(t, client)
data := make([]byte, provisionersdk.MaxMessageSize)
rand.Read(data)
resp, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data)
require.NoError(t, err)
t.Log(resp.Hash)
version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{
StorageMethod: database.ProvisionerStorageMethodFile,
StorageSource: resp.Hash,
Provisioner: database.ProvisionerTypeEcho,
})
require.NoError(t, err)
require.Eventually(t, func() bool {
var err error
version, err = client.TemplateVersion(context.Background(), version.ID)
require.NoError(t, err)
return version.Job.Error != ""
}, 5*time.Second, 25*time.Millisecond)
})
}

View File

@ -70,8 +70,8 @@ func (c *Client) ListenProvisionerDaemon(ctx context.Context) (proto.DRPCProvisi
}
return nil, readBodyAsError(res)
}
// Allow _somewhat_ large payloads.
conn.SetReadLimit((1 << 20) * 2)
// Align with the frame size of yamux.
conn.SetReadLimit(256 * 1024)
config := yamux.DefaultConfig()
config.LogOutput = io.Discard

3
go.mod
View File

@ -17,6 +17,9 @@ replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220
// Required until https://github.com/briandowns/spinner/pull/136 is merged.
replace github.com/briandowns/spinner => github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e
// Required until https://github.com/storj/drpc/pull/31 is merged.
replace storj.io/drpc => github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff
// opencensus-go leaks a goroutine by default.
replace go.opencensus.io => github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b

4
go.sum
View File

@ -1107,6 +1107,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/ktrysmt/go-bitbucket v0.6.4/go.mod h1:9u0v3hsd2rqCHRIpbir1oP7F58uo5dq19sBYvuMoyQ4=
github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff h1:7qg425aXdULnZWCCQNPOzHO7c+M6BpbTfOUJLrk5+3w=
github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff/go.mod h1:6rcOyR/QQkSTX/9L5ZGtlZaE2PtXTTZl8d+ulSeeYEg=
github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b h1:1Y1X6aR78kMEQE1iCjQodB3lA7VO4jB88Wf8ZrzXSsA=
github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8/go.mod h1:n/KX1BZoN1m9EwoXkn/xAV4fd3k8c++gGBsgLONaPOY=
@ -2544,5 +2546,3 @@ sigs.k8s.io/structured-merge-diff/v4 v4.0.3/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK
sigs.k8s.io/structured-merge-diff/v4 v4.1.0/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc=
storj.io/drpc v0.0.30 h1:jqPe4T9KEu3CDBI05A2hCMgMSHLtd/E0N0yTF9QreIE=
storj.io/drpc v0.0.30/go.mod h1:6rcOyR/QQkSTX/9L5ZGtlZaE2PtXTTZl8d+ulSeeYEg=

View File

@ -68,8 +68,8 @@ func New(clientDialer Dialer, opts *Options) *Server {
clientDialer: clientDialer,
opts: opts,
closeCancel: ctxCancel,
closed: make(chan struct{}),
closeContext: ctx,
closeCancel: ctxCancel,
shutdown: make(chan struct{}),
@ -87,13 +87,13 @@ type Server struct {
opts *Options
clientDialer Dialer
client proto.DRPCProvisionerDaemonClient
clientValue atomic.Value
// Locked when closing the daemon.
closeMutex sync.Mutex
closeCancel context.CancelFunc
closed chan struct{}
closeError error
closeMutex sync.Mutex
closeContext context.Context
closeCancel context.CancelFunc
closeError error
shutdownMutex sync.Mutex
shutdown chan struct{}
@ -108,11 +108,10 @@ type Server struct {
// Connect establishes a connection to coderd.
func (p *Server) connect(ctx context.Context) {
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
p.client, err = p.clientDialer(ctx)
client, err := p.clientDialer(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
@ -126,6 +125,7 @@ func (p *Server) connect(ctx context.Context) {
p.closeMutex.Unlock()
continue
}
p.clientValue.Store(client)
p.opts.Logger.Debug(context.Background(), "connected")
break
}
@ -139,10 +139,14 @@ func (p *Server) connect(ctx context.Context) {
if p.isClosed() {
return
}
select {
case <-p.closed:
client, ok := p.client()
if !ok {
return
case <-p.client.DRPCConn().Closed():
}
select {
case <-p.closeContext.Done():
return
case <-client.DRPCConn().Closed():
// We use the update stream to detect when the connection
// has been interrupted. This works well, because logs need
// to buffer if a job is running in the background.
@ -158,10 +162,14 @@ func (p *Server) connect(ctx context.Context) {
ticker := time.NewTicker(p.opts.PollInterval)
defer ticker.Stop()
for {
select {
case <-p.closed:
client, ok := p.client()
if !ok {
return
case <-p.client.DRPCConn().Closed():
}
select {
case <-p.closeContext.Done():
return
case <-client.DRPCConn().Closed():
return
case <-ticker.C:
p.acquireJob(ctx)
@ -170,6 +178,15 @@ func (p *Server) connect(ctx context.Context) {
}()
}
func (p *Server) client() (proto.DRPCProvisionerDaemonClient, bool) {
rawClient := p.clientValue.Load()
if rawClient == nil {
return nil, false
}
client, ok := rawClient.(proto.DRPCProvisionerDaemonClient)
return client, ok
}
func (p *Server) isRunningJob() bool {
select {
case <-p.jobRunning:
@ -195,7 +212,11 @@ func (p *Server) acquireJob(ctx context.Context) {
return
}
var err error
job, err := p.client.AcquireJob(ctx, &proto.Empty{})
client, ok := p.client()
if !ok {
return
}
job, err := client.AcquireJob(ctx, &proto.Empty{})
if err != nil {
if errors.Is(err, context.Canceled) {
return
@ -231,7 +252,7 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
defer ticker.Stop()
for {
select {
case <-p.closed:
case <-p.closeContext.Done():
return
case <-ctx.Done():
return
@ -241,9 +262,16 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
return
case <-ticker.C:
}
resp, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
client, ok := p.client()
if !ok {
continue
}
resp, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.JobId,
})
if errors.Is(err, yamux.ErrSessionShutdown) || errors.Is(err, io.EOF) {
continue
}
if err != nil {
p.failActiveJobf("send periodic update: %s", err)
return
@ -297,7 +325,12 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
return
}
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
client, ok := p.client()
if !ok {
p.failActiveJobf("client disconnected")
return
}
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.GetJobId(),
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER_DAEMON,
@ -387,10 +420,14 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
return
}
client, ok = p.client()
if !ok {
return
}
// Ensure the job is still running to output.
// It's possible the job has failed.
if p.isRunningJob() {
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.GetJobId(),
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER_DAEMON,
@ -409,7 +446,12 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
}
func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) {
_, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
client, ok := p.client()
if !ok {
p.failActiveJobf("client disconnected")
return
}
_, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.GetJobId(),
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER_DAEMON,
@ -429,7 +471,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
return
}
updateResponse, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
updateResponse, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.JobId,
ParameterSchemas: parameterSchemas,
})
@ -450,7 +492,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
}
}
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.GetJobId(),
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER_DAEMON,
@ -471,7 +513,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
p.failActiveJobf("template import provision for start: %s", err)
return
}
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.GetJobId(),
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER_DAEMON,
@ -493,7 +535,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
return
}
_, err = p.client.CompleteJob(ctx, &proto.CompletedJob{
p.completeJob(&proto.CompletedJob{
JobId: job.JobId,
Type: &proto.CompletedJob_TemplateImport_{
TemplateImport: &proto.CompletedJob_TemplateImport{
@ -502,14 +544,14 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
},
},
})
if err != nil {
p.failActiveJobf("complete job: %s", err)
return
}
}
// Parses parameter schemas from source.
func (p *Server) runTemplateImportParse(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) ([]*sdkproto.ParameterSchema, error) {
client, ok := p.client()
if !ok {
return nil, xerrors.New("client disconnected")
}
stream, err := provisioner.Parse(ctx, &sdkproto.Parse_Request{
Directory: p.opts.WorkDirectory,
})
@ -529,7 +571,7 @@ func (p *Server) runTemplateImportParse(ctx context.Context, provisioner sdkprot
slog.F("output", msgType.Log.Output),
)
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.JobId,
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER,
@ -599,8 +641,11 @@ func (p *Server) runTemplateImportProvision(ctx, shutdown context.Context, provi
slog.F("level", msgType.Log.Level),
slog.F("output", msgType.Log.Output),
)
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
client, ok := p.client()
if !ok {
continue
}
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.JobId,
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER,
@ -638,7 +683,12 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
stage = "Destroying workspace"
}
_, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
client, ok := p.client()
if !ok {
p.failActiveJobf("client disconnected")
return
}
_, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.GetJobId(),
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER_DAEMON,
@ -699,7 +749,7 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
slog.F("workspace_build_id", job.GetWorkspaceBuild().WorkspaceBuildId),
)
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.JobId,
Logs: []*proto.Log{{
Source: proto.LogSource_PROVISIONER,
@ -729,15 +779,7 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
return
}
p.opts.Logger.Info(context.Background(), "provision successful; marking job as complete",
slog.F("resource_count", len(msgType.Complete.Resources)),
slog.F("resources", msgType.Complete.Resources),
slog.F("state_length", len(msgType.Complete.State)),
)
// Complete job may need to be async if we disconnected...
// When we reconnect we can flush any of these cached values.
_, err = p.client.CompleteJob(ctx, &proto.CompletedJob{
p.completeJob(&proto.CompletedJob{
JobId: job.JobId,
Type: &proto.CompletedJob_WorkspaceBuild_{
WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{
@ -746,11 +788,12 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
},
},
})
if err != nil {
p.failActiveJobf("complete job: %s", err)
return
}
// Return so we stop looping!
p.opts.Logger.Info(context.Background(), "provision successful; marked job as complete",
slog.F("resource_count", len(msgType.Complete.Resources)),
slog.F("resources", msgType.Complete.Resources),
slog.F("state_length", len(msgType.Complete.State)),
)
// Stop looping!
return
default:
p.failActiveJobf("invalid message type %T received from provisioner", msg.Type)
@ -759,6 +802,26 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
}
}
func (p *Server) completeJob(job *proto.CompletedJob) {
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); {
client, ok := p.client()
if !ok {
continue
}
// Complete job may need to be async if we disconnected...
// When we reconnect we can flush any of these cached values.
_, err := client.CompleteJob(p.closeContext, job)
if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) {
continue
}
if err != nil {
p.opts.Logger.Warn(p.closeContext, "failed to complete job", slog.Error(err))
return
}
break
}
}
func (p *Server) failActiveJobf(format string, args ...interface{}) {
p.failActiveJob(&proto.FailedJob{
Error: fmt.Sprintf(format, args...),
@ -786,18 +849,31 @@ func (p *Server) failActiveJob(failedJob *proto.FailedJob) {
slog.F("job_id", p.jobID),
)
failedJob.JobId = p.jobID
_, err := p.client.FailJob(context.Background(), failedJob)
if err != nil {
p.opts.Logger.Warn(context.Background(), "failed to notify of error; job is no longer running", slog.Error(err))
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); {
client, ok := p.client()
if !ok {
continue
}
_, err := client.FailJob(p.closeContext, failedJob)
if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) {
continue
}
if err != nil {
if p.isClosed() {
return
}
p.opts.Logger.Warn(context.Background(), "failed to notify of error; job is no longer running", slog.Error(err))
return
}
p.opts.Logger.Debug(context.Background(), "marked running job as failed")
return
}
p.opts.Logger.Debug(context.Background(), "marked running job as failed")
}
// isClosed returns whether the API is closed or not.
func (p *Server) isClosed() bool {
select {
case <-p.closed:
case <-p.closeContext.Done():
return true
default:
return false
@ -847,7 +923,6 @@ func (p *Server) closeWithError(err error) error {
return p.closeError
}
p.closeError = err
close(p.closed)
errMsg := "provisioner daemon was shutdown gracefully"
if err != nil {

View File

@ -11,6 +11,7 @@ import (
"testing"
"time"
"github.com/hashicorp/yamux"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
@ -126,6 +127,7 @@ func TestProvisionerd(t *testing.T) {
// Ensures tars with "../../../etc/passwd" as the path
// are not allowed to run, and will fail the job.
t.Parallel()
var complete sync.Once
completeChan := make(chan struct{})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
@ -145,7 +147,9 @@ func TestProvisionerd(t *testing.T) {
},
updateJob: noopUpdateJob,
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
close(completeChan)
complete.Do(func() {
close(completeChan)
})
return &proto.Empty{}, nil
},
}), nil
@ -158,6 +162,7 @@ func TestProvisionerd(t *testing.T) {
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
t.Parallel()
var complete sync.Once
completeChan := make(chan struct{})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
@ -176,11 +181,9 @@ func TestProvisionerd(t *testing.T) {
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
select {
case <-completeChan:
default:
complete.Do(func() {
close(completeChan)
}
})
return &proto.UpdateJobResponse{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
@ -492,6 +495,7 @@ func TestProvisionerd(t *testing.T) {
t.Run("ShutdownFromJob", func(t *testing.T) {
t.Parallel()
var updated sync.Once
updateChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
@ -513,7 +517,9 @@ func TestProvisionerd(t *testing.T) {
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) > 0 && update.Logs[0].Source == proto.LogSource_PROVISIONER {
// Close on a log so we know when the job is in progress!
close(updateChan)
updated.Do(func() {
close(updateChan)
})
}
return &proto.UpdateJobResponse{
Canceled: true,
@ -558,6 +564,139 @@ func TestProvisionerd(t *testing.T) {
<-completeChan
require.NoError(t, server.Close())
})
t.Run("ReconnectAndFail", func(t *testing.T) {
t.Parallel()
var second atomic.Bool
failChan := make(chan struct{})
failedChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if second.Load() {
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return &proto.UpdateJobResponse{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
if second.Load() {
close(completeChan)
return &proto.Empty{}, nil
}
close(failChan)
<-failedChan
return &proto.Empty{}, nil
},
})
if !second.Load() {
go func() {
<-failChan
_ = client.DRPCConn().Close()
second.Store(true)
close(failedChan)
}()
}
return client, nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
return stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{
Error: "some error",
},
},
})
},
}),
})
<-completeChan
require.NoError(t, server.Close())
})
t.Run("ReconnectAndComplete", func(t *testing.T) {
t.Parallel()
var second atomic.Bool
failChan := make(chan struct{})
failedChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if second.Load() {
close(completeChan)
return &proto.AcquiredJob{}, nil
}
return &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: createTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Provision_Metadata{},
},
},
}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
return nil, yamux.ErrSessionShutdown
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
if second.Load() {
return &proto.Empty{}, nil
}
close(failChan)
<-failedChan
return &proto.Empty{}, nil
},
})
if !second.Load() {
go func() {
<-failChan
_ = client.DRPCConn().Close()
second.Store(true)
close(failedChan)
}()
}
return client, nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
return stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
Complete: &sdkproto.Provision_Complete{},
},
})
},
}),
})
<-completeChan
require.NoError(t, server.Close())
})
}
// Creates an in-memory tar of the files provided.

View File

@ -9,6 +9,12 @@ import (
"storj.io/drpc/drpcconn"
)
const (
// MaxMessageSize is the maximum payload size that can be
// transported without error.
MaxMessageSize = 4 << 20
)
// TransportPipe creates an in-memory pipe for dRPC transport.
func TransportPipe() (*yamux.Session, *yamux.Session) {
clientReader, clientWriter := io.Pipe()