diff --git a/cli/agent.go b/cli/agent.go index af54bfc969..b900db48d0 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -125,7 +125,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { args := append(os.Args, "--no-reap") err := reaper.ForkReap( reaper.WithExecArgs(args...), - reaper.WithCatchSignals(InterruptSignals...), + reaper.WithCatchSignals(StopSignals...), ) if err != nil { logger.Error(ctx, "agent process reaper unable to fork", slog.Error(err)) @@ -144,7 +144,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { // Note that we don't want to handle these signals in the // process that runs as PID 1, that's why we do this after // the reaper forked. - ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, stopNotify := inv.SignalNotifyContext(ctx, StopSignals...) defer stopNotify() // DumpHandler does signal handling, so we call it after the diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index fc8f503f3a..fe93fe26cd 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -890,7 +890,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd { Handler: func(inv *clibase.Invocation) (err error) { ctx := inv.Context() - notifyCtx, stop := signal.NotifyContext(ctx, InterruptSignals...) // Checked later. + notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...) // Checked later. defer stop() ctx = notifyCtx diff --git a/cli/externalauth.go b/cli/externalauth.go index 675d795491..52b897b64b 100644 --- a/cli/externalauth.go +++ b/cli/externalauth.go @@ -65,7 +65,7 @@ fi Handler: func(inv *clibase.Invocation) error { ctx := inv.Context() - ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) defer stop() client, err := r.createAgentClient() diff --git a/cli/gitaskpass.go b/cli/gitaskpass.go index ddfd05af9d..30e82ab90a 100644 --- a/cli/gitaskpass.go +++ b/cli/gitaskpass.go @@ -25,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd { Handler: func(inv *clibase.Invocation) error { ctx := inv.Context() - ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) defer stop() user, host, err := gitauth.ParseAskpass(inv.Args[0]) diff --git a/cli/gitssh.go b/cli/gitssh.go index b627b3911b..479ec094f0 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -29,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd { // Catch interrupt signals to ensure the temporary private // key file is cleaned up on most cases. - ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) defer stop() // Early check so errors are reported immediately. diff --git a/cli/server.go b/cli/server.go index e02a891022..937f290aff 100644 --- a/cli/server.go +++ b/cli/server.go @@ -337,7 +337,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // Register signals early on so that graceful shutdown can't // be interrupted by additional signals. Note that we avoid - // shadowing cancel() (from above) here because notifyStop() + // shadowing cancel() (from above) here because stopCancel() // restores default behavior for the signals. This protects // the shutdown sequence from abruptly terminating things // like: database migrations, provisioner work, workspace @@ -345,8 +345,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // // To get out of a graceful shutdown, the user can send // SIGQUIT with ctrl+\ or SIGKILL with `kill -9`. - notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...) - defer notifyStop() + stopCtx, stopCancel := signalNotifyContext(ctx, inv, StopSignalsNoInterrupt...) + defer stopCancel() + interruptCtx, interruptCancel := signalNotifyContext(ctx, inv, InterruptSignals...) + defer interruptCancel() cacheDir := vals.CacheDir.String() err = os.MkdirAll(cacheDir, 0o700) @@ -1028,13 +1030,18 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. hangDetector.Start() defer hangDetector.Close() + waitForProvisionerJobs := false // Currently there is no way to ask the server to shut // itself down, so any exit signal will result in a non-zero // exit of the server. var exitErr error select { - case <-notifyCtx.Done(): - exitErr = notifyCtx.Err() + case <-stopCtx.Done(): + exitErr = stopCtx.Err() + waitForProvisionerJobs = true + _, _ = io.WriteString(inv.Stdout, cliui.Bold("Stop caught, waiting for provisioner jobs to complete and gracefully exiting. Use ctrl+\\ to force quit")) + case <-interruptCtx.Done(): + exitErr = interruptCtx.Err() _, _ = io.WriteString(inv.Stdout, cliui.Bold("Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit")) case <-tunnelDone: exitErr = xerrors.New("dev tunnel closed unexpectedly") @@ -1082,7 +1089,16 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. defer wg.Done() r.Verbosef(inv, "Shutting down provisioner daemon %d...", id) - err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second) + timeout := 5 * time.Second + if waitForProvisionerJobs { + // It can last for a long time... + timeout = 30 * time.Minute + } + + err := shutdownWithTimeout(func(ctx context.Context) error { + // We only want to cancel active jobs if we aren't exiting gracefully. + return provisionerDaemon.Shutdown(ctx, !waitForProvisionerJobs) + }, timeout) if err != nil { cliui.Errorf(inv.Stderr, "Failed to shut down provisioner daemon %d: %s\n", id, err) return @@ -2512,3 +2528,12 @@ func escapePostgresURLUserInfo(v string) (string, error) { return v, nil } + +func signalNotifyContext(ctx context.Context, inv *clibase.Invocation, sig ...os.Signal) (context.Context, context.CancelFunc) { + // On Windows, some of our signal functions lack support. + // If we pass in no signals, we should just return the context as-is. + if len(sig) == 0 { + return context.WithCancel(ctx) + } + return inv.SignalNotifyContext(ctx, sig...) +} diff --git a/cli/server_createadminuser.go b/cli/server_createadminuser.go index 7491afac3c..43f78ea784 100644 --- a/cli/server_createadminuser.go +++ b/cli/server_createadminuser.go @@ -47,7 +47,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd { logger = logger.Leveled(slog.LevelDebug) } - ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...) + ctx, cancel := inv.SignalNotifyContext(ctx, StopSignals...) defer cancel() if newUserDBURL == "" { diff --git a/cli/server_test.go b/cli/server_test.go index 9699e8a48e..4ce4d2b5f5 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -21,6 +21,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "runtime" "strconv" "strings" @@ -1605,7 +1606,7 @@ func TestServer_Production(t *testing.T) { } //nolint:tparallel,paralleltest // This test cannot be run in parallel due to signal handling. -func TestServer_Shutdown(t *testing.T) { +func TestServer_InterruptShutdown(t *testing.T) { t.Skip("This test issues an interrupt signal which will propagate to the test runner.") if runtime.GOOS == "windows" { @@ -1638,6 +1639,46 @@ func TestServer_Shutdown(t *testing.T) { require.NoError(t, err) } +func TestServer_GracefulShutdown(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + // Sending interrupt signal isn't supported on Windows! + t.SkipNow() + } + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + root, cfg := clitest.New(t, + "server", + "--in-memory", + "--http-address", ":0", + "--access-url", "http://example.com", + "--provisioner-daemons", "1", + "--cache-dir", t.TempDir(), + ) + var stopFunc context.CancelFunc + root = root.WithTestSignalNotifyContext(t, func(parent context.Context, signals ...os.Signal) (context.Context, context.CancelFunc) { + if !reflect.DeepEqual(cli.StopSignalsNoInterrupt, signals) { + return context.WithCancel(ctx) + } + var ctx context.Context + ctx, stopFunc = context.WithCancel(parent) + return ctx, stopFunc + }) + serverErr := make(chan error, 1) + pty := ptytest.New(t).Attach(root) + go func() { + serverErr <- root.WithContext(ctx).Run() + }() + _ = waitAccessURL(t, cfg) + // It's fair to assume `stopFunc` isn't nil here, because the server + // has started and access URL is propagated. + stopFunc() + pty.ExpectMatch("waiting for provisioner jobs to complete") + err := <-serverErr + require.NoError(t, err) +} + func BenchmarkServerHelp(b *testing.B) { // server --help is a good proxy for measuring the // constant overhead of each command. diff --git a/cli/signal_unix.go b/cli/signal_unix.go index 05d619c023..9cb6f3f899 100644 --- a/cli/signal_unix.go +++ b/cli/signal_unix.go @@ -7,8 +7,23 @@ import ( "syscall" ) -var InterruptSignals = []os.Signal{ +// StopSignals is the list of signals that are used for handling +// shutdown behavior. +var StopSignals = []os.Signal{ os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, } + +// StopSignals is the list of signals that are used for handling +// graceful shutdown behavior. +var StopSignalsNoInterrupt = []os.Signal{ + syscall.SIGTERM, + syscall.SIGHUP, +} + +// InterruptSignals is the list of signals that are used for handling +// immediate shutdown behavior. +var InterruptSignals = []os.Signal{ + os.Interrupt, +} diff --git a/cli/signal_windows.go b/cli/signal_windows.go index 3624415a64..8d9b8518e6 100644 --- a/cli/signal_windows.go +++ b/cli/signal_windows.go @@ -6,4 +6,12 @@ import ( "os" ) -var InterruptSignals = []os.Signal{os.Interrupt} +var StopSignals = []os.Signal{ + os.Interrupt, +} + +var StopSignalsNoInterrupt = []os.Signal{} + +var InterruptSignals = []os.Signal{ + os.Interrupt, +} diff --git a/cli/ssh.go b/cli/ssh.go index 21437ee6ae..023c5307da 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -73,7 +73,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { // session can persist for up to 72 hours, since we set a long // timeout on the Agent side of the connection. In particular, // OpenSSH sends SIGHUP to terminate a proxy command. - ctx, stop := inv.SignalNotifyContext(inv.Context(), InterruptSignals...) + ctx, stop := inv.SignalNotifyContext(inv.Context(), StopSignals...) defer stop() ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/cli/templatepull_test.go b/cli/templatepull_test.go index ec7beb6196..1b1d51b0cc 100644 --- a/cli/templatepull_test.go +++ b/cli/templatepull_test.go @@ -328,7 +328,7 @@ func TestTemplatePull_ToDir(t *testing.T) { require.NoError(t, inv.Run()) - // Validate behaviour of choosing template name in the absence of an output path argument. + // Validate behavior of choosing template name in the absence of an output path argument. destPath := actualDest if destPath == "" { destPath = template.Name diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 85d92a5ef6..4d315c3e2b 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -498,7 +498,7 @@ func (c *provisionerdCloser) Close() error { c.closed = true ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - shutdownErr := c.d.Shutdown(ctx) + shutdownErr := c.d.Shutdown(ctx, true) closeErr := c.d.Close() if shutdownErr != nil { return shutdownErr diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index 5943758a7d..6f356e541d 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -88,8 +88,10 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, agpl.InterruptSignals...) - defer notifyStop() + stopCtx, stopCancel := inv.SignalNotifyContext(ctx, agpl.StopSignalsNoInterrupt...) + defer stopCancel() + interruptCtx, interruptCancel := inv.SignalNotifyContext(ctx, agpl.InterruptSignals...) + defer interruptCancel() tags, err := agpl.ParseProvisionerTags(rawTags) if err != nil { @@ -212,10 +214,17 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { Metrics: metrics, }) + waitForProvisionerJobs := false var exitErr error select { - case <-notifyCtx.Done(): - exitErr = notifyCtx.Err() + case <-stopCtx.Done(): + exitErr = stopCtx.Err() + _, _ = fmt.Fprintln(inv.Stdout, cliui.Bold( + "Stop caught, waiting for provisioner jobs to complete and gracefully exiting. Use ctrl+\\ to force quit", + )) + waitForProvisionerJobs = true + case <-interruptCtx.Done(): + exitErr = interruptCtx.Err() _, _ = fmt.Fprintln(inv.Stdout, cliui.Bold( "Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit", )) @@ -225,7 +234,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { cliui.Errorf(inv.Stderr, "Unexpected error, shutting down server: %s\n", exitErr) } - err = srv.Shutdown(ctx) + err = srv.Shutdown(ctx, waitForProvisionerJobs) if err != nil { return xerrors.Errorf("shutdown: %w", err) } diff --git a/enterprise/cli/proxyserver.go b/enterprise/cli/proxyserver.go index a31d2fe829..68ec04b966 100644 --- a/enterprise/cli/proxyserver.go +++ b/enterprise/cli/proxyserver.go @@ -142,7 +142,7 @@ func (r *RootCmd) proxyServer() *clibase.Cmd { // // To get out of a graceful shutdown, the user can send // SIGQUIT with ctrl+\ or SIGKILL with `kill -9`. - notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, cli.InterruptSignals...) + notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, cli.StopSignals...) defer notifyStop() // Clean up idle connections at the end, e.g. diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index caa65c8850..c62a91593d 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -441,7 +441,7 @@ func TestProvisionerDaemonServe(t *testing.T) { build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) - err = pd.Shutdown(ctx) + err = pd.Shutdown(ctx, false) require.NoError(t, err) err = terraformServer.Close() require.NoError(t, err) diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 52414db4af..3e49648700 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -474,15 +474,18 @@ func (p *Server) isClosed() bool { } } -// Shutdown triggers a graceful exit of each registered provisioner. -func (p *Server) Shutdown(ctx context.Context) error { +// Shutdown gracefully exists with the option to cancel the active job. +// If false, it will wait for the job to complete. +// +//nolint:revive +func (p *Server) Shutdown(ctx context.Context, cancelActiveJob bool) error { p.mutex.Lock() p.opts.Logger.Info(ctx, "attempting graceful shutdown") if !p.shuttingDownB { close(p.shuttingDownCh) p.shuttingDownB = true } - if p.activeJob != nil { + if cancelActiveJob && p.activeJob != nil { p.activeJob.Cancel() } p.mutex.Unlock() diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index a04196e6b4..2031fa6c39 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -671,7 +671,7 @@ func TestProvisionerd(t *testing.T) { }), }) require.Condition(t, closedWithin(updateChan, testutil.WaitShort)) - err := server.Shutdown(context.Background()) + err := server.Shutdown(context.Background(), true) require.NoError(t, err) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, server.Close()) @@ -762,7 +762,7 @@ func TestProvisionerd(t *testing.T) { require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - require.NoError(t, server.Shutdown(ctx)) + require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) @@ -853,7 +853,7 @@ func TestProvisionerd(t *testing.T) { require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - require.NoError(t, server.Shutdown(ctx)) + require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) @@ -944,7 +944,7 @@ func TestProvisionerd(t *testing.T) { t.Log("completeChan closed") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - require.NoError(t, server.Shutdown(ctx)) + require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) }) @@ -1039,7 +1039,7 @@ func TestProvisionerd(t *testing.T) { require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - require.NoError(t, server.Shutdown(ctx)) + require.NoError(t, server.Shutdown(ctx, true)) require.NoError(t, server.Close()) assert.Equal(t, ops[len(ops)-1], "CompleteJob") assert.Contains(t, ops[0:len(ops)-1], "Log: Cleaning Up | ") @@ -1076,7 +1076,7 @@ func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, connector prov t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - _ = server.Shutdown(ctx) + _ = server.Shutdown(ctx, true) _ = server.Close() }) return server