diff --git a/cli/server.go b/cli/server.go index f1125da3f0..72c72679fd 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1371,10 +1371,10 @@ func newProvisionerDaemon( connector[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient) } - return provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { + return provisionerd.New(func(dialCtx context.Context) (proto.DRPCProvisionerDaemonClient, error) { // This debounces calls to listen every second. Read the comment // in provisionerdserver.go to learn more! - return coderAPI.CreateInMemoryProvisionerDaemon(ctx, name) + return coderAPI.CreateInMemoryProvisionerDaemon(dialCtx, name) }, &provisionerd.Options{ Logger: logger.Named(fmt.Sprintf("provisionerd-%s", name)), UpdateInterval: time.Second, diff --git a/coderd/coderd.go b/coderd/coderd.go index 3f16c89cb0..3e04e6a7db 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1174,7 +1174,7 @@ func compressHandler(h http.Handler) http.Handler { // CreateInMemoryProvisionerDaemon is an in-memory connection to a provisionerd. // Useful when starting coderd and provisionerd in the same process. -func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string) (client proto.DRPCProvisionerDaemonClient, err error) { +func (api *API) CreateInMemoryProvisionerDaemon(dialCtx context.Context, name string) (client proto.DRPCProvisionerDaemonClient, err error) { tracer := api.TracerProvider.Tracer(tracing.TracerName) clientSession, serverSession := drpc.MemTransportPipe() defer func() { @@ -1185,7 +1185,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string }() //nolint:gocritic // in-memory provisioners are owned by system - daemon, err := api.Database.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(ctx), database.UpsertProvisionerDaemonParams{ + daemon, err := api.Database.UpsertProvisionerDaemon(dbauthz.AsSystemRestricted(dialCtx), database.UpsertProvisionerDaemonParams{ Name: name, CreatedAt: dbtime.Now(), Provisioners: []database.ProvisionerType{ @@ -1201,7 +1201,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string } mux := drpcmux.New() - api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name)) + api.Logger.Info(dialCtx, "starting in-memory provisioner daemon", slog.F("name", name)) logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name)) srv, err := provisionerdserver.NewServer( api.ctx, // use the same ctx as the API @@ -1238,13 +1238,25 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string if xerrors.Is(err, io.EOF) { return } - logger.Debug(ctx, "drpc server error", slog.Error(err)) + logger.Debug(dialCtx, "drpc server error", slog.Error(err)) }, }, ) + // in-mem pipes aren't technically "websockets" but they have the same properties as far as the + // API is concerned: they are long-lived connections that we need to close before completing + // shutdown of the API. + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Add(1) + api.WebsocketWaitMutex.Unlock() go func() { - err := server.Serve(ctx, serverSession) - logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err)) + defer api.WebsocketWaitGroup.Done() + // here we pass the background context, since we want the server to keep serving until the + // client hangs up. If we, say, pass the API context, then when it is canceled, we could + // drop a job that we locked in the database but never passed to the provisionerd. The + // provisionerd is local, in-mem, so there isn't a danger of losing contact with it and + // having a dead connection we don't know the status of. + err := server.Serve(context.Background(), serverSession) + logger.Info(dialCtx, "provisioner daemon disconnected", slog.Error(err)) // close the sessions, so we don't leak goroutines serving them. _ = clientSession.Close() _ = serverSession.Close() diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 55060a0998..33184aede9 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -532,8 +532,8 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer { assert.NoError(t, err) }() - daemon := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { - return coderAPI.CreateInMemoryProvisionerDaemon(ctx, "test") + daemon := provisionerd.New(func(dialCtx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { + return coderAPI.CreateInMemoryProvisionerDaemon(dialCtx, "test") }, &provisionerd.Options{ Logger: coderAPI.Logger.Named("provisionerd").Leveled(slog.LevelDebug), UpdateInterval: 250 * time.Millisecond,