mirror of https://github.com/coder/coder.git
fix: Add resiliency to daemon connections (#1116)
Connections could fail when massive payloads were transmitted. This fixes an upstream bug in dRPC where the connection would end with a context canceled if a message was too large. This adds retransmission of completion and failures too. If Coder somehow loses connection with a provisioner daemon, upon the next connection the state will be properly reported.
This commit is contained in:
parent
be974cf280
commit
db7ed4d019
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"github.com/tabbed/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
protobuf "google.golang.org/protobuf/proto"
|
||||
"nhooyr.io/websocket"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
@ -27,6 +28,7 @@ import (
|
|||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/parameter"
|
||||
"github.com/coder/coder/provisionerd/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
sdkproto "github.com/coder/coder/provisionersdk/proto"
|
||||
)
|
||||
|
||||
|
@ -47,6 +49,8 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request
|
|||
})
|
||||
return
|
||||
}
|
||||
// Align with the frame size of yamux.
|
||||
conn.SetReadLimit(256 * 1024)
|
||||
|
||||
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
|
||||
ID: uuid.New(),
|
||||
|
@ -82,9 +86,17 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request
|
|||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err))
|
||||
return
|
||||
}
|
||||
server := drpcserver.New(mux)
|
||||
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err))
|
||||
},
|
||||
})
|
||||
err = server.Serve(r.Context(), session)
|
||||
if err != nil {
|
||||
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
|
||||
return
|
||||
}
|
||||
|
@ -253,6 +265,9 @@ func (server *provisionerdServer) AcquireJob(ctx context.Context, _ *proto.Empty
|
|||
default:
|
||||
return nil, failJob(fmt.Sprintf("unsupported storage method: %s", job.StorageMethod))
|
||||
}
|
||||
if protobuf.Size(protoJob) > provisionersdk.MaxMessageSize {
|
||||
return nil, failJob(fmt.Sprintf("payload was too big: %d > %d", protobuf.Size(protoJob), provisionersdk.MaxMessageSize))
|
||||
}
|
||||
|
||||
return protoJob, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
)
|
||||
|
||||
func TestProvisionerDaemons(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("PayloadTooBig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
// Takes too long to allocate memory on Windows!
|
||||
t.Skip()
|
||||
}
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
coderdtest.NewProvisionerDaemon(t, client)
|
||||
data := make([]byte, provisionersdk.MaxMessageSize)
|
||||
rand.Read(data)
|
||||
resp, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data)
|
||||
require.NoError(t, err)
|
||||
t.Log(resp.Hash)
|
||||
|
||||
version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
StorageSource: resp.Hash,
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
version, err = client.TemplateVersion(context.Background(), version.ID)
|
||||
require.NoError(t, err)
|
||||
return version.Job.Error != ""
|
||||
}, 5*time.Second, 25*time.Millisecond)
|
||||
})
|
||||
}
|
|
@ -70,8 +70,8 @@ func (c *Client) ListenProvisionerDaemon(ctx context.Context) (proto.DRPCProvisi
|
|||
}
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
// Allow _somewhat_ large payloads.
|
||||
conn.SetReadLimit((1 << 20) * 2)
|
||||
// Align with the frame size of yamux.
|
||||
conn.SetReadLimit(256 * 1024)
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
|
|
3
go.mod
3
go.mod
|
@ -17,6 +17,9 @@ replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220
|
|||
// Required until https://github.com/briandowns/spinner/pull/136 is merged.
|
||||
replace github.com/briandowns/spinner => github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e
|
||||
|
||||
// Required until https://github.com/storj/drpc/pull/31 is merged.
|
||||
replace storj.io/drpc => github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff
|
||||
|
||||
// opencensus-go leaks a goroutine by default.
|
||||
replace go.opencensus.io => github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b
|
||||
|
||||
|
|
4
go.sum
4
go.sum
|
@ -1107,6 +1107,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/ktrysmt/go-bitbucket v0.6.4/go.mod h1:9u0v3hsd2rqCHRIpbir1oP7F58uo5dq19sBYvuMoyQ4=
|
||||
github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff h1:7qg425aXdULnZWCCQNPOzHO7c+M6BpbTfOUJLrk5+3w=
|
||||
github.com/kylecarbs/drpc v0.0.31-0.20220424193521-8ebbaf48bdff/go.mod h1:6rcOyR/QQkSTX/9L5ZGtlZaE2PtXTTZl8d+ulSeeYEg=
|
||||
github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b h1:1Y1X6aR78kMEQE1iCjQodB3lA7VO4jB88Wf8ZrzXSsA=
|
||||
github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
|
||||
github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8/go.mod h1:n/KX1BZoN1m9EwoXkn/xAV4fd3k8c++gGBsgLONaPOY=
|
||||
|
@ -2544,5 +2546,3 @@ sigs.k8s.io/structured-merge-diff/v4 v4.0.3/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK
|
|||
sigs.k8s.io/structured-merge-diff/v4 v4.1.0/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw=
|
||||
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
|
||||
sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc=
|
||||
storj.io/drpc v0.0.30 h1:jqPe4T9KEu3CDBI05A2hCMgMSHLtd/E0N0yTF9QreIE=
|
||||
storj.io/drpc v0.0.30/go.mod h1:6rcOyR/QQkSTX/9L5ZGtlZaE2PtXTTZl8d+ulSeeYEg=
|
||||
|
|
|
@ -68,8 +68,8 @@ func New(clientDialer Dialer, opts *Options) *Server {
|
|||
clientDialer: clientDialer,
|
||||
opts: opts,
|
||||
|
||||
closeCancel: ctxCancel,
|
||||
closed: make(chan struct{}),
|
||||
closeContext: ctx,
|
||||
closeCancel: ctxCancel,
|
||||
|
||||
shutdown: make(chan struct{}),
|
||||
|
||||
|
@ -87,13 +87,13 @@ type Server struct {
|
|||
opts *Options
|
||||
|
||||
clientDialer Dialer
|
||||
client proto.DRPCProvisionerDaemonClient
|
||||
clientValue atomic.Value
|
||||
|
||||
// Locked when closing the daemon.
|
||||
closeMutex sync.Mutex
|
||||
closeCancel context.CancelFunc
|
||||
closed chan struct{}
|
||||
closeError error
|
||||
closeMutex sync.Mutex
|
||||
closeContext context.Context
|
||||
closeCancel context.CancelFunc
|
||||
closeError error
|
||||
|
||||
shutdownMutex sync.Mutex
|
||||
shutdown chan struct{}
|
||||
|
@ -108,11 +108,10 @@ type Server struct {
|
|||
|
||||
// Connect establishes a connection to coderd.
|
||||
func (p *Server) connect(ctx context.Context) {
|
||||
var err error
|
||||
// An exponential back-off occurs when the connection is failing to dial.
|
||||
// This is to prevent server spam in case of a coderd outage.
|
||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
||||
p.client, err = p.clientDialer(ctx)
|
||||
client, err := p.clientDialer(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
|
@ -126,6 +125,7 @@ func (p *Server) connect(ctx context.Context) {
|
|||
p.closeMutex.Unlock()
|
||||
continue
|
||||
}
|
||||
p.clientValue.Store(client)
|
||||
p.opts.Logger.Debug(context.Background(), "connected")
|
||||
break
|
||||
}
|
||||
|
@ -139,10 +139,14 @@ func (p *Server) connect(ctx context.Context) {
|
|||
if p.isClosed() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.closed:
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
return
|
||||
case <-p.client.DRPCConn().Closed():
|
||||
}
|
||||
select {
|
||||
case <-p.closeContext.Done():
|
||||
return
|
||||
case <-client.DRPCConn().Closed():
|
||||
// We use the update stream to detect when the connection
|
||||
// has been interrupted. This works well, because logs need
|
||||
// to buffer if a job is running in the background.
|
||||
|
@ -158,10 +162,14 @@ func (p *Server) connect(ctx context.Context) {
|
|||
ticker := time.NewTicker(p.opts.PollInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.closed:
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
return
|
||||
case <-p.client.DRPCConn().Closed():
|
||||
}
|
||||
select {
|
||||
case <-p.closeContext.Done():
|
||||
return
|
||||
case <-client.DRPCConn().Closed():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.acquireJob(ctx)
|
||||
|
@ -170,6 +178,15 @@ func (p *Server) connect(ctx context.Context) {
|
|||
}()
|
||||
}
|
||||
|
||||
func (p *Server) client() (proto.DRPCProvisionerDaemonClient, bool) {
|
||||
rawClient := p.clientValue.Load()
|
||||
if rawClient == nil {
|
||||
return nil, false
|
||||
}
|
||||
client, ok := rawClient.(proto.DRPCProvisionerDaemonClient)
|
||||
return client, ok
|
||||
}
|
||||
|
||||
func (p *Server) isRunningJob() bool {
|
||||
select {
|
||||
case <-p.jobRunning:
|
||||
|
@ -195,7 +212,11 @@ func (p *Server) acquireJob(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
var err error
|
||||
job, err := p.client.AcquireJob(ctx, &proto.Empty{})
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
job, err := client.AcquireJob(ctx, &proto.Empty{})
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
|
@ -231,7 +252,7 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
|
|||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.closed:
|
||||
case <-p.closeContext.Done():
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
@ -241,9 +262,16 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
|
|||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
resp, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
resp, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.JobId,
|
||||
})
|
||||
if errors.Is(err, yamux.ErrSessionShutdown) || errors.Is(err, io.EOF) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
p.failActiveJobf("send periodic update: %s", err)
|
||||
return
|
||||
|
@ -297,7 +325,12 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
|
|||
return
|
||||
}
|
||||
|
||||
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
p.failActiveJobf("client disconnected")
|
||||
return
|
||||
}
|
||||
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.GetJobId(),
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER_DAEMON,
|
||||
|
@ -387,10 +420,14 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
|
|||
return
|
||||
}
|
||||
|
||||
client, ok = p.client()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Ensure the job is still running to output.
|
||||
// It's possible the job has failed.
|
||||
if p.isRunningJob() {
|
||||
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.GetJobId(),
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER_DAEMON,
|
||||
|
@ -409,7 +446,12 @@ func (p *Server) runJob(ctx context.Context, job *proto.AcquiredJob) {
|
|||
}
|
||||
|
||||
func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) {
|
||||
_, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
p.failActiveJobf("client disconnected")
|
||||
return
|
||||
}
|
||||
_, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.GetJobId(),
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER_DAEMON,
|
||||
|
@ -429,7 +471,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
|
|||
return
|
||||
}
|
||||
|
||||
updateResponse, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
updateResponse, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.JobId,
|
||||
ParameterSchemas: parameterSchemas,
|
||||
})
|
||||
|
@ -450,7 +492,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
|
|||
}
|
||||
}
|
||||
|
||||
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.GetJobId(),
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER_DAEMON,
|
||||
|
@ -471,7 +513,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
|
|||
p.failActiveJobf("template import provision for start: %s", err)
|
||||
return
|
||||
}
|
||||
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.GetJobId(),
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER_DAEMON,
|
||||
|
@ -493,7 +535,7 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
|
|||
return
|
||||
}
|
||||
|
||||
_, err = p.client.CompleteJob(ctx, &proto.CompletedJob{
|
||||
p.completeJob(&proto.CompletedJob{
|
||||
JobId: job.JobId,
|
||||
Type: &proto.CompletedJob_TemplateImport_{
|
||||
TemplateImport: &proto.CompletedJob_TemplateImport{
|
||||
|
@ -502,14 +544,14 @@ func (p *Server) runTemplateImport(ctx, shutdown context.Context, provisioner sd
|
|||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
p.failActiveJobf("complete job: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Parses parameter schemas from source.
|
||||
func (p *Server) runTemplateImportParse(ctx context.Context, provisioner sdkproto.DRPCProvisionerClient, job *proto.AcquiredJob) ([]*sdkproto.ParameterSchema, error) {
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
return nil, xerrors.New("client disconnected")
|
||||
}
|
||||
stream, err := provisioner.Parse(ctx, &sdkproto.Parse_Request{
|
||||
Directory: p.opts.WorkDirectory,
|
||||
})
|
||||
|
@ -529,7 +571,7 @@ func (p *Server) runTemplateImportParse(ctx context.Context, provisioner sdkprot
|
|||
slog.F("output", msgType.Log.Output),
|
||||
)
|
||||
|
||||
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.JobId,
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER,
|
||||
|
@ -599,8 +641,11 @@ func (p *Server) runTemplateImportProvision(ctx, shutdown context.Context, provi
|
|||
slog.F("level", msgType.Log.Level),
|
||||
slog.F("output", msgType.Log.Output),
|
||||
)
|
||||
|
||||
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.JobId,
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER,
|
||||
|
@ -638,7 +683,12 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
|
|||
stage = "Destroying workspace"
|
||||
}
|
||||
|
||||
_, err := p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
p.failActiveJobf("client disconnected")
|
||||
return
|
||||
}
|
||||
_, err := client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.GetJobId(),
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER_DAEMON,
|
||||
|
@ -699,7 +749,7 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
|
|||
slog.F("workspace_build_id", job.GetWorkspaceBuild().WorkspaceBuildId),
|
||||
)
|
||||
|
||||
_, err = p.client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
_, err = client.UpdateJob(ctx, &proto.UpdateJobRequest{
|
||||
JobId: job.JobId,
|
||||
Logs: []*proto.Log{{
|
||||
Source: proto.LogSource_PROVISIONER,
|
||||
|
@ -729,15 +779,7 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
|
|||
return
|
||||
}
|
||||
|
||||
p.opts.Logger.Info(context.Background(), "provision successful; marking job as complete",
|
||||
slog.F("resource_count", len(msgType.Complete.Resources)),
|
||||
slog.F("resources", msgType.Complete.Resources),
|
||||
slog.F("state_length", len(msgType.Complete.State)),
|
||||
)
|
||||
|
||||
// Complete job may need to be async if we disconnected...
|
||||
// When we reconnect we can flush any of these cached values.
|
||||
_, err = p.client.CompleteJob(ctx, &proto.CompletedJob{
|
||||
p.completeJob(&proto.CompletedJob{
|
||||
JobId: job.JobId,
|
||||
Type: &proto.CompletedJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{
|
||||
|
@ -746,11 +788,12 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
|
|||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
p.failActiveJobf("complete job: %s", err)
|
||||
return
|
||||
}
|
||||
// Return so we stop looping!
|
||||
p.opts.Logger.Info(context.Background(), "provision successful; marked job as complete",
|
||||
slog.F("resource_count", len(msgType.Complete.Resources)),
|
||||
slog.F("resources", msgType.Complete.Resources),
|
||||
slog.F("state_length", len(msgType.Complete.State)),
|
||||
)
|
||||
// Stop looping!
|
||||
return
|
||||
default:
|
||||
p.failActiveJobf("invalid message type %T received from provisioner", msg.Type)
|
||||
|
@ -759,6 +802,26 @@ func (p *Server) runWorkspaceBuild(ctx, shutdown context.Context, provisioner sd
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Server) completeJob(job *proto.CompletedJob) {
|
||||
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); {
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// Complete job may need to be async if we disconnected...
|
||||
// When we reconnect we can flush any of these cached values.
|
||||
_, err := client.CompleteJob(p.closeContext, job)
|
||||
if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
p.opts.Logger.Warn(p.closeContext, "failed to complete job", slog.Error(err))
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) failActiveJobf(format string, args ...interface{}) {
|
||||
p.failActiveJob(&proto.FailedJob{
|
||||
Error: fmt.Sprintf(format, args...),
|
||||
|
@ -786,18 +849,31 @@ func (p *Server) failActiveJob(failedJob *proto.FailedJob) {
|
|||
slog.F("job_id", p.jobID),
|
||||
)
|
||||
failedJob.JobId = p.jobID
|
||||
_, err := p.client.FailJob(context.Background(), failedJob)
|
||||
if err != nil {
|
||||
p.opts.Logger.Warn(context.Background(), "failed to notify of error; job is no longer running", slog.Error(err))
|
||||
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(p.closeContext); {
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
_, err := client.FailJob(p.closeContext, failedJob)
|
||||
if xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
if p.isClosed() {
|
||||
return
|
||||
}
|
||||
p.opts.Logger.Warn(context.Background(), "failed to notify of error; job is no longer running", slog.Error(err))
|
||||
return
|
||||
}
|
||||
p.opts.Logger.Debug(context.Background(), "marked running job as failed")
|
||||
return
|
||||
}
|
||||
p.opts.Logger.Debug(context.Background(), "marked running job as failed")
|
||||
}
|
||||
|
||||
// isClosed returns whether the API is closed or not.
|
||||
func (p *Server) isClosed() bool {
|
||||
select {
|
||||
case <-p.closed:
|
||||
case <-p.closeContext.Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
@ -847,7 +923,6 @@ func (p *Server) closeWithError(err error) error {
|
|||
return p.closeError
|
||||
}
|
||||
p.closeError = err
|
||||
close(p.closed)
|
||||
|
||||
errMsg := "provisioner daemon was shutdown gracefully"
|
||||
if err != nil {
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/goleak"
|
||||
|
@ -126,6 +127,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
// Ensures tars with "../../../etc/passwd" as the path
|
||||
// are not allowed to run, and will fail the job.
|
||||
t.Parallel()
|
||||
var complete sync.Once
|
||||
completeChan := make(chan struct{})
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
|
@ -145,7 +147,9 @@ func TestProvisionerd(t *testing.T) {
|
|||
},
|
||||
updateJob: noopUpdateJob,
|
||||
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
|
||||
close(completeChan)
|
||||
complete.Do(func() {
|
||||
close(completeChan)
|
||||
})
|
||||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
|
@ -158,6 +162,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
|
||||
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var complete sync.Once
|
||||
completeChan := make(chan struct{})
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
|
@ -176,11 +181,9 @@ func TestProvisionerd(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
|
||||
select {
|
||||
case <-completeChan:
|
||||
default:
|
||||
complete.Do(func() {
|
||||
close(completeChan)
|
||||
}
|
||||
})
|
||||
return &proto.UpdateJobResponse{}, nil
|
||||
},
|
||||
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
|
||||
|
@ -492,6 +495,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
|
||||
t.Run("ShutdownFromJob", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var updated sync.Once
|
||||
updateChan := make(chan struct{})
|
||||
completeChan := make(chan struct{})
|
||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
|
@ -513,7 +517,9 @@ func TestProvisionerd(t *testing.T) {
|
|||
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
|
||||
if len(update.Logs) > 0 && update.Logs[0].Source == proto.LogSource_PROVISIONER {
|
||||
// Close on a log so we know when the job is in progress!
|
||||
close(updateChan)
|
||||
updated.Do(func() {
|
||||
close(updateChan)
|
||||
})
|
||||
}
|
||||
return &proto.UpdateJobResponse{
|
||||
Canceled: true,
|
||||
|
@ -558,6 +564,139 @@ func TestProvisionerd(t *testing.T) {
|
|||
<-completeChan
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
t.Run("ReconnectAndFail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var second atomic.Bool
|
||||
failChan := make(chan struct{})
|
||||
failedChan := make(chan struct{})
|
||||
completeChan := make(chan struct{})
|
||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if second.Load() {
|
||||
return &proto.AcquiredJob{}, nil
|
||||
}
|
||||
return &proto.AcquiredJob{
|
||||
JobId: "test",
|
||||
Provisioner: "someprovisioner",
|
||||
TemplateSourceArchive: createTar(t, map[string]string{
|
||||
"test.txt": "content",
|
||||
}),
|
||||
Type: &proto.AcquiredJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
|
||||
Metadata: &sdkproto.Provision_Metadata{},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
|
||||
return &proto.UpdateJobResponse{}, nil
|
||||
},
|
||||
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
|
||||
if second.Load() {
|
||||
close(completeChan)
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
close(failChan)
|
||||
<-failedChan
|
||||
return &proto.Empty{}, nil
|
||||
},
|
||||
})
|
||||
if !second.Load() {
|
||||
go func() {
|
||||
<-failChan
|
||||
_ = client.DRPCConn().Close()
|
||||
second.Store(true)
|
||||
close(failedChan)
|
||||
}()
|
||||
}
|
||||
return client, nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
// Ignore the first provision message!
|
||||
_, _ = stream.Recv()
|
||||
return stream.Send(&sdkproto.Provision_Response{
|
||||
Type: &sdkproto.Provision_Response_Complete{
|
||||
Complete: &sdkproto.Provision_Complete{
|
||||
Error: "some error",
|
||||
},
|
||||
},
|
||||
})
|
||||
},
|
||||
}),
|
||||
})
|
||||
<-completeChan
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
t.Run("ReconnectAndComplete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var second atomic.Bool
|
||||
failChan := make(chan struct{})
|
||||
failedChan := make(chan struct{})
|
||||
completeChan := make(chan struct{})
|
||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if second.Load() {
|
||||
close(completeChan)
|
||||
return &proto.AcquiredJob{}, nil
|
||||
}
|
||||
return &proto.AcquiredJob{
|
||||
JobId: "test",
|
||||
Provisioner: "someprovisioner",
|
||||
TemplateSourceArchive: createTar(t, map[string]string{
|
||||
"test.txt": "content",
|
||||
}),
|
||||
Type: &proto.AcquiredJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
|
||||
Metadata: &sdkproto.Provision_Metadata{},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
|
||||
return nil, yamux.ErrSessionShutdown
|
||||
},
|
||||
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
|
||||
return &proto.UpdateJobResponse{}, nil
|
||||
},
|
||||
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
|
||||
if second.Load() {
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
close(failChan)
|
||||
<-failedChan
|
||||
return &proto.Empty{}, nil
|
||||
},
|
||||
})
|
||||
if !second.Load() {
|
||||
go func() {
|
||||
<-failChan
|
||||
_ = client.DRPCConn().Close()
|
||||
second.Store(true)
|
||||
close(failedChan)
|
||||
}()
|
||||
}
|
||||
return client, nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
// Ignore the first provision message!
|
||||
_, _ = stream.Recv()
|
||||
return stream.Send(&sdkproto.Provision_Response{
|
||||
Type: &sdkproto.Provision_Response_Complete{
|
||||
Complete: &sdkproto.Provision_Complete{},
|
||||
},
|
||||
})
|
||||
},
|
||||
}),
|
||||
})
|
||||
<-completeChan
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
}
|
||||
|
||||
// Creates an in-memory tar of the files provided.
|
||||
|
|
|
@ -9,6 +9,12 @@ import (
|
|||
"storj.io/drpc/drpcconn"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxMessageSize is the maximum payload size that can be
|
||||
// transported without error.
|
||||
MaxMessageSize = 4 << 20
|
||||
)
|
||||
|
||||
// TransportPipe creates an in-memory pipe for dRPC transport.
|
||||
func TransportPipe() (*yamux.Session, *yamux.Session) {
|
||||
clientReader, clientWriter := io.Pipe()
|
||||
|
|
Loading…
Reference in New Issue