chore: refactor agent routines that use the v2 API (#12223)

In anticipation of needing the `LogSender` to run on a context that doesn't get immediately canceled when you `Close()` the agent, I've undertaken a little refactor to manage the goroutines that get run against the Tailnet and Agent API connection.

This handles controlling two contexts, one that gets canceled right away at the start of graceful shutdown, and another that stays up to allow graceful shutdown to complete.
This commit is contained in:
Spike Curtis 2024-02-23 11:04:23 +04:00 committed by GitHub
parent 66585f042f
commit af3fdc68c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 501 additions and 259 deletions

View File

@ -150,13 +150,17 @@ func New(options Options) Agent {
options.Syscaller = agentproc.NewSyscaller()
}
ctx, cancelFunc := context.WithCancel(context.Background())
hardCtx, hardCancel := context.WithCancel(context.Background())
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
a := &agent{
tailnetListenPort: options.TailnetListenPort,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
gracefulCtx: gracefulCtx,
gracefulCancel: gracefulCancel,
hardCtx: hardCtx,
hardCancel: hardCancel,
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
exchangeToken: options.ExchangeToken,
@ -181,9 +185,14 @@ func New(options Options) Agent {
prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
// that gets closed on disconnection. This is used to wait for graceful disconnection from the
// coordinator during shut down.
close(a.coordDisconnected)
a.serviceBanner.Store(new(codersdk.ServiceBannerConfig))
a.sessionToken.Store(new(string))
a.init(ctx)
a.init()
return a
}
@ -206,10 +215,16 @@ type agent struct {
reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration
connCloseWait sync.WaitGroup
closeCancel context.CancelFunc
closeMutex sync.Mutex
closed chan struct{}
// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
// to start gracefully shutting down and "hard" which is Done when it is time to close
// everything down (regardless of whether graceful shutdown completed).
gracefulCtx context.Context
gracefulCancel context.CancelFunc
hardCtx context.Context
hardCancel context.CancelFunc
closeWaitGroup sync.WaitGroup
closeMutex sync.Mutex
coordDisconnected chan struct{}
environmentVariables map[string]string
@ -249,8 +264,9 @@ func (a *agent) TailnetConn() *tailnet.Conn {
return a.network
}
func (a *agent) init(ctx context.Context) {
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
func (a *agent) init() {
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() },
@ -272,22 +288,24 @@ func (a *agent) init(ctx context.Context) {
// Register runner metrics. If the prom registry is nil, the metrics
// will not report anywhere.
a.scriptRunner.RegisterMetrics(a.prometheusRegistry)
go a.runLoop(ctx)
go a.runLoop()
}
// runLoop attempts to start the agent in a retry loop.
// Coder may be offline temporarily, a connection issue
// may be happening, but regardless after the intermittent
// failure, you'll want the agent to reconnect.
func (a *agent) runLoop(ctx context.Context) {
go a.reportLifecycleLoop(ctx)
go a.reportMetadataLoop(ctx)
go a.manageProcessPriorityLoop(ctx)
func (a *agent) runLoop() {
go a.reportLifecycleUntilClose()
go a.reportMetadataUntilGracefulShutdown()
go a.manageProcessPriorityUntilGracefulShutdown()
// need to keep retrying up to the hardCtx so that we can send graceful shutdown-related
// messages.
ctx := a.hardCtx
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
a.logger.Info(ctx, "connecting to coderd")
err := a.run(ctx)
// Cancel after the run is complete to clean up any leaked resources!
err := a.run()
if err == nil {
continue
}
@ -386,7 +404,9 @@ func (t *trySingleflight) Do(key string, fn func()) {
fn()
}
func (a *agent) reportMetadataLoop(ctx context.Context) {
func (a *agent) reportMetadataUntilGracefulShutdown() {
// metadata reporting can cease as soon as we start gracefully shutting down.
ctx := a.gracefulCtx
tickerDone := make(chan struct{})
collectDone := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
@ -595,9 +615,12 @@ func (a *agent) reportMetadataLoop(ctx context.Context) {
}
}
// reportLifecycleLoop reports the current lifecycle state once. All state
// reportLifecycleUntilClose reports the current lifecycle state once. All state
// changes are reported in order.
func (a *agent) reportLifecycleLoop(ctx context.Context) {
func (a *agent) reportLifecycleUntilClose() {
// part of graceful shut down is reporting the final lifecycle states, e.g "ShuttingDown" so the
// lifecycle reporting has to be via the "hard" context.
ctx := a.hardCtx
lastReportedIndex := 0 // Start off with the created state without reporting it.
for {
select {
@ -623,6 +646,8 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
err := a.client.PostLifecycle(ctx, report)
if err == nil {
a.logger.Debug(ctx, "successfully reported lifecycle state", slog.F("payload", report))
r.Reset() // don't back off when we are successful
lastReportedIndex++
select {
case a.lifecycleReported <- report.State:
@ -638,6 +663,7 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
break
}
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
a.logger.Debug(ctx, "canceled reporting lifecycle state", slog.F("payload", report))
return
}
// If we fail to report the state we probably shouldn't exit, log only.
@ -648,7 +674,7 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) {
// setLifecycle sets the lifecycle state and notifies the lifecycle loop.
// The state is only updated if it's a valid state transition.
func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentLifecycle) {
func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
report := agentsdk.PostLifecycleRequest{
State: state,
ChangedAt: dbtime.Now(),
@ -657,12 +683,12 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL
a.lifecycleMu.Lock()
lastReport := a.lifecycleStates[len(a.lifecycleStates)-1]
if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastReport.State) >= slices.Index(codersdk.WorkspaceAgentLifecycleOrder, report.State) {
a.logger.Warn(ctx, "attempted to set lifecycle state to a previous state", slog.F("last", lastReport), slog.F("current", report))
a.logger.Warn(context.Background(), "attempted to set lifecycle state to a previous state", slog.F("last", lastReport), slog.F("current", report))
a.lifecycleMu.Unlock()
return
}
a.lifecycleStates = append(a.lifecycleStates, report)
a.logger.Debug(ctx, "set lifecycle state", slog.F("current", report), slog.F("last", lastReport))
a.logger.Debug(context.Background(), "set lifecycle state", slog.F("current", report), slog.F("last", lastReport))
a.lifecycleMu.Unlock()
select {
@ -674,7 +700,8 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL
// fetchServiceBannerLoop fetches the service banner on an interval. It will
// not be fetched immediately; the expectation is that it is primed elsewhere
// (and must be done before the session actually starts).
func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgentClient) error {
func (a *agent) fetchServiceBannerLoop(ctx context.Context, conn drpc.Conn) error {
aAPI := proto.NewDRPCAgentClient(conn)
ticker := time.NewTicker(a.serviceBannerRefreshInterval)
defer ticker.Stop()
for {
@ -696,205 +723,272 @@ func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgent
}
}
func (a *agent) run(ctx context.Context) error {
func (a *agent) run() (retErr error) {
// This allows the agent to refresh it's token if necessary.
// For instance identity this is required, since the instance
// may not have re-provisioned, but a new agent ID was created.
sessionToken, err := a.exchangeToken(ctx)
sessionToken, err := a.exchangeToken(a.hardCtx)
if err != nil {
return xerrors.Errorf("exchange token: %w", err)
}
a.sessionToken.Store(&sessionToken)
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
conn, err := a.client.ConnectRPC(ctx)
conn, err := a.client.ConnectRPC(a.hardCtx)
if err != nil {
return err
}
defer func() {
cErr := conn.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(err))
}
}()
aAPI := proto.NewDRPCAgentClient(conn)
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
if err != nil {
return xerrors.Errorf("fetch service banner: %w", err)
}
serviceBanner := agentsdk.ServiceBannerFromProto(sbp)
a.serviceBanner.Store(&serviceBanner)
// A lot of routines need the agent API / tailnet API connection. We run them in their own
// goroutines in parallel, but errors in any routine will cause them all to exit so we can
// redial the coder server and retry.
connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger, conn)
mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{})
if err != nil {
return xerrors.Errorf("fetch metadata: %w", err)
}
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp))
manifest, err := agentsdk.ManifestFromProto(mp)
if err != nil {
a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err))
return xerrors.Errorf("convert manifest: %w", err)
}
if manifest.AgentID == uuid.Nil {
return xerrors.New("nil agentID returned by manifest")
}
a.client.RewriteDERPMap(manifest.DERPMap)
// Expand the directory and send it back to coderd so external
// applications that rely on the directory can use it.
//
// An example is VS Code Remote, which must know the directory
// before initializing a connection.
manifest.Directory, err = expandDirectory(manifest.Directory)
if err != nil {
return xerrors.Errorf("expand directory: %w", err)
}
subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems)
if err != nil {
a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err))
return xerrors.Errorf("failed to convert subsystems: %w", err)
}
_, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{
Version: buildinfo.Version(),
ExpandedDirectory: manifest.Directory,
Subsystems: subsys,
}})
if err != nil {
return xerrors.Errorf("update workspace agent startup: %w", err)
}
oldManifest := a.manifest.Swap(&manifest)
// The startup script should only execute on the first run!
if oldManifest == nil {
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStarting)
// Perform overrides early so that Git auth can work even if users
// connect to a workspace that is not yet ready. We don't run this
// concurrently with the startup script to avoid conflicts between
// them.
if manifest.GitAuthConfigs > 0 {
// If this fails, we should consider surfacing the error in the
// startup log and setting the lifecycle state to be "start_error"
// (after startup script completion), but for now we'll just log it.
err := gitauth.OverrideVSCodeConfigs(a.filesystem)
connMan.start("init service banner", gracefulShutdownBehaviorStop,
func(ctx context.Context, conn drpc.Conn) error {
aAPI := proto.NewDRPCAgentClient(conn)
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
if err != nil {
a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err))
return xerrors.Errorf("fetch service banner: %w", err)
}
}
serviceBanner := agentsdk.ServiceBannerFromProto(sbp)
a.serviceBanner.Store(&serviceBanner)
return nil
},
)
err = a.scriptRunner.Init(manifest.Scripts)
if err != nil {
return xerrors.Errorf("init script runner: %w", err)
}
err = a.trackConnGoroutine(func() {
start := time.Now()
err := a.scriptRunner.Execute(ctx, func(script codersdk.WorkspaceAgentScript) bool {
return script.RunOnStart
})
// Measure the time immediately after the script has finished
dur := time.Since(start).Seconds()
if err != nil {
a.logger.Warn(ctx, "startup script(s) failed", slog.Error(err))
if errors.Is(err, agentscripts.ErrTimeout) {
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartTimeout)
} else {
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartError)
}
} else {
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleReady)
}
// channels to sync goroutines below
// handle manifest
// |
// manifestOK
// | |
// | +----------------------+
// V |
// app health reporter |
// V
// create or update network
// |
// networkOK
// |
// coordination <--------------------------+
// derp map subscriber <----------------+
// stats report loop <---------------+
networkOK := make(chan struct{})
manifestOK := make(chan struct{})
label := "false"
if err == nil {
label = "true"
connMan.start("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK))
connMan.start("app health reporter", gracefulShutdownBehaviorStop,
func(ctx context.Context, conn drpc.Conn) error {
select {
case <-ctx.Done():
return nil
case <-manifestOK:
manifest := a.manifest.Load()
NewWorkspaceAppHealthReporter(
a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)),
)(ctx)
return nil
}
a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur)
a.scriptRunner.StartCron()
})
if err != nil {
return xerrors.Errorf("track conn goroutine: %w", err)
connMan.start("create or update network", gracefulShutdownBehaviorStop,
a.createOrUpdateNetwork(manifestOK, networkOK))
connMan.start("coordination", gracefulShutdownBehaviorStop,
func(ctx context.Context, conn drpc.Conn) error {
select {
case <-ctx.Done():
return nil
case <-networkOK:
}
return a.runCoordinator(ctx, conn, a.network)
},
)
connMan.start("derp map subscriber", gracefulShutdownBehaviorStop,
func(ctx context.Context, conn drpc.Conn) error {
select {
case <-ctx.Done():
return nil
case <-networkOK:
}
return a.runDERPMapSubscriber(ctx, conn, a.network)
})
connMan.start("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop)
connMan.start("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error {
select {
case <-ctx.Done():
return nil
case <-networkOK:
}
return a.statsReporter.reportLoop(ctx, proto.NewDRPCAgentClient(conn))
})
return connMan.wait()
}
// handleManifest returns a function that fetches and processes the manifest
func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Context, conn drpc.Conn) error {
return func(ctx context.Context, conn drpc.Conn) error {
aAPI := proto.NewDRPCAgentClient(conn)
mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{})
if err != nil {
return xerrors.Errorf("fetch metadata: %w", err)
}
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp))
manifest, err := agentsdk.ManifestFromProto(mp)
if err != nil {
a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err))
return xerrors.Errorf("convert manifest: %w", err)
}
if manifest.AgentID == uuid.Nil {
return xerrors.New("nil agentID returned by manifest")
}
a.client.RewriteDERPMap(manifest.DERPMap)
// Expand the directory and send it back to coderd so external
// applications that rely on the directory can use it.
//
// An example is VS Code Remote, which must know the directory
// before initializing a connection.
manifest.Directory, err = expandDirectory(manifest.Directory)
if err != nil {
return xerrors.Errorf("expand directory: %w", err)
}
subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems)
if err != nil {
a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err))
return xerrors.Errorf("failed to convert subsystems: %w", err)
}
_, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{
Version: buildinfo.Version(),
ExpandedDirectory: manifest.Directory,
Subsystems: subsys,
}})
if err != nil {
if xerrors.Is(err, context.Canceled) {
return nil
}
return xerrors.Errorf("update workspace agent startup: %w", err)
}
oldManifest := a.manifest.Swap(&manifest)
close(manifestOK)
// The startup script should only execute on the first run!
if oldManifest == nil {
a.setLifecycle(codersdk.WorkspaceAgentLifecycleStarting)
// Perform overrides early so that Git auth can work even if users
// connect to a workspace that is not yet ready. We don't run this
// concurrently with the startup script to avoid conflicts between
// them.
if manifest.GitAuthConfigs > 0 {
// If this fails, we should consider surfacing the error in the
// startup log and setting the lifecycle state to be "start_error"
// (after startup script completion), but for now we'll just log it.
err := gitauth.OverrideVSCodeConfigs(a.filesystem)
if err != nil {
a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err))
}
}
err = a.scriptRunner.Init(manifest.Scripts)
if err != nil {
return xerrors.Errorf("init script runner: %w", err)
}
err = a.trackGoroutine(func() {
start := time.Now()
// here we use the graceful context because the script runner is not directly tied
// to the agent API.
err := a.scriptRunner.Execute(a.gracefulCtx, func(script codersdk.WorkspaceAgentScript) bool {
return script.RunOnStart
})
// Measure the time immediately after the script has finished
dur := time.Since(start).Seconds()
if err != nil {
a.logger.Warn(ctx, "startup script(s) failed", slog.Error(err))
if errors.Is(err, agentscripts.ErrTimeout) {
a.setLifecycle(codersdk.WorkspaceAgentLifecycleStartTimeout)
} else {
a.setLifecycle(codersdk.WorkspaceAgentLifecycleStartError)
}
} else {
a.setLifecycle(codersdk.WorkspaceAgentLifecycleReady)
}
label := "false"
if err == nil {
label = "true"
}
a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur)
a.scriptRunner.StartCron()
})
if err != nil {
return xerrors.Errorf("track conn goroutine: %w", err)
}
}
return nil
}
}
// This automatically closes when the context ends!
appReporterCtx, appReporterCtxCancel := context.WithCancel(ctx)
defer appReporterCtxCancel()
go NewWorkspaceAppHealthReporter(
a.logger, manifest.Apps, agentsdk.AppHealthPoster(aAPI))(appReporterCtx)
a.closeMutex.Lock()
network := a.network
a.closeMutex.Unlock()
if network == nil {
network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections)
if err != nil {
return xerrors.Errorf("create tailnet: %w", err)
// createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates
// the tailnet using the information in the manifest
func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan<- struct{}) func(context.Context, drpc.Conn) error {
return func(ctx context.Context, _ drpc.Conn) error {
select {
case <-ctx.Done():
return nil
case <-manifestOK:
}
var err error
manifest := a.manifest.Load()
a.closeMutex.Lock()
// Re-check if agent was closed while initializing the network.
closed := a.isClosed()
if !closed {
a.network = network
a.statsReporter = newStatsReporter(a.logger, network, a)
}
network := a.network
a.closeMutex.Unlock()
if closed {
_ = network.Close()
return xerrors.New("agent is closed")
if network == nil {
// use the graceful context here, because creating the tailnet is not itself tied to the
// agent API.
network, err = a.createTailnet(a.gracefulCtx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections)
if err != nil {
return xerrors.Errorf("create tailnet: %w", err)
}
a.closeMutex.Lock()
// Re-check if agent was closed while initializing the network.
closed := a.isClosed()
if !closed {
a.network = network
a.statsReporter = newStatsReporter(a.logger, network, a)
}
a.closeMutex.Unlock()
if closed {
_ = network.Close()
return xerrors.New("agent is closed")
}
} else {
// Update the wireguard IPs if the agent ID changed.
err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID))
if err != nil {
a.logger.Error(a.gracefulCtx, "update tailnet addresses", slog.Error(err))
}
// Update the DERP map, force WebSocket setting and allow/disallow
// direct connections.
network.SetDERPMap(manifest.DERPMap)
network.SetDERPForceWebSockets(manifest.DERPForceWebSockets)
network.SetBlockEndpoints(manifest.DisableDirectConnections)
}
} else {
// Update the wireguard IPs if the agent ID changed.
err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID))
if err != nil {
a.logger.Error(ctx, "update tailnet addresses", slog.Error(err))
}
// Update the DERP map, force WebSocket setting and allow/disallow
// direct connections.
network.SetDERPMap(manifest.DERPMap)
network.SetDERPForceWebSockets(manifest.DERPForceWebSockets)
network.SetBlockEndpoints(manifest.DisableDirectConnections)
close(networkOK)
return nil
}
eg, egCtx := errgroup.WithContext(ctx)
eg.Go(func() error {
a.logger.Debug(egCtx, "running tailnet connection coordinator")
err := a.runCoordinator(egCtx, conn, network)
if err != nil {
return xerrors.Errorf("run coordinator: %w", err)
}
return nil
})
eg.Go(func() error {
a.logger.Debug(egCtx, "running derp map subscriber")
err := a.runDERPMapSubscriber(egCtx, conn, network)
if err != nil {
return xerrors.Errorf("run derp map subscriber: %w", err)
}
return nil
})
eg.Go(func() error {
a.logger.Debug(egCtx, "running fetch server banner loop")
err := a.fetchServiceBannerLoop(egCtx, aAPI)
if err != nil {
return xerrors.Errorf("fetch server banner loop: %w", err)
}
return nil
})
eg.Go(func() error {
a.logger.Debug(egCtx, "running stats report loop")
err := a.statsReporter.reportLoop(egCtx, aAPI)
if err != nil {
return xerrors.Errorf("report stats loop: %w", err)
}
return nil
})
return eg.Wait()
}
// updateCommandEnv updates the provided command environment with the
@ -995,15 +1089,15 @@ func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
return a.addresses
}
func (a *agent) trackConnGoroutine(fn func()) error {
func (a *agent) trackGoroutine(fn func()) error {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
return xerrors.New("track conn goroutine: agent is closed")
}
a.connCloseWait.Add(1)
a.closeWaitGroup.Add(1)
go func() {
defer a.connCloseWait.Done()
defer a.closeWaitGroup.Done()
fn()
}()
return nil
@ -1037,7 +1131,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
_ = sshListener.Close()
}
}()
if err = a.trackConnGoroutine(func() {
if err = a.trackGoroutine(func() {
_ = a.sshServer.Serve(sshListener)
}); err != nil {
return nil, err
@ -1052,7 +1146,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
_ = reconnectingPTYListener.Close()
}
}()
if err = a.trackConnGoroutine(func() {
if err = a.trackGoroutine(func() {
logger := a.logger.Named("reconnecting-pty")
var wg sync.WaitGroup
for {
@ -1072,7 +1166,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
go func() {
select {
case <-closed:
case <-a.closed:
case <-a.hardCtx.Done():
_ = conn.Close()
}
wg.Done()
@ -1115,7 +1209,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
_ = speedtestListener.Close()
}
}()
if err = a.trackConnGoroutine(func() {
if err = a.trackGoroutine(func() {
var wg sync.WaitGroup
for {
conn, err := speedtestListener.Accept()
@ -1134,7 +1228,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
go func() {
select {
case <-closed:
case <-a.closed:
case <-a.hardCtx.Done():
_ = conn.Close()
}
wg.Done()
@ -1163,7 +1257,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
_ = apiListener.Close()
}
}()
if err = a.trackConnGoroutine(func() {
if err = a.trackGoroutine(func() {
defer apiListener.Close()
server := &http.Server{
Handler: a.apiHandler(),
@ -1175,7 +1269,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
go func() {
select {
case <-ctx.Done():
case <-a.closed:
case <-a.hardCtx.Done():
}
_ = server.Close()
}()
@ -1196,7 +1290,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error {
defer a.logger.Debug(ctx, "disconnected from coordination RPC")
tClient := tailnetproto.NewDRPCTailnetClient(conn)
coordinate, err := tClient.Coordinate(ctx)
// we run the RPC on the hardCtx so that we have a chance to send the disconnect message if we
// gracefully shut down.
coordinate, err := tClient.Coordinate(a.hardCtx)
if err != nil {
return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err)
}
@ -1207,13 +1303,34 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
}
}()
a.logger.Info(ctx, "connected to coordination RPC")
coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil)
select {
case <-ctx.Done():
return ctx.Err()
case err := <-coordination.Error():
return err
// This allows the Close() routine to wait for the coordinator to gracefully disconnect.
a.closeMutex.Lock()
if a.isClosed() {
return nil
}
disconnected := make(chan struct{})
a.coordDisconnected = disconnected
defer close(disconnected)
a.closeMutex.Unlock()
coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil)
errCh := make(chan error, 1)
go func() {
defer close(errCh)
select {
case <-ctx.Done():
err := coordination.Close()
if err != nil {
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
}
return
case err := <-coordination.Error():
errCh <- err
}
}()
return <-errCh
}
// runDERPMapSubscriber runs a coordinator and returns if a reconnect should occur.
@ -1311,7 +1428,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
Metrics: a.metrics.reconnectingPTYErrors,
}, logger.With(slog.F("message_id", msg.ID)))
if err = a.trackConnGoroutine(func() {
if err = a.trackGoroutine(func() {
rpty.Wait()
a.reconnectingPTYs.Delete(msg.ID)
}); err != nil {
@ -1406,7 +1523,9 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
var prioritizedProcs = []string{"coder agent"}
func (a *agent) manageProcessPriorityLoop(ctx context.Context) {
func (a *agent) manageProcessPriorityUntilGracefulShutdown() {
// process priority can stop as soon as we are gracefully shutting down
ctx := a.gracefulCtx
defer func() {
if r := recover(); r != nil {
a.logger.Critical(ctx, "recovered from panic",
@ -1515,12 +1634,7 @@ func (a *agent) manageProcessPriority(ctx context.Context) ([]*agentproc.Process
// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
case <-a.closed:
return true
default:
return false
}
return a.hardCtx.Err() != nil
}
func (a *agent) HTTPDebug() http.Handler {
@ -1584,59 +1698,82 @@ func (a *agent) Close() error {
return nil
}
ctx := context.Background()
a.logger.Info(ctx, "shutting down agent")
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown)
a.logger.Info(a.hardCtx, "shutting down agent")
a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown)
// Attempt to gracefully shut down all active SSH connections and
// stop accepting new ones.
err := a.sshServer.Shutdown(ctx)
err := a.sshServer.Shutdown(a.hardCtx)
if err != nil {
a.logger.Error(ctx, "ssh server shutdown", slog.Error(err))
a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err))
}
err = a.sshServer.Close()
if err != nil {
a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err))
}
// wait for SSH to shut down before the general graceful cancel, because
// this triggers a disconnect in the tailnet layer, telling all clients to
// shut down their wireguard tunnels to us. If SSH sessions are still up,
// they might hang instead of being closed.
a.gracefulCancel()
lifecycleState := codersdk.WorkspaceAgentLifecycleOff
err = a.scriptRunner.Execute(ctx, func(script codersdk.WorkspaceAgentScript) bool {
err = a.scriptRunner.Execute(a.hardCtx, func(script codersdk.WorkspaceAgentScript) bool {
return script.RunOnStop
})
if err != nil {
a.logger.Warn(ctx, "shutdown script(s) failed", slog.Error(err))
a.logger.Warn(a.hardCtx, "shutdown script(s) failed", slog.Error(err))
if errors.Is(err, agentscripts.ErrTimeout) {
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownTimeout
} else {
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
}
}
a.setLifecycle(ctx, lifecycleState)
a.setLifecycle(lifecycleState)
err = a.scriptRunner.Close()
if err != nil {
a.logger.Error(ctx, "script runner close", slog.Error(err))
a.logger.Error(a.hardCtx, "script runner close", slog.Error(err))
}
// Wait for the lifecycle to be reported, but don't wait forever so
// Wait for the graceful shutdown to complete, but don't wait forever so
// that we don't break user expectations.
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
go func() {
defer a.hardCancel()
select {
case <-a.hardCtx.Done():
case <-time.After(5 * time.Second):
}
}()
// Wait for lifecycle to be reported
lifecycleWaitLoop:
for {
select {
case <-ctx.Done():
case <-a.hardCtx.Done():
a.logger.Warn(context.Background(), "failed to report final lifecycle state")
break lifecycleWaitLoop
case s := <-a.lifecycleReported:
if s == lifecycleState {
a.logger.Debug(context.Background(), "reported final lifecycle state")
break lifecycleWaitLoop
}
}
}
close(a.closed)
a.closeCancel()
_ = a.sshServer.Close()
// Wait for graceful disconnect from the Coordinator RPC
select {
case <-a.hardCtx.Done():
a.logger.Warn(context.Background(), "timed out waiting for Coordinator RPC disconnect")
case <-a.coordDisconnected:
a.logger.Debug(context.Background(), "coordinator RPC disconnected")
}
a.hardCancel()
if a.network != nil {
_ = a.network.Close()
}
a.connCloseWait.Wait()
a.closeWaitGroup.Wait()
return nil
}
@ -1688,3 +1825,94 @@ func expandDirectory(dir string) (string, error) {
// specialized environment in which the agent is running
// (e.g. envbox, envbuilder).
const EnvAgentSubsystem = "CODER_AGENT_SUBSYSTEM"
// eitherContext returns a context that is canceled when either context ends.
func eitherContext(a, b context.Context) context.Context {
ctx, cancel := context.WithCancel(a)
go func() {
defer cancel()
select {
case <-a.Done():
case <-b.Done():
}
}()
return ctx
}
type gracefulShutdownBehavior int
const (
gracefulShutdownBehaviorStop gracefulShutdownBehavior = iota
gracefulShutdownBehaviorRemain
)
type apiConnRoutineManager struct {
logger slog.Logger
conn drpc.Conn
eg *errgroup.Group
stopCtx context.Context
remainCtx context.Context
}
func newAPIConnRoutineManager(gracefulCtx, hardCtx context.Context, logger slog.Logger, conn drpc.Conn) *apiConnRoutineManager {
// routines that remain in operation during graceful shutdown use the remainCtx. They'll still
// exit if the errgroup hits an error, which usually means a problem with the conn.
eg, remainCtx := errgroup.WithContext(hardCtx)
// routines that stop operation during graceful shutdown use the stopCtx, which ends when the
// first of remainCtx or gracefulContext ends (an error or start of graceful shutdown).
//
// +------------------------------------------+
// | hardCtx |
// | +------------------------------------+ |
// | | stopCtx | |
// | | +--------------+ +--------------+ | |
// | | | remainCtx | | gracefulCtx | | |
// | | +--------------+ +--------------+ | |
// | +------------------------------------+ |
// +------------------------------------------+
stopCtx := eitherContext(remainCtx, gracefulCtx)
return &apiConnRoutineManager{
logger: logger,
conn: conn,
eg: eg,
stopCtx: stopCtx,
remainCtx: remainCtx,
}
}
func (a *apiConnRoutineManager) start(name string, b gracefulShutdownBehavior, f func(context.Context, drpc.Conn) error) {
logger := a.logger.With(slog.F("name", name))
var ctx context.Context
switch b {
case gracefulShutdownBehaviorStop:
ctx = a.stopCtx
case gracefulShutdownBehaviorRemain:
ctx = a.remainCtx
default:
panic("unknown behavior")
}
a.eg.Go(func() error {
logger.Debug(ctx, "starting routine")
err := f(ctx, a.conn)
if xerrors.Is(err, context.Canceled) && ctx.Err() != nil {
logger.Debug(ctx, "swallowing context canceled")
// Don't propagate context canceled errors to the error group, because we don't want the
// graceful context being canceled to halt the work of routines with
// gracefulShutdownBehaviorRemain. Note that we check both that the error is
// context.Canceled and that *our* context is currently canceled, because when Coderd
// unilaterally closes the API connection (for example if the build is outdated), it can
// sometimes show up as context.Canceled in our RPC calls.
return nil
}
logger.Debug(ctx, "routine exited", slog.Error(err))
if err != nil {
return xerrors.Errorf("error in routine %s: %w", name, err)
}
return nil
})
}
func (a *apiConnRoutineManager) wait() error {
return a.eg.Wait()
}

View File

@ -162,7 +162,13 @@ func TestSSH(t *testing.T) {
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspaceBuild.ID)
// Update template version
version = coderdtest.UpdateTemplateVersion(t, ownerClient, owner.OrganizationID, echoResponses, template.ID)
authToken2 := uuid.NewString()
echoResponses2 := &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionApply: echo.ProvisionApplyWithAgent(authToken2),
}
version = coderdtest.UpdateTemplateVersion(t, ownerClient, owner.OrganizationID, echoResponses2, template.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, ownerClient, version.ID)
err := ownerClient.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
ID: version.ID,
@ -184,7 +190,7 @@ func TestSSH(t *testing.T) {
// When the agent connects, the workspace was started, and we should
// have access to the shell.
_ = agenttest.New(t, client.URL, authToken)
_ = agenttest.New(t, client.URL, authToken2)
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
// Shells on Mac, Windows, and Linux all exit shells with the "exit" command.

View File

@ -193,7 +193,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
options = &Options{}
}
if options.Logger == nil {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("coderd")
options.Logger = &logger
}
if options.GoogleTokenValidator == nil {

View File

@ -534,7 +534,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
crdErr := coordination.Close()
if crdErr != nil {
tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err))
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
}
case err = <-coordination.Error():
if err != nil &&

View File

@ -231,13 +231,14 @@ func TestAuditLogging(t *testing.T) {
},
DontAddLicense: true,
})
workspace, agent := setupWorkspaceAgent(t, client, user, 0)
conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil) //nolint:gocritic // RBAC is not the purpose of this test
r := setupWorkspaceAgent(t, client, user, 0)
conn, err := client.DialWorkspaceAgent(ctx, r.sdkAgent.ID, nil) //nolint:gocritic // RBAC is not the purpose of this test
require.NoError(t, err)
defer conn.Close()
connected := conn.AwaitReachable(ctx)
require.True(t, connected)
build := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop)
_ = r.agent.Close() // close first so we don't drop error logs from outdated build
build := coderdtest.CreateWorkspaceBuild(t, client, r.workspace, database.WorkspaceTransitionStop)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)
})
}

View File

@ -81,8 +81,8 @@ func TestReplicas(t *testing.T) {
require.NoError(t, err)
require.Len(t, replicas, 2)
_, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0)
conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{
r := setupWorkspaceAgent(t, firstClient, firstUser, 0)
conn, err := secondClient.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, &codersdk.DialWorkspaceAgentOptions{
BlockEndpoints: true,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
})
@ -127,8 +127,8 @@ func TestReplicas(t *testing.T) {
require.NoError(t, err)
require.Len(t, replicas, 2)
_, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0)
conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{
r := setupWorkspaceAgent(t, firstClient, firstUser, 0)
conn, err := secondClient.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, &codersdk.DialWorkspaceAgentOptions{
BlockEndpoints: true,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})

View File

@ -44,9 +44,9 @@ func TestBlockNonBrowser(t *testing.T) {
},
},
})
_, agent := setupWorkspaceAgent(t, client, user, 0)
r := setupWorkspaceAgent(t, client, user, 0)
//nolint:gocritic // Testing that even the owner gets blocked.
_, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil)
_, err := client.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, nil)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
@ -63,15 +63,21 @@ func TestBlockNonBrowser(t *testing.T) {
},
},
})
_, agent := setupWorkspaceAgent(t, client, user, 0)
r := setupWorkspaceAgent(t, client, user, 0)
//nolint:gocritic // Testing RBAC is not the point of this test.
conn, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil)
conn, err := client.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, nil)
require.NoError(t, err)
_ = conn.Close()
})
}
func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.CreateFirstUserResponse, appPort uint16) (codersdk.Workspace, codersdk.WorkspaceAgent) {
type setupResp struct {
workspace codersdk.Workspace
sdkAgent codersdk.WorkspaceAgent
agent agent.Agent
}
func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.CreateFirstUserResponse, appPort uint16) setupResp {
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
@ -127,20 +133,20 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr
},
}
agentClient.SetSessionToken(authToken)
agentCloser := agent.New(agent.Options{
agnt := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
t.Cleanup(func() {
_ = agentCloser.Close()
_ = agnt.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
agnt, err := client.WorkspaceAgent(ctx, resources[0].Agents[0].ID)
sdkAgent, err := client.WorkspaceAgent(ctx, resources[0].Agents[0].ID)
require.NoError(t, err)
return workspace, agnt
return setupResp{workspace, sdkAgent, agnt}
}

View File

@ -31,7 +31,7 @@ func TestWorkspacePortShare(t *testing.T) {
},
})
client, user := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin())
workspace, agent := setupWorkspaceAgent(t, client, codersdk.CreateFirstUserResponse{
r := setupWorkspaceAgent(t, client, codersdk.CreateFirstUserResponse{
UserID: user.ID,
OrganizationID: owner.OrganizationID,
}, 0)
@ -39,8 +39,8 @@ func TestWorkspacePortShare(t *testing.T) {
defer cancel()
// try to update port share with template max port share level owner
_, err := client.UpsertWorkspaceAgentPortShare(ctx, workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{
AgentName: agent.Name,
_, err := client.UpsertWorkspaceAgentPortShare(ctx, r.workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{
AgentName: r.sdkAgent.Name,
Port: 8080,
ShareLevel: codersdk.WorkspaceAgentPortShareLevelPublic,
})
@ -48,13 +48,13 @@ func TestWorkspacePortShare(t *testing.T) {
// update the template max port share level to public
var level codersdk.WorkspaceAgentPortShareLevel = codersdk.WorkspaceAgentPortShareLevelPublic
client.UpdateTemplateMeta(ctx, workspace.TemplateID, codersdk.UpdateTemplateMeta{
client.UpdateTemplateMeta(ctx, r.workspace.TemplateID, codersdk.UpdateTemplateMeta{
MaxPortShareLevel: &level,
})
// OK
ps, err := client.UpsertWorkspaceAgentPortShare(ctx, workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{
AgentName: agent.Name,
ps, err := client.UpsertWorkspaceAgentPortShare(ctx, r.workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{
AgentName: r.sdkAgent.Name,
Port: 8080,
ShareLevel: codersdk.WorkspaceAgentPortShareLevelPublic,
})

View File

@ -131,7 +131,8 @@ func (c *remoteCoordination) Close() (retErr error) {
}
}()
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
if err != nil && !xerrors.Is(err, io.EOF) {
// Coordinator RPC hangs up when it gets disconnect, so EOF is expected.
return xerrors.Errorf("send disconnect: %w", err)
}
c.logger.Debug(context.Background(), "sent disconnect")