diff --git a/agent/agent.go b/agent/agent.go index c0a61fa97f..e0256d2e22 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -150,13 +150,17 @@ func New(options Options) Agent { options.Syscaller = agentproc.NewSyscaller() } - ctx, cancelFunc := context.WithCancel(context.Background()) + hardCtx, hardCancel := context.WithCancel(context.Background()) + gracefulCtx, gracefulCancel := context.WithCancel(hardCtx) a := &agent{ tailnetListenPort: options.TailnetListenPort, reconnectingPTYTimeout: options.ReconnectingPTYTimeout, logger: options.Logger, - closeCancel: cancelFunc, - closed: make(chan struct{}), + gracefulCtx: gracefulCtx, + gracefulCancel: gracefulCancel, + hardCtx: hardCtx, + hardCancel: hardCancel, + coordDisconnected: make(chan struct{}), environmentVariables: options.EnvironmentVariables, client: options.Client, exchangeToken: options.ExchangeToken, @@ -181,9 +185,14 @@ func New(options Options) Agent { prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), } + // Initially, we have a closed channel, reflecting the fact that we are not initially connected. + // Each time we connect we replace the channel (while holding the closeMutex) with a new one + // that gets closed on disconnection. This is used to wait for graceful disconnection from the + // coordinator during shut down. + close(a.coordDisconnected) a.serviceBanner.Store(new(codersdk.ServiceBannerConfig)) a.sessionToken.Store(new(string)) - a.init(ctx) + a.init() return a } @@ -206,10 +215,16 @@ type agent struct { reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration - connCloseWait sync.WaitGroup - closeCancel context.CancelFunc - closeMutex sync.Mutex - closed chan struct{} + // we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time + // to start gracefully shutting down and "hard" which is Done when it is time to close + // everything down (regardless of whether graceful shutdown completed). + gracefulCtx context.Context + gracefulCancel context.CancelFunc + hardCtx context.Context + hardCancel context.CancelFunc + closeWaitGroup sync.WaitGroup + closeMutex sync.Mutex + coordDisconnected chan struct{} environmentVariables map[string]string @@ -249,8 +264,9 @@ func (a *agent) TailnetConn() *tailnet.Conn { return a.network } -func (a *agent) init(ctx context.Context) { - sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{ +func (a *agent) init() { + // pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown. + sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{ MaxTimeout: a.sshMaxTimeout, MOTDFile: func() string { return a.manifest.Load().MOTDFile }, ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() }, @@ -272,22 +288,24 @@ func (a *agent) init(ctx context.Context) { // Register runner metrics. If the prom registry is nil, the metrics // will not report anywhere. a.scriptRunner.RegisterMetrics(a.prometheusRegistry) - go a.runLoop(ctx) + go a.runLoop() } // runLoop attempts to start the agent in a retry loop. // Coder may be offline temporarily, a connection issue // may be happening, but regardless after the intermittent // failure, you'll want the agent to reconnect. -func (a *agent) runLoop(ctx context.Context) { - go a.reportLifecycleLoop(ctx) - go a.reportMetadataLoop(ctx) - go a.manageProcessPriorityLoop(ctx) +func (a *agent) runLoop() { + go a.reportLifecycleUntilClose() + go a.reportMetadataUntilGracefulShutdown() + go a.manageProcessPriorityUntilGracefulShutdown() + // need to keep retrying up to the hardCtx so that we can send graceful shutdown-related + // messages. + ctx := a.hardCtx for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting to coderd") - err := a.run(ctx) - // Cancel after the run is complete to clean up any leaked resources! + err := a.run() if err == nil { continue } @@ -386,7 +404,9 @@ func (t *trySingleflight) Do(key string, fn func()) { fn() } -func (a *agent) reportMetadataLoop(ctx context.Context) { +func (a *agent) reportMetadataUntilGracefulShutdown() { + // metadata reporting can cease as soon as we start gracefully shutting down. + ctx := a.gracefulCtx tickerDone := make(chan struct{}) collectDone := make(chan struct{}) ctx, cancel := context.WithCancel(ctx) @@ -595,9 +615,12 @@ func (a *agent) reportMetadataLoop(ctx context.Context) { } } -// reportLifecycleLoop reports the current lifecycle state once. All state +// reportLifecycleUntilClose reports the current lifecycle state once. All state // changes are reported in order. -func (a *agent) reportLifecycleLoop(ctx context.Context) { +func (a *agent) reportLifecycleUntilClose() { + // part of graceful shut down is reporting the final lifecycle states, e.g "ShuttingDown" so the + // lifecycle reporting has to be via the "hard" context. + ctx := a.hardCtx lastReportedIndex := 0 // Start off with the created state without reporting it. for { select { @@ -623,6 +646,8 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) { err := a.client.PostLifecycle(ctx, report) if err == nil { + a.logger.Debug(ctx, "successfully reported lifecycle state", slog.F("payload", report)) + r.Reset() // don't back off when we are successful lastReportedIndex++ select { case a.lifecycleReported <- report.State: @@ -638,6 +663,7 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) { break } if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + a.logger.Debug(ctx, "canceled reporting lifecycle state", slog.F("payload", report)) return } // If we fail to report the state we probably shouldn't exit, log only. @@ -648,7 +674,7 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) { // setLifecycle sets the lifecycle state and notifies the lifecycle loop. // The state is only updated if it's a valid state transition. -func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentLifecycle) { +func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) { report := agentsdk.PostLifecycleRequest{ State: state, ChangedAt: dbtime.Now(), @@ -657,12 +683,12 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL a.lifecycleMu.Lock() lastReport := a.lifecycleStates[len(a.lifecycleStates)-1] if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastReport.State) >= slices.Index(codersdk.WorkspaceAgentLifecycleOrder, report.State) { - a.logger.Warn(ctx, "attempted to set lifecycle state to a previous state", slog.F("last", lastReport), slog.F("current", report)) + a.logger.Warn(context.Background(), "attempted to set lifecycle state to a previous state", slog.F("last", lastReport), slog.F("current", report)) a.lifecycleMu.Unlock() return } a.lifecycleStates = append(a.lifecycleStates, report) - a.logger.Debug(ctx, "set lifecycle state", slog.F("current", report), slog.F("last", lastReport)) + a.logger.Debug(context.Background(), "set lifecycle state", slog.F("current", report), slog.F("last", lastReport)) a.lifecycleMu.Unlock() select { @@ -674,7 +700,8 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL // fetchServiceBannerLoop fetches the service banner on an interval. It will // not be fetched immediately; the expectation is that it is primed elsewhere // (and must be done before the session actually starts). -func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgentClient) error { +func (a *agent) fetchServiceBannerLoop(ctx context.Context, conn drpc.Conn) error { + aAPI := proto.NewDRPCAgentClient(conn) ticker := time.NewTicker(a.serviceBannerRefreshInterval) defer ticker.Stop() for { @@ -696,205 +723,272 @@ func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgent } } -func (a *agent) run(ctx context.Context) error { +func (a *agent) run() (retErr error) { // This allows the agent to refresh it's token if necessary. // For instance identity this is required, since the instance // may not have re-provisioned, but a new agent ID was created. - sessionToken, err := a.exchangeToken(ctx) + sessionToken, err := a.exchangeToken(a.hardCtx) if err != nil { return xerrors.Errorf("exchange token: %w", err) } a.sessionToken.Store(&sessionToken) // ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs - conn, err := a.client.ConnectRPC(ctx) + conn, err := a.client.ConnectRPC(a.hardCtx) if err != nil { return err } defer func() { cErr := conn.Close() if cErr != nil { - a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err)) + a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(err)) } }() - aAPI := proto.NewDRPCAgentClient(conn) - sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{}) - if err != nil { - return xerrors.Errorf("fetch service banner: %w", err) - } - serviceBanner := agentsdk.ServiceBannerFromProto(sbp) - a.serviceBanner.Store(&serviceBanner) + // A lot of routines need the agent API / tailnet API connection. We run them in their own + // goroutines in parallel, but errors in any routine will cause them all to exit so we can + // redial the coder server and retry. + connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger, conn) - mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{}) - if err != nil { - return xerrors.Errorf("fetch metadata: %w", err) - } - a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp)) - manifest, err := agentsdk.ManifestFromProto(mp) - if err != nil { - a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err)) - return xerrors.Errorf("convert manifest: %w", err) - } - if manifest.AgentID == uuid.Nil { - return xerrors.New("nil agentID returned by manifest") - } - a.client.RewriteDERPMap(manifest.DERPMap) - - // Expand the directory and send it back to coderd so external - // applications that rely on the directory can use it. - // - // An example is VS Code Remote, which must know the directory - // before initializing a connection. - manifest.Directory, err = expandDirectory(manifest.Directory) - if err != nil { - return xerrors.Errorf("expand directory: %w", err) - } - subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems) - if err != nil { - a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err)) - return xerrors.Errorf("failed to convert subsystems: %w", err) - } - _, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{ - Version: buildinfo.Version(), - ExpandedDirectory: manifest.Directory, - Subsystems: subsys, - }}) - if err != nil { - return xerrors.Errorf("update workspace agent startup: %w", err) - } - - oldManifest := a.manifest.Swap(&manifest) - - // The startup script should only execute on the first run! - if oldManifest == nil { - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStarting) - - // Perform overrides early so that Git auth can work even if users - // connect to a workspace that is not yet ready. We don't run this - // concurrently with the startup script to avoid conflicts between - // them. - if manifest.GitAuthConfigs > 0 { - // If this fails, we should consider surfacing the error in the - // startup log and setting the lifecycle state to be "start_error" - // (after startup script completion), but for now we'll just log it. - err := gitauth.OverrideVSCodeConfigs(a.filesystem) + connMan.start("init service banner", gracefulShutdownBehaviorStop, + func(ctx context.Context, conn drpc.Conn) error { + aAPI := proto.NewDRPCAgentClient(conn) + sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{}) if err != nil { - a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err)) + return xerrors.Errorf("fetch service banner: %w", err) } - } + serviceBanner := agentsdk.ServiceBannerFromProto(sbp) + a.serviceBanner.Store(&serviceBanner) + return nil + }, + ) - err = a.scriptRunner.Init(manifest.Scripts) - if err != nil { - return xerrors.Errorf("init script runner: %w", err) - } - err = a.trackConnGoroutine(func() { - start := time.Now() - err := a.scriptRunner.Execute(ctx, func(script codersdk.WorkspaceAgentScript) bool { - return script.RunOnStart - }) - // Measure the time immediately after the script has finished - dur := time.Since(start).Seconds() - if err != nil { - a.logger.Warn(ctx, "startup script(s) failed", slog.Error(err)) - if errors.Is(err, agentscripts.ErrTimeout) { - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartTimeout) - } else { - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartError) - } - } else { - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleReady) - } + // channels to sync goroutines below + // handle manifest + // | + // manifestOK + // | | + // | +----------------------+ + // V | + // app health reporter | + // V + // create or update network + // | + // networkOK + // | + // coordination <--------------------------+ + // derp map subscriber <----------------+ + // stats report loop <---------------+ + networkOK := make(chan struct{}) + manifestOK := make(chan struct{}) - label := "false" - if err == nil { - label = "true" + connMan.start("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK)) + + connMan.start("app health reporter", gracefulShutdownBehaviorStop, + func(ctx context.Context, conn drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-manifestOK: + manifest := a.manifest.Load() + NewWorkspaceAppHealthReporter( + a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)), + )(ctx) + return nil } - a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur) - a.scriptRunner.StartCron() }) - if err != nil { - return xerrors.Errorf("track conn goroutine: %w", err) + + connMan.start("create or update network", gracefulShutdownBehaviorStop, + a.createOrUpdateNetwork(manifestOK, networkOK)) + + connMan.start("coordination", gracefulShutdownBehaviorStop, + func(ctx context.Context, conn drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-networkOK: + } + return a.runCoordinator(ctx, conn, a.network) + }, + ) + + connMan.start("derp map subscriber", gracefulShutdownBehaviorStop, + func(ctx context.Context, conn drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-networkOK: + } + return a.runDERPMapSubscriber(ctx, conn, a.network) + }) + + connMan.start("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop) + + connMan.start("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-networkOK: } + return a.statsReporter.reportLoop(ctx, proto.NewDRPCAgentClient(conn)) + }) + + return connMan.wait() +} + +// handleManifest returns a function that fetches and processes the manifest +func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Context, conn drpc.Conn) error { + return func(ctx context.Context, conn drpc.Conn) error { + aAPI := proto.NewDRPCAgentClient(conn) + mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{}) + if err != nil { + return xerrors.Errorf("fetch metadata: %w", err) + } + a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp)) + manifest, err := agentsdk.ManifestFromProto(mp) + if err != nil { + a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err)) + return xerrors.Errorf("convert manifest: %w", err) + } + if manifest.AgentID == uuid.Nil { + return xerrors.New("nil agentID returned by manifest") + } + a.client.RewriteDERPMap(manifest.DERPMap) + + // Expand the directory and send it back to coderd so external + // applications that rely on the directory can use it. + // + // An example is VS Code Remote, which must know the directory + // before initializing a connection. + manifest.Directory, err = expandDirectory(manifest.Directory) + if err != nil { + return xerrors.Errorf("expand directory: %w", err) + } + subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems) + if err != nil { + a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err)) + return xerrors.Errorf("failed to convert subsystems: %w", err) + } + _, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{ + Version: buildinfo.Version(), + ExpandedDirectory: manifest.Directory, + Subsystems: subsys, + }}) + if err != nil { + if xerrors.Is(err, context.Canceled) { + return nil + } + return xerrors.Errorf("update workspace agent startup: %w", err) + } + + oldManifest := a.manifest.Swap(&manifest) + close(manifestOK) + + // The startup script should only execute on the first run! + if oldManifest == nil { + a.setLifecycle(codersdk.WorkspaceAgentLifecycleStarting) + + // Perform overrides early so that Git auth can work even if users + // connect to a workspace that is not yet ready. We don't run this + // concurrently with the startup script to avoid conflicts between + // them. + if manifest.GitAuthConfigs > 0 { + // If this fails, we should consider surfacing the error in the + // startup log and setting the lifecycle state to be "start_error" + // (after startup script completion), but for now we'll just log it. + err := gitauth.OverrideVSCodeConfigs(a.filesystem) + if err != nil { + a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err)) + } + } + + err = a.scriptRunner.Init(manifest.Scripts) + if err != nil { + return xerrors.Errorf("init script runner: %w", err) + } + err = a.trackGoroutine(func() { + start := time.Now() + // here we use the graceful context because the script runner is not directly tied + // to the agent API. + err := a.scriptRunner.Execute(a.gracefulCtx, func(script codersdk.WorkspaceAgentScript) bool { + return script.RunOnStart + }) + // Measure the time immediately after the script has finished + dur := time.Since(start).Seconds() + if err != nil { + a.logger.Warn(ctx, "startup script(s) failed", slog.Error(err)) + if errors.Is(err, agentscripts.ErrTimeout) { + a.setLifecycle(codersdk.WorkspaceAgentLifecycleStartTimeout) + } else { + a.setLifecycle(codersdk.WorkspaceAgentLifecycleStartError) + } + } else { + a.setLifecycle(codersdk.WorkspaceAgentLifecycleReady) + } + + label := "false" + if err == nil { + label = "true" + } + a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur) + a.scriptRunner.StartCron() + }) + if err != nil { + return xerrors.Errorf("track conn goroutine: %w", err) + } + } + return nil } +} - // This automatically closes when the context ends! - appReporterCtx, appReporterCtxCancel := context.WithCancel(ctx) - defer appReporterCtxCancel() - go NewWorkspaceAppHealthReporter( - a.logger, manifest.Apps, agentsdk.AppHealthPoster(aAPI))(appReporterCtx) - - a.closeMutex.Lock() - network := a.network - a.closeMutex.Unlock() - if network == nil { - network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections) - if err != nil { - return xerrors.Errorf("create tailnet: %w", err) +// createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates +// the tailnet using the information in the manifest +func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan<- struct{}) func(context.Context, drpc.Conn) error { + return func(ctx context.Context, _ drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-manifestOK: } + var err error + manifest := a.manifest.Load() a.closeMutex.Lock() - // Re-check if agent was closed while initializing the network. - closed := a.isClosed() - if !closed { - a.network = network - a.statsReporter = newStatsReporter(a.logger, network, a) - } + network := a.network a.closeMutex.Unlock() - if closed { - _ = network.Close() - return xerrors.New("agent is closed") + if network == nil { + // use the graceful context here, because creating the tailnet is not itself tied to the + // agent API. + network, err = a.createTailnet(a.gracefulCtx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections) + if err != nil { + return xerrors.Errorf("create tailnet: %w", err) + } + a.closeMutex.Lock() + // Re-check if agent was closed while initializing the network. + closed := a.isClosed() + if !closed { + a.network = network + a.statsReporter = newStatsReporter(a.logger, network, a) + } + a.closeMutex.Unlock() + if closed { + _ = network.Close() + return xerrors.New("agent is closed") + } + } else { + // Update the wireguard IPs if the agent ID changed. + err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) + if err != nil { + a.logger.Error(a.gracefulCtx, "update tailnet addresses", slog.Error(err)) + } + // Update the DERP map, force WebSocket setting and allow/disallow + // direct connections. + network.SetDERPMap(manifest.DERPMap) + network.SetDERPForceWebSockets(manifest.DERPForceWebSockets) + network.SetBlockEndpoints(manifest.DisableDirectConnections) } - } else { - // Update the wireguard IPs if the agent ID changed. - err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) - if err != nil { - a.logger.Error(ctx, "update tailnet addresses", slog.Error(err)) - } - // Update the DERP map, force WebSocket setting and allow/disallow - // direct connections. - network.SetDERPMap(manifest.DERPMap) - network.SetDERPForceWebSockets(manifest.DERPForceWebSockets) - network.SetBlockEndpoints(manifest.DisableDirectConnections) + close(networkOK) + return nil } - - eg, egCtx := errgroup.WithContext(ctx) - eg.Go(func() error { - a.logger.Debug(egCtx, "running tailnet connection coordinator") - err := a.runCoordinator(egCtx, conn, network) - if err != nil { - return xerrors.Errorf("run coordinator: %w", err) - } - return nil - }) - - eg.Go(func() error { - a.logger.Debug(egCtx, "running derp map subscriber") - err := a.runDERPMapSubscriber(egCtx, conn, network) - if err != nil { - return xerrors.Errorf("run derp map subscriber: %w", err) - } - return nil - }) - - eg.Go(func() error { - a.logger.Debug(egCtx, "running fetch server banner loop") - err := a.fetchServiceBannerLoop(egCtx, aAPI) - if err != nil { - return xerrors.Errorf("fetch server banner loop: %w", err) - } - return nil - }) - - eg.Go(func() error { - a.logger.Debug(egCtx, "running stats report loop") - err := a.statsReporter.reportLoop(egCtx, aAPI) - if err != nil { - return xerrors.Errorf("report stats loop: %w", err) - } - return nil - }) - - return eg.Wait() } // updateCommandEnv updates the provided command environment with the @@ -995,15 +1089,15 @@ func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix { return a.addresses } -func (a *agent) trackConnGoroutine(fn func()) error { +func (a *agent) trackGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() if a.isClosed() { return xerrors.New("track conn goroutine: agent is closed") } - a.connCloseWait.Add(1) + a.closeWaitGroup.Add(1) go func() { - defer a.connCloseWait.Done() + defer a.closeWaitGroup.Done() fn() }() return nil @@ -1037,7 +1131,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = sshListener.Close() } }() - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { _ = a.sshServer.Serve(sshListener) }); err != nil { return nil, err @@ -1052,7 +1146,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = reconnectingPTYListener.Close() } }() - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { logger := a.logger.Named("reconnecting-pty") var wg sync.WaitGroup for { @@ -1072,7 +1166,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t go func() { select { case <-closed: - case <-a.closed: + case <-a.hardCtx.Done(): _ = conn.Close() } wg.Done() @@ -1115,7 +1209,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = speedtestListener.Close() } }() - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { var wg sync.WaitGroup for { conn, err := speedtestListener.Accept() @@ -1134,7 +1228,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t go func() { select { case <-closed: - case <-a.closed: + case <-a.hardCtx.Done(): _ = conn.Close() } wg.Done() @@ -1163,7 +1257,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = apiListener.Close() } }() - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { defer apiListener.Close() server := &http.Server{ Handler: a.apiHandler(), @@ -1175,7 +1269,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t go func() { select { case <-ctx.Done(): - case <-a.closed: + case <-a.hardCtx.Done(): } _ = server.Close() }() @@ -1196,7 +1290,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error { defer a.logger.Debug(ctx, "disconnected from coordination RPC") tClient := tailnetproto.NewDRPCTailnetClient(conn) - coordinate, err := tClient.Coordinate(ctx) + // we run the RPC on the hardCtx so that we have a chance to send the disconnect message if we + // gracefully shut down. + coordinate, err := tClient.Coordinate(a.hardCtx) if err != nil { return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err) } @@ -1207,13 +1303,34 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai } }() a.logger.Info(ctx, "connected to coordination RPC") - coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-coordination.Error(): - return err + + // This allows the Close() routine to wait for the coordinator to gracefully disconnect. + a.closeMutex.Lock() + if a.isClosed() { + return nil } + disconnected := make(chan struct{}) + a.coordDisconnected = disconnected + defer close(disconnected) + a.closeMutex.Unlock() + + coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) + + errCh := make(chan error, 1) + go func() { + defer close(errCh) + select { + case <-ctx.Done(): + err := coordination.Close() + if err != nil { + a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) + } + return + case err := <-coordination.Error(): + errCh <- err + } + }() + return <-errCh } // runDERPMapSubscriber runs a coordinator and returns if a reconnect should occur. @@ -1311,7 +1428,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m Metrics: a.metrics.reconnectingPTYErrors, }, logger.With(slog.F("message_id", msg.ID))) - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { rpty.Wait() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { @@ -1406,7 +1523,9 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect var prioritizedProcs = []string{"coder agent"} -func (a *agent) manageProcessPriorityLoop(ctx context.Context) { +func (a *agent) manageProcessPriorityUntilGracefulShutdown() { + // process priority can stop as soon as we are gracefully shutting down + ctx := a.gracefulCtx defer func() { if r := recover(); r != nil { a.logger.Critical(ctx, "recovered from panic", @@ -1515,12 +1634,7 @@ func (a *agent) manageProcessPriority(ctx context.Context) ([]*agentproc.Process // isClosed returns whether the API is closed or not. func (a *agent) isClosed() bool { - select { - case <-a.closed: - return true - default: - return false - } + return a.hardCtx.Err() != nil } func (a *agent) HTTPDebug() http.Handler { @@ -1584,59 +1698,82 @@ func (a *agent) Close() error { return nil } - ctx := context.Background() - a.logger.Info(ctx, "shutting down agent") - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown) + a.logger.Info(a.hardCtx, "shutting down agent") + a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown) // Attempt to gracefully shut down all active SSH connections and // stop accepting new ones. - err := a.sshServer.Shutdown(ctx) + err := a.sshServer.Shutdown(a.hardCtx) if err != nil { - a.logger.Error(ctx, "ssh server shutdown", slog.Error(err)) + a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err)) } + err = a.sshServer.Close() + if err != nil { + a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err)) + } + // wait for SSH to shut down before the general graceful cancel, because + // this triggers a disconnect in the tailnet layer, telling all clients to + // shut down their wireguard tunnels to us. If SSH sessions are still up, + // they might hang instead of being closed. + a.gracefulCancel() lifecycleState := codersdk.WorkspaceAgentLifecycleOff - err = a.scriptRunner.Execute(ctx, func(script codersdk.WorkspaceAgentScript) bool { + err = a.scriptRunner.Execute(a.hardCtx, func(script codersdk.WorkspaceAgentScript) bool { return script.RunOnStop }) if err != nil { - a.logger.Warn(ctx, "shutdown script(s) failed", slog.Error(err)) + a.logger.Warn(a.hardCtx, "shutdown script(s) failed", slog.Error(err)) if errors.Is(err, agentscripts.ErrTimeout) { lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownTimeout } else { lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError } } - a.setLifecycle(ctx, lifecycleState) + a.setLifecycle(lifecycleState) err = a.scriptRunner.Close() if err != nil { - a.logger.Error(ctx, "script runner close", slog.Error(err)) + a.logger.Error(a.hardCtx, "script runner close", slog.Error(err)) } - // Wait for the lifecycle to be reported, but don't wait forever so + // Wait for the graceful shutdown to complete, but don't wait forever so // that we don't break user expectations. - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() + go func() { + defer a.hardCancel() + select { + case <-a.hardCtx.Done(): + case <-time.After(5 * time.Second): + } + }() + + // Wait for lifecycle to be reported lifecycleWaitLoop: for { select { - case <-ctx.Done(): + case <-a.hardCtx.Done(): + a.logger.Warn(context.Background(), "failed to report final lifecycle state") break lifecycleWaitLoop case s := <-a.lifecycleReported: if s == lifecycleState { + a.logger.Debug(context.Background(), "reported final lifecycle state") break lifecycleWaitLoop } } } - close(a.closed) - a.closeCancel() - _ = a.sshServer.Close() + // Wait for graceful disconnect from the Coordinator RPC + select { + case <-a.hardCtx.Done(): + a.logger.Warn(context.Background(), "timed out waiting for Coordinator RPC disconnect") + case <-a.coordDisconnected: + a.logger.Debug(context.Background(), "coordinator RPC disconnected") + } + + a.hardCancel() if a.network != nil { _ = a.network.Close() } - a.connCloseWait.Wait() + a.closeWaitGroup.Wait() return nil } @@ -1688,3 +1825,94 @@ func expandDirectory(dir string) (string, error) { // specialized environment in which the agent is running // (e.g. envbox, envbuilder). const EnvAgentSubsystem = "CODER_AGENT_SUBSYSTEM" + +// eitherContext returns a context that is canceled when either context ends. +func eitherContext(a, b context.Context) context.Context { + ctx, cancel := context.WithCancel(a) + go func() { + defer cancel() + select { + case <-a.Done(): + case <-b.Done(): + } + }() + return ctx +} + +type gracefulShutdownBehavior int + +const ( + gracefulShutdownBehaviorStop gracefulShutdownBehavior = iota + gracefulShutdownBehaviorRemain +) + +type apiConnRoutineManager struct { + logger slog.Logger + conn drpc.Conn + eg *errgroup.Group + stopCtx context.Context + remainCtx context.Context +} + +func newAPIConnRoutineManager(gracefulCtx, hardCtx context.Context, logger slog.Logger, conn drpc.Conn) *apiConnRoutineManager { + // routines that remain in operation during graceful shutdown use the remainCtx. They'll still + // exit if the errgroup hits an error, which usually means a problem with the conn. + eg, remainCtx := errgroup.WithContext(hardCtx) + + // routines that stop operation during graceful shutdown use the stopCtx, which ends when the + // first of remainCtx or gracefulContext ends (an error or start of graceful shutdown). + // + // +------------------------------------------+ + // | hardCtx | + // | +------------------------------------+ | + // | | stopCtx | | + // | | +--------------+ +--------------+ | | + // | | | remainCtx | | gracefulCtx | | | + // | | +--------------+ +--------------+ | | + // | +------------------------------------+ | + // +------------------------------------------+ + stopCtx := eitherContext(remainCtx, gracefulCtx) + return &apiConnRoutineManager{ + logger: logger, + conn: conn, + eg: eg, + stopCtx: stopCtx, + remainCtx: remainCtx, + } +} + +func (a *apiConnRoutineManager) start(name string, b gracefulShutdownBehavior, f func(context.Context, drpc.Conn) error) { + logger := a.logger.With(slog.F("name", name)) + var ctx context.Context + switch b { + case gracefulShutdownBehaviorStop: + ctx = a.stopCtx + case gracefulShutdownBehaviorRemain: + ctx = a.remainCtx + default: + panic("unknown behavior") + } + a.eg.Go(func() error { + logger.Debug(ctx, "starting routine") + err := f(ctx, a.conn) + if xerrors.Is(err, context.Canceled) && ctx.Err() != nil { + logger.Debug(ctx, "swallowing context canceled") + // Don't propagate context canceled errors to the error group, because we don't want the + // graceful context being canceled to halt the work of routines with + // gracefulShutdownBehaviorRemain. Note that we check both that the error is + // context.Canceled and that *our* context is currently canceled, because when Coderd + // unilaterally closes the API connection (for example if the build is outdated), it can + // sometimes show up as context.Canceled in our RPC calls. + return nil + } + logger.Debug(ctx, "routine exited", slog.Error(err)) + if err != nil { + return xerrors.Errorf("error in routine %s: %w", name, err) + } + return nil + }) +} + +func (a *apiConnRoutineManager) wait() error { + return a.eg.Wait() +} diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 81019788c7..cc88bc52d7 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -162,7 +162,13 @@ func TestSSH(t *testing.T) { coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspaceBuild.ID) // Update template version - version = coderdtest.UpdateTemplateVersion(t, ownerClient, owner.OrganizationID, echoResponses, template.ID) + authToken2 := uuid.NewString() + echoResponses2 := &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.PlanComplete, + ProvisionApply: echo.ProvisionApplyWithAgent(authToken2), + } + version = coderdtest.UpdateTemplateVersion(t, ownerClient, owner.OrganizationID, echoResponses2, template.ID) coderdtest.AwaitTemplateVersionJobCompleted(t, ownerClient, version.ID) err := ownerClient.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{ ID: version.ID, @@ -184,7 +190,7 @@ func TestSSH(t *testing.T) { // When the agent connects, the workspace was started, and we should // have access to the shell. - _ = agenttest.New(t, client.URL, authToken) + _ = agenttest.New(t, client.URL, authToken2) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index b1c496e4ba..e75c32f9b0 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -193,7 +193,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can options = &Options{} } if options.Logger == nil { - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("coderd") options.Logger = &logger } if options.GoogleTokenValidator == nil { diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 20d1e221e2..7a8b2e5d3b 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -534,7 +534,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") crdErr := coordination.Close() if crdErr != nil { - tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err)) + tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) } case err = <-coordination.Error(): if err != nil && diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index f69fbff8d4..1564b6587e 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -231,13 +231,14 @@ func TestAuditLogging(t *testing.T) { }, DontAddLicense: true, }) - workspace, agent := setupWorkspaceAgent(t, client, user, 0) - conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil) //nolint:gocritic // RBAC is not the purpose of this test + r := setupWorkspaceAgent(t, client, user, 0) + conn, err := client.DialWorkspaceAgent(ctx, r.sdkAgent.ID, nil) //nolint:gocritic // RBAC is not the purpose of this test require.NoError(t, err) defer conn.Close() connected := conn.AwaitReachable(ctx) require.True(t, connected) - build := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop) + _ = r.agent.Close() // close first so we don't drop error logs from outdated build + build := coderdtest.CreateWorkspaceBuild(t, client, r.workspace, database.WorkspaceTransitionStop) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID) }) } diff --git a/enterprise/coderd/replicas_test.go b/enterprise/coderd/replicas_test.go index 1081ec81e3..6d348db782 100644 --- a/enterprise/coderd/replicas_test.go +++ b/enterprise/coderd/replicas_test.go @@ -81,8 +81,8 @@ func TestReplicas(t *testing.T) { require.NoError(t, err) require.Len(t, replicas, 2) - _, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0) - conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{ + r := setupWorkspaceAgent(t, firstClient, firstUser, 0) + conn, err := secondClient.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, &codersdk.DialWorkspaceAgentOptions{ BlockEndpoints: true, Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), }) @@ -127,8 +127,8 @@ func TestReplicas(t *testing.T) { require.NoError(t, err) require.Len(t, replicas, 2) - _, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0) - conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{ + r := setupWorkspaceAgent(t, firstClient, firstUser, 0) + conn, err := secondClient.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, &codersdk.DialWorkspaceAgentOptions{ BlockEndpoints: true, Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), }) diff --git a/enterprise/coderd/workspaceagents_test.go b/enterprise/coderd/workspaceagents_test.go index 7745eb7289..a6cf84a594 100644 --- a/enterprise/coderd/workspaceagents_test.go +++ b/enterprise/coderd/workspaceagents_test.go @@ -44,9 +44,9 @@ func TestBlockNonBrowser(t *testing.T) { }, }, }) - _, agent := setupWorkspaceAgent(t, client, user, 0) + r := setupWorkspaceAgent(t, client, user, 0) //nolint:gocritic // Testing that even the owner gets blocked. - _, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil) + _, err := client.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, nil) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusConflict, apiErr.StatusCode()) @@ -63,15 +63,21 @@ func TestBlockNonBrowser(t *testing.T) { }, }, }) - _, agent := setupWorkspaceAgent(t, client, user, 0) + r := setupWorkspaceAgent(t, client, user, 0) //nolint:gocritic // Testing RBAC is not the point of this test. - conn, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil) + conn, err := client.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, nil) require.NoError(t, err) _ = conn.Close() }) } -func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.CreateFirstUserResponse, appPort uint16) (codersdk.Workspace, codersdk.WorkspaceAgent) { +type setupResp struct { + workspace codersdk.Workspace + sdkAgent codersdk.WorkspaceAgent + agent agent.Agent +} + +func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.CreateFirstUserResponse, appPort uint16) setupResp { authToken := uuid.NewString() version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ Parse: echo.ParseComplete, @@ -127,20 +133,20 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr }, } agentClient.SetSessionToken(authToken) - agentCloser := agent.New(agent.Options{ + agnt := agent.New(agent.Options{ Client: agentClient, Logger: slogtest.Make(t, nil).Named("agent"), }) t.Cleanup(func() { - _ = agentCloser.Close() + _ = agnt.Close() }) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - agnt, err := client.WorkspaceAgent(ctx, resources[0].Agents[0].ID) + sdkAgent, err := client.WorkspaceAgent(ctx, resources[0].Agents[0].ID) require.NoError(t, err) - return workspace, agnt + return setupResp{workspace, sdkAgent, agnt} } diff --git a/enterprise/coderd/workspaceportshare_test.go b/enterprise/coderd/workspaceportshare_test.go index 04d2d83967..1a8543db68 100644 --- a/enterprise/coderd/workspaceportshare_test.go +++ b/enterprise/coderd/workspaceportshare_test.go @@ -31,7 +31,7 @@ func TestWorkspacePortShare(t *testing.T) { }, }) client, user := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin()) - workspace, agent := setupWorkspaceAgent(t, client, codersdk.CreateFirstUserResponse{ + r := setupWorkspaceAgent(t, client, codersdk.CreateFirstUserResponse{ UserID: user.ID, OrganizationID: owner.OrganizationID, }, 0) @@ -39,8 +39,8 @@ func TestWorkspacePortShare(t *testing.T) { defer cancel() // try to update port share with template max port share level owner - _, err := client.UpsertWorkspaceAgentPortShare(ctx, workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{ - AgentName: agent.Name, + _, err := client.UpsertWorkspaceAgentPortShare(ctx, r.workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{ + AgentName: r.sdkAgent.Name, Port: 8080, ShareLevel: codersdk.WorkspaceAgentPortShareLevelPublic, }) @@ -48,13 +48,13 @@ func TestWorkspacePortShare(t *testing.T) { // update the template max port share level to public var level codersdk.WorkspaceAgentPortShareLevel = codersdk.WorkspaceAgentPortShareLevelPublic - client.UpdateTemplateMeta(ctx, workspace.TemplateID, codersdk.UpdateTemplateMeta{ + client.UpdateTemplateMeta(ctx, r.workspace.TemplateID, codersdk.UpdateTemplateMeta{ MaxPortShareLevel: &level, }) // OK - ps, err := client.UpsertWorkspaceAgentPortShare(ctx, workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{ - AgentName: agent.Name, + ps, err := client.UpsertWorkspaceAgentPortShare(ctx, r.workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{ + AgentName: r.sdkAgent.Name, Port: 8080, ShareLevel: codersdk.WorkspaceAgentPortShareLevelPublic, }) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 530b42aea3..842a6bcbfa 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -131,7 +131,8 @@ func (c *remoteCoordination) Close() (retErr error) { } }() err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) - if err != nil { + if err != nil && !xerrors.Is(err, io.EOF) { + // Coordinator RPC hangs up when it gets disconnect, so EOF is expected. return xerrors.Errorf("send disconnect: %w", err) } c.logger.Debug(context.Background(), "sent disconnect")