package agent import ( "bytes" "context" "encoding/binary" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/netip" "os" "os/user" "path/filepath" "runtime" "runtime/debug" "sort" "strconv" "strings" "sync" "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/common/expfmt" "github.com/spf13/afero" "go.uber.org/atomic" "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" "storj.io/drpc" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" "tailscale.com/types/netlogtype" "tailscale.com/util/clientmetric" "cdr.dev/slog" "github.com/coder/coder/v2/agent/agentproc" "github.com/coder/coder/v2/agent/agentscripts" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/agent/reconnectingpty" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/cli/gitauth" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/tailnet" tailnetproto "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/retry" ) const ( ProtocolReconnectingPTY = "reconnecting-pty" ProtocolSSH = "ssh" ProtocolDial = "dial" ) // EnvProcPrioMgmt determines whether we attempt to manage // process CPU and OOM Killer priority. const ( EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT" EnvProcOOMScore = "CODER_PROC_OOM_SCORE" ) type Options struct { Filesystem afero.Fs LogDir string TempDir string ScriptDataDir string ExchangeToken func(ctx context.Context) (string, error) Client Client ReconnectingPTYTimeout time.Duration EnvironmentVariables map[string]string Logger slog.Logger IgnorePorts map[int]string PortCacheDuration time.Duration SSHMaxTimeout time.Duration TailnetListenPort uint16 Subsystems []codersdk.AgentSubsystem Addresses []netip.Prefix PrometheusRegistry *prometheus.Registry ReportMetadataInterval time.Duration ServiceBannerRefreshInterval time.Duration Syscaller agentproc.Syscaller // ModifiedProcesses is used for testing process priority management. ModifiedProcesses chan []*agentproc.Process // ProcessManagementTick is used for testing process priority management. ProcessManagementTick <-chan time.Time } type Client interface { ConnectRPC(ctx context.Context) (drpc.Conn, error) RewriteDERPMap(derpMap *tailcfg.DERPMap) } type Agent interface { HTTPDebug() http.Handler // TailnetConn may be nil. TailnetConn() *tailnet.Conn io.Closer } func New(options Options) Agent { if options.Filesystem == nil { options.Filesystem = afero.NewOsFs() } if options.TempDir == "" { options.TempDir = os.TempDir() } if options.LogDir == "" { if options.TempDir != os.TempDir() { options.Logger.Debug(context.Background(), "log dir not set, using temp dir", slog.F("temp_dir", options.TempDir)) } else { options.Logger.Debug(context.Background(), "using log dir", slog.F("log_dir", options.LogDir)) } options.LogDir = options.TempDir } if options.ScriptDataDir == "" { if options.TempDir != os.TempDir() { options.Logger.Debug(context.Background(), "script data dir not set, using temp dir", slog.F("temp_dir", options.TempDir)) } else { options.Logger.Debug(context.Background(), "using script data dir", slog.F("script_data_dir", options.ScriptDataDir)) } options.ScriptDataDir = options.TempDir } if options.ExchangeToken == nil { options.ExchangeToken = func(ctx context.Context) (string, error) { return "", nil } } if options.ReportMetadataInterval == 0 { options.ReportMetadataInterval = time.Second } if options.ServiceBannerRefreshInterval == 0 { options.ServiceBannerRefreshInterval = 2 * time.Minute } if options.PortCacheDuration == 0 { options.PortCacheDuration = 1 * time.Second } prometheusRegistry := options.PrometheusRegistry if prometheusRegistry == nil { prometheusRegistry = prometheus.NewRegistry() } if options.Syscaller == nil { options.Syscaller = agentproc.NewSyscaller() } hardCtx, hardCancel := context.WithCancel(context.Background()) gracefulCtx, gracefulCancel := context.WithCancel(hardCtx) a := &agent{ tailnetListenPort: options.TailnetListenPort, reconnectingPTYTimeout: options.ReconnectingPTYTimeout, logger: options.Logger, gracefulCtx: gracefulCtx, gracefulCancel: gracefulCancel, hardCtx: hardCtx, hardCancel: hardCancel, coordDisconnected: make(chan struct{}), environmentVariables: options.EnvironmentVariables, client: options.Client, exchangeToken: options.ExchangeToken, filesystem: options.Filesystem, logDir: options.LogDir, tempDir: options.TempDir, scriptDataDir: options.ScriptDataDir, lifecycleUpdate: make(chan struct{}, 1), lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, ignorePorts: options.IgnorePorts, portCacheDuration: options.PortCacheDuration, reportMetadataInterval: options.ReportMetadataInterval, notificationBannersRefreshInterval: options.ServiceBannerRefreshInterval, sshMaxTimeout: options.SSHMaxTimeout, subsystems: options.Subsystems, addresses: options.Addresses, syscaller: options.Syscaller, modifiedProcs: options.ModifiedProcesses, processManagementTick: options.ProcessManagementTick, logSender: agentsdk.NewLogSender(options.Logger), prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), } // Initially, we have a closed channel, reflecting the fact that we are not initially connected. // Each time we connect we replace the channel (while holding the closeMutex) with a new one // that gets closed on disconnection. This is used to wait for graceful disconnection from the // coordinator during shut down. close(a.coordDisconnected) a.notificationBanners.Store(new([]codersdk.BannerConfig)) a.sessionToken.Store(new(string)) a.init() return a } type agent struct { logger slog.Logger client Client exchangeToken func(ctx context.Context) (string, error) tailnetListenPort uint16 filesystem afero.Fs logDir string tempDir string scriptDataDir string // ignorePorts tells the api handler which ports to ignore when // listing all listening ports. This is helpful to hide ports that // are used by the agent, that the user does not care about. ignorePorts map[int]string portCacheDuration time.Duration subsystems []codersdk.AgentSubsystem reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration // we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time // to start gracefully shutting down and "hard" which is Done when it is time to close // everything down (regardless of whether graceful shutdown completed). gracefulCtx context.Context gracefulCancel context.CancelFunc hardCtx context.Context hardCancel context.CancelFunc closeWaitGroup sync.WaitGroup closeMutex sync.Mutex coordDisconnected chan struct{} environmentVariables map[string]string manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection. reportMetadataInterval time.Duration scriptRunner *agentscripts.Runner notificationBanners atomic.Pointer[[]codersdk.BannerConfig] // notificationBanners is atomic because it is periodically updated. notificationBannersRefreshInterval time.Duration sessionToken atomic.Pointer[string] sshServer *agentssh.Server sshMaxTimeout time.Duration lifecycleUpdate chan struct{} lifecycleReported chan codersdk.WorkspaceAgentLifecycle lifecycleMu sync.RWMutex // Protects following. lifecycleStates []agentsdk.PostLifecycleRequest lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported. network *tailnet.Conn addresses []netip.Prefix statsReporter *statsReporter logSender *agentsdk.LogSender connCountReconnectingPTY atomic.Int64 prometheusRegistry *prometheus.Registry // metrics are prometheus registered metrics that will be collected and // labeled in Coder with the agent + workspace. metrics *agentMetrics syscaller agentproc.Syscaller // modifiedProcs is used for testing process priority management. modifiedProcs chan []*agentproc.Process // processManagementTick is used for testing process priority management. processManagementTick <-chan time.Time } func (a *agent) TailnetConn() *tailnet.Conn { return a.network } func (a *agent) init() { // pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown. sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{ MaxTimeout: a.sshMaxTimeout, MOTDFile: func() string { return a.manifest.Load().MOTDFile }, NotificationBanners: func() *[]codersdk.BannerConfig { return a.notificationBanners.Load() }, UpdateEnv: a.updateCommandEnv, WorkingDirectory: func() string { return a.manifest.Load().Directory }, }) if err != nil { panic(err) } a.sshServer = sshSrv a.scriptRunner = agentscripts.New(agentscripts.Options{ LogDir: a.logDir, DataDirBase: a.scriptDataDir, Logger: a.logger, SSHServer: sshSrv, Filesystem: a.filesystem, GetScriptLogger: func(logSourceID uuid.UUID) agentscripts.ScriptLogger { return a.logSender.GetScriptLogger(logSourceID) }, }) // Register runner metrics. If the prom registry is nil, the metrics // will not report anywhere. a.scriptRunner.RegisterMetrics(a.prometheusRegistry) go a.runLoop() } // runLoop attempts to start the agent in a retry loop. // Coder may be offline temporarily, a connection issue // may be happening, but regardless after the intermittent // failure, you'll want the agent to reconnect. func (a *agent) runLoop() { go a.manageProcessPriorityUntilGracefulShutdown() // need to keep retrying up to the hardCtx so that we can send graceful shutdown-related // messages. ctx := a.hardCtx for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting to coderd") err := a.run() if err == nil { continue } if ctx.Err() != nil { // Context canceled errors may come from websocket pings, so we // don't want to use `errors.Is(err, context.Canceled)` here. return } if a.isClosed() { return } if errors.Is(err, io.EOF) { a.logger.Info(ctx, "disconnected from coderd") continue } a.logger.Warn(ctx, "run exited with error", slog.Error(err)) } } func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentMetadataDescription, now time.Time) *codersdk.WorkspaceAgentMetadataResult { var out bytes.Buffer result := &codersdk.WorkspaceAgentMetadataResult{ // CollectedAt is set here for testing purposes and overrode by // coderd to the time of server receipt to solve clock skew. // // In the future, the server may accept the timestamp from the agent // if it can guarantee the clocks are synchronized. CollectedAt: now, } cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil) if err != nil { result.Error = fmt.Sprintf("create cmd: %+v", err) return result } cmd := cmdPty.AsExec() cmd.Stdout = &out cmd.Stderr = &out cmd.Stdin = io.LimitReader(nil, 0) // We split up Start and Wait instead of calling Run so that we can return a more precise error. err = cmd.Start() if err != nil { result.Error = fmt.Sprintf("start cmd: %+v", err) return result } // This error isn't mutually exclusive with useful output. err = cmd.Wait() const bufLimit = 10 << 10 if out.Len() > bufLimit { err = errors.Join( err, xerrors.Errorf("output truncated from %v to %v bytes", out.Len(), bufLimit), ) out.Truncate(bufLimit) } // Important: if the command times out, we may see a misleading error like // "exit status 1", so it's important to include the context error. err = errors.Join(err, ctx.Err()) if err != nil { result.Error = fmt.Sprintf("run cmd: %+v", err) } result.Value = out.String() return result } type metadataResultAndKey struct { result *codersdk.WorkspaceAgentMetadataResult key string } type trySingleflight struct { mu sync.Mutex m map[string]struct{} } func (t *trySingleflight) Do(key string, fn func()) { t.mu.Lock() _, ok := t.m[key] if ok { t.mu.Unlock() return } t.m[key] = struct{}{} t.mu.Unlock() defer func() { t.mu.Lock() delete(t.m, key) t.mu.Unlock() }() fn() } func (a *agent) reportMetadata(ctx context.Context, conn drpc.Conn) error { tickerDone := make(chan struct{}) collectDone := make(chan struct{}) ctx, cancel := context.WithCancel(ctx) defer func() { cancel() <-collectDone <-tickerDone }() var ( logger = a.logger.Named("metadata") report = make(chan struct{}, 1) collect = make(chan struct{}, 1) metadataResults = make(chan metadataResultAndKey, 1) ) // Set up collect and report as a single ticker with two channels, // this is to allow collection and reporting to be triggered // independently of each other. go func() { t := time.NewTicker(a.reportMetadataInterval) defer func() { t.Stop() close(report) close(collect) close(tickerDone) }() wake := func(c chan<- struct{}) { select { case c <- struct{}{}: default: } } wake(collect) // Start immediately. for { select { case <-ctx.Done(): return case <-t.C: wake(report) wake(collect) } } }() go func() { defer close(collectDone) var ( // We use a custom singleflight that immediately returns if there is already // a goroutine running for a given key. This is to prevent a build-up of // goroutines waiting on Do when the script takes many multiples of // baseInterval to run. flight = trySingleflight{m: map[string]struct{}{}} lastCollectedAtMu sync.RWMutex lastCollectedAts = make(map[string]time.Time) ) for { select { case <-ctx.Done(): return case <-collect: } manifest := a.manifest.Load() if manifest == nil { continue } // If the manifest changes (e.g. on agent reconnect) we need to // purge old cache values to prevent lastCollectedAt from growing // boundlessly. lastCollectedAtMu.Lock() for key := range lastCollectedAts { if slices.IndexFunc(manifest.Metadata, func(md codersdk.WorkspaceAgentMetadataDescription) bool { return md.Key == key }) < 0 { logger.Debug(ctx, "deleting lastCollected key, missing from manifest", slog.F("key", key), ) delete(lastCollectedAts, key) } } lastCollectedAtMu.Unlock() // Spawn a goroutine for each metadata collection, and use a // channel to synchronize the results and avoid both messy // mutex logic and overloading the API. for _, md := range manifest.Metadata { md := md // We send the result to the channel in the goroutine to avoid // sending the same result multiple times. So, we don't care about // the return values. go flight.Do(md.Key, func() { ctx := slog.With(ctx, slog.F("key", md.Key)) lastCollectedAtMu.RLock() collectedAt, ok := lastCollectedAts[md.Key] lastCollectedAtMu.RUnlock() if ok { // If the interval is zero, we assume the user just wants // a single collection at startup, not a spinning loop. if md.Interval == 0 { return } intervalUnit := time.Second // reportMetadataInterval is only less than a second in tests, // so adjust the interval unit for them. if a.reportMetadataInterval < time.Second { intervalUnit = 100 * time.Millisecond } // The last collected value isn't quite stale yet, so we skip it. if collectedAt.Add(time.Duration(md.Interval) * intervalUnit).After(time.Now()) { return } } timeout := md.Timeout if timeout == 0 { if md.Interval != 0 { timeout = md.Interval } else if interval := int64(a.reportMetadataInterval.Seconds()); interval != 0 { // Fallback to the report interval timeout = interval * 3 } else { // If the interval is still 0 (possible if the interval // is less than a second), default to 5. This was // randomly picked. timeout = 5 } } ctxTimeout := time.Duration(timeout) * time.Second ctx, cancel := context.WithTimeout(ctx, ctxTimeout) defer cancel() now := time.Now() select { case <-ctx.Done(): logger.Warn(ctx, "metadata collection timed out", slog.F("timeout", ctxTimeout)) case metadataResults <- metadataResultAndKey{ key: md.Key, result: a.collectMetadata(ctx, md, now), }: lastCollectedAtMu.Lock() lastCollectedAts[md.Key] = now lastCollectedAtMu.Unlock() } }) } } }() // Gather metadata updates and report them once every interval. If a // previous report is in flight, wait for it to complete before // sending a new one. If the network conditions are bad, we won't // benefit from canceling the previous send and starting a new one. var ( updatedMetadata = make(map[string]*codersdk.WorkspaceAgentMetadataResult) reportTimeout = 30 * time.Second reportError = make(chan error, 1) reportInFlight = false aAPI = proto.NewDRPCAgentClient(conn) ) for { select { case <-ctx.Done(): return ctx.Err() case mr := <-metadataResults: // This can overwrite unsent values, but that's fine because // we're only interested about up-to-date values. updatedMetadata[mr.key] = mr.result continue case err := <-reportError: a.logger.Debug(ctx, "batch update metadata complete", slog.Error(err)) if err != nil { return xerrors.Errorf("failed to report metadata: %w", err) } reportInFlight = false case <-report: if len(updatedMetadata) == 0 { continue } if reportInFlight { // If there's already a report in flight, don't send // another one, wait for next tick instead. a.logger.Debug(ctx, "skipped metadata report tick because report is in flight") continue } metadata := make([]*proto.Metadata, 0, len(updatedMetadata)) for key, result := range updatedMetadata { pr := agentsdk.ProtoFromMetadataResult(*result) metadata = append(metadata, &proto.Metadata{ Key: key, Result: pr, }) delete(updatedMetadata, key) } reportInFlight = true go func() { a.logger.Debug(ctx, "batch updating metadata") ctx, cancel := context.WithTimeout(ctx, reportTimeout) defer cancel() _, err := aAPI.BatchUpdateMetadata(ctx, &proto.BatchUpdateMetadataRequest{Metadata: metadata}) reportError <- err }() } } } // reportLifecycle reports the current lifecycle state once. All state // changes are reported in order. func (a *agent) reportLifecycle(ctx context.Context, conn drpc.Conn) error { aAPI := proto.NewDRPCAgentClient(conn) for { select { case <-a.lifecycleUpdate: case <-ctx.Done(): return ctx.Err() } for { a.lifecycleMu.RLock() lastIndex := len(a.lifecycleStates) - 1 report := a.lifecycleStates[a.lifecycleLastReportedIndex] if len(a.lifecycleStates) > a.lifecycleLastReportedIndex+1 { report = a.lifecycleStates[a.lifecycleLastReportedIndex+1] } a.lifecycleMu.RUnlock() if lastIndex == a.lifecycleLastReportedIndex { break } l, err := agentsdk.ProtoFromLifecycle(report) if err != nil { a.logger.Critical(ctx, "failed to convert lifecycle state", slog.F("report", report)) // Skip this report; there is no point retrying. Maybe we can successfully convert the next one? a.lifecycleLastReportedIndex++ continue } payload := &proto.UpdateLifecycleRequest{Lifecycle: l} logger := a.logger.With(slog.F("payload", payload)) logger.Debug(ctx, "reporting lifecycle state") _, err = aAPI.UpdateLifecycle(ctx, payload) if err != nil { return xerrors.Errorf("failed to update lifecycle: %w", err) } logger.Debug(ctx, "successfully reported lifecycle state") a.lifecycleLastReportedIndex++ select { case a.lifecycleReported <- report.State: case <-a.lifecycleReported: a.lifecycleReported <- report.State } if a.lifecycleLastReportedIndex < lastIndex { // Keep reporting until we've sent all messages, we can't // rely on the channel triggering us before the backlog is // consumed. continue } break } } } // setLifecycle sets the lifecycle state and notifies the lifecycle loop. // The state is only updated if it's a valid state transition. func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) { report := agentsdk.PostLifecycleRequest{ State: state, ChangedAt: dbtime.Now(), } a.lifecycleMu.Lock() lastReport := a.lifecycleStates[len(a.lifecycleStates)-1] if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastReport.State) >= slices.Index(codersdk.WorkspaceAgentLifecycleOrder, report.State) { a.logger.Warn(context.Background(), "attempted to set lifecycle state to a previous state", slog.F("last", lastReport), slog.F("current", report)) a.lifecycleMu.Unlock() return } a.lifecycleStates = append(a.lifecycleStates, report) a.logger.Debug(context.Background(), "set lifecycle state", slog.F("current", report), slog.F("last", lastReport)) a.lifecycleMu.Unlock() select { case a.lifecycleUpdate <- struct{}{}: default: } } // fetchServiceBannerLoop fetches the service banner on an interval. It will // not be fetched immediately; the expectation is that it is primed elsewhere // (and must be done before the session actually starts). func (a *agent) fetchServiceBannerLoop(ctx context.Context, conn drpc.Conn) error { aAPI := proto.NewDRPCAgentClient(conn) ticker := time.NewTicker(a.notificationBannersRefreshInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: bannersProto, err := aAPI.GetNotificationBanners(ctx, &proto.GetNotificationBannersRequest{}) if err != nil { if ctx.Err() != nil { return ctx.Err() } a.logger.Error(ctx, "failed to update notification banners", slog.Error(err)) return err } banners := make([]codersdk.BannerConfig, 0, len(bannersProto.NotificationBanners)) for _, bannerProto := range bannersProto.NotificationBanners { banners = append(banners, agentsdk.BannerConfigFromProto(bannerProto)) } a.notificationBanners.Store(&banners) } } } func (a *agent) run() (retErr error) { // This allows the agent to refresh it's token if necessary. // For instance identity this is required, since the instance // may not have re-provisioned, but a new agent ID was created. sessionToken, err := a.exchangeToken(a.hardCtx) if err != nil { return xerrors.Errorf("exchange token: %w", err) } a.sessionToken.Store(&sessionToken) // ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs conn, err := a.client.ConnectRPC(a.hardCtx) if err != nil { return err } defer func() { cErr := conn.Close() if cErr != nil { a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(err)) } }() // A lot of routines need the agent API / tailnet API connection. We run them in their own // goroutines in parallel, but errors in any routine will cause them all to exit so we can // redial the coder server and retry. connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger, conn) connMan.start("init notification banners", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { aAPI := proto.NewDRPCAgentClient(conn) bannersProto, err := aAPI.GetNotificationBanners(ctx, &proto.GetNotificationBannersRequest{}) if err != nil { return xerrors.Errorf("fetch service banner: %w", err) } banners := make([]codersdk.BannerConfig, 0, len(bannersProto.NotificationBanners)) for _, bannerProto := range bannersProto.NotificationBanners { banners = append(banners, agentsdk.BannerConfigFromProto(bannerProto)) } a.notificationBanners.Store(&banners) return nil }, ) // sending logs gets gracefulShutdownBehaviorRemain because we want to send logs generated by // shutdown scripts. connMan.start("send logs", gracefulShutdownBehaviorRemain, func(ctx context.Context, conn drpc.Conn) error { err := a.logSender.SendLoop(ctx, proto.NewDRPCAgentClient(conn)) if xerrors.Is(err, agentsdk.LogLimitExceededError) { // we don't want this error to tear down the API connection and propagate to the // other routines that use the API. The LogSender has already dropped a warning // log, so just return nil here. return nil } return err }) // part of graceful shut down is reporting the final lifecycle states, e.g "ShuttingDown" so the // lifecycle reporting has to be via gracefulShutdownBehaviorRemain connMan.start("report lifecycle", gracefulShutdownBehaviorRemain, a.reportLifecycle) // metadata reporting can cease as soon as we start gracefully shutting down connMan.start("report metadata", gracefulShutdownBehaviorStop, a.reportMetadata) // channels to sync goroutines below // handle manifest // | // manifestOK // | | // | +----------------------+ // V | // app health reporter | // V // create or update network // | // networkOK // | // coordination <--------------------------+ // derp map subscriber <----------------+ // stats report loop <---------------+ networkOK := newCheckpoint(a.logger) manifestOK := newCheckpoint(a.logger) connMan.start("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK)) connMan.start("app health reporter", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { if err := manifestOK.wait(ctx); err != nil { return xerrors.Errorf("no manifest: %w", err) } manifest := a.manifest.Load() NewWorkspaceAppHealthReporter( a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)), )(ctx) return nil }) connMan.start("create or update network", gracefulShutdownBehaviorStop, a.createOrUpdateNetwork(manifestOK, networkOK)) connMan.start("coordination", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { if err := networkOK.wait(ctx); err != nil { return xerrors.Errorf("no network: %w", err) } return a.runCoordinator(ctx, conn, a.network) }, ) connMan.start("derp map subscriber", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { if err := networkOK.wait(ctx); err != nil { return xerrors.Errorf("no network: %w", err) } return a.runDERPMapSubscriber(ctx, conn, a.network) }) connMan.start("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop) connMan.start("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { if err := networkOK.wait(ctx); err != nil { return xerrors.Errorf("no network: %w", err) } return a.statsReporter.reportLoop(ctx, proto.NewDRPCAgentClient(conn)) }) return connMan.wait() } // handleManifest returns a function that fetches and processes the manifest func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, conn drpc.Conn) error { return func(ctx context.Context, conn drpc.Conn) error { var ( sentResult = false err error ) defer func() { if !sentResult { manifestOK.complete(err) } }() aAPI := proto.NewDRPCAgentClient(conn) mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{}) if err != nil { return xerrors.Errorf("fetch metadata: %w", err) } a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp)) manifest, err := agentsdk.ManifestFromProto(mp) if err != nil { a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err)) return xerrors.Errorf("convert manifest: %w", err) } if manifest.AgentID == uuid.Nil { return xerrors.New("nil agentID returned by manifest") } a.client.RewriteDERPMap(manifest.DERPMap) // Expand the directory and send it back to coderd so external // applications that rely on the directory can use it. // // An example is VS Code Remote, which must know the directory // before initializing a connection. manifest.Directory, err = expandDirectory(manifest.Directory) if err != nil { return xerrors.Errorf("expand directory: %w", err) } subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems) if err != nil { a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err)) return xerrors.Errorf("failed to convert subsystems: %w", err) } _, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{ Version: buildinfo.Version(), ExpandedDirectory: manifest.Directory, Subsystems: subsys, }}) if err != nil { return xerrors.Errorf("update workspace agent startup: %w", err) } oldManifest := a.manifest.Swap(&manifest) manifestOK.complete(nil) sentResult = true // The startup script should only execute on the first run! if oldManifest == nil { a.setLifecycle(codersdk.WorkspaceAgentLifecycleStarting) // Perform overrides early so that Git auth can work even if users // connect to a workspace that is not yet ready. We don't run this // concurrently with the startup script to avoid conflicts between // them. if manifest.GitAuthConfigs > 0 { // If this fails, we should consider surfacing the error in the // startup log and setting the lifecycle state to be "start_error" // (after startup script completion), but for now we'll just log it. err := gitauth.OverrideVSCodeConfigs(a.filesystem) if err != nil { a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err)) } } err = a.scriptRunner.Init(manifest.Scripts) if err != nil { return xerrors.Errorf("init script runner: %w", err) } err = a.trackGoroutine(func() { start := time.Now() // here we use the graceful context because the script runner is not directly tied // to the agent API. err := a.scriptRunner.Execute(a.gracefulCtx, func(script codersdk.WorkspaceAgentScript) bool { return script.RunOnStart }) // Measure the time immediately after the script has finished dur := time.Since(start).Seconds() if err != nil { a.logger.Warn(ctx, "startup script(s) failed", slog.Error(err)) if errors.Is(err, agentscripts.ErrTimeout) { a.setLifecycle(codersdk.WorkspaceAgentLifecycleStartTimeout) } else { a.setLifecycle(codersdk.WorkspaceAgentLifecycleStartError) } } else { a.setLifecycle(codersdk.WorkspaceAgentLifecycleReady) } label := "false" if err == nil { label = "true" } a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur) a.scriptRunner.StartCron() }) if err != nil { return xerrors.Errorf("track conn goroutine: %w", err) } } return nil } } // createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates // the tailnet using the information in the manifest func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(context.Context, drpc.Conn) error { return func(ctx context.Context, _ drpc.Conn) (retErr error) { if err := manifestOK.wait(ctx); err != nil { return xerrors.Errorf("no manifest: %w", err) } var err error defer func() { networkOK.complete(retErr) }() manifest := a.manifest.Load() a.closeMutex.Lock() network := a.network a.closeMutex.Unlock() if network == nil { // use the graceful context here, because creating the tailnet is not itself tied to the // agent API. network, err = a.createTailnet(a.gracefulCtx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections) if err != nil { return xerrors.Errorf("create tailnet: %w", err) } a.closeMutex.Lock() // Re-check if agent was closed while initializing the network. closed := a.isClosed() if !closed { a.network = network a.statsReporter = newStatsReporter(a.logger, network, a) } a.closeMutex.Unlock() if closed { _ = network.Close() return xerrors.New("agent is closed") } } else { // Update the wireguard IPs if the agent ID changed. err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) if err != nil { a.logger.Error(a.gracefulCtx, "update tailnet addresses", slog.Error(err)) } // Update the DERP map, force WebSocket setting and allow/disallow // direct connections. network.SetDERPMap(manifest.DERPMap) network.SetDERPForceWebSockets(manifest.DERPForceWebSockets) network.SetBlockEndpoints(manifest.DisableDirectConnections) } return nil } } // updateCommandEnv updates the provided command environment with the // following set of environment variables: // - Predefined workspace environment variables // - Environment variables currently set (overriding predefined) // - Environment variables passed via the agent manifest (overriding predefined and current) // - Agent-level environment variables (overriding all) func (a *agent) updateCommandEnv(current []string) (updated []string, err error) { manifest := a.manifest.Load() if manifest == nil { return nil, xerrors.Errorf("no manifest") } executablePath, err := os.Executable() if err != nil { return nil, xerrors.Errorf("getting os executable: %w", err) } unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") // Define environment variables that should be set for all commands, // and then merge them with the current environment. envs := map[string]string{ // Set env vars indicating we're inside a Coder workspace. "CODER": "true", "CODER_WORKSPACE_NAME": manifest.WorkspaceName, "CODER_WORKSPACE_AGENT_NAME": manifest.AgentName, // Specific Coder subcommands require the agent token exposed! "CODER_AGENT_TOKEN": *a.sessionToken.Load(), // Git on Windows resolves with UNIX-style paths. // If using backslashes, it's unable to find the executable. "GIT_SSH_COMMAND": fmt.Sprintf("%s gitssh --", unixExecutablePath), // Hide Coder message on code-server's "Getting Started" page "CS_DISABLE_GETTING_STARTED_OVERRIDE": "true", } // This adds the ports dialog to code-server that enables // proxying a port dynamically. // If this is empty string, do not set anything. Code-server auto defaults // using its basepath to construct a path based port proxy. if manifest.VSCodePortProxyURI != "" { envs["VSCODE_PROXY_URI"] = manifest.VSCodePortProxyURI } // Allow any of the current env to override what we defined above. for _, env := range current { parts := strings.SplitN(env, "=", 2) if len(parts) != 2 { continue } if _, ok := envs[parts[0]]; !ok { envs[parts[0]] = parts[1] } } // Load environment variables passed via the agent manifest. // These override all variables we manually specify. for k, v := range manifest.EnvironmentVariables { // Expanding environment variables allows for customization // of the $PATH, among other variables. Customers can prepend // or append to the $PATH, so allowing expand is required! envs[k] = os.ExpandEnv(v) } // Agent-level environment variables should take over all. This is // used for setting agent-specific variables like CODER_AGENT_TOKEN // and GIT_ASKPASS. for k, v := range a.environmentVariables { envs[k] = v } // Prepend the agent script bin directory to the PATH // (this is where Coder modules place their binaries). if _, ok := envs["PATH"]; !ok { envs["PATH"] = os.Getenv("PATH") } envs["PATH"] = fmt.Sprintf("%s%c%s", a.scriptRunner.ScriptBinDir(), filepath.ListSeparator, envs["PATH"]) for k, v := range envs { updated = append(updated, fmt.Sprintf("%s=%s", k, v)) } return updated, nil } func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix { if len(a.addresses) == 0 { return []netip.Prefix{ // This is the IP that should be used primarily. netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128), // We also listen on the legacy codersdk.WorkspaceAgentIP. This // allows for a transition away from wsconncache. netip.PrefixFrom(workspacesdk.AgentIP, 128), } } return a.addresses } func (a *agent) trackGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() if a.isClosed() { return xerrors.New("track conn goroutine: agent is closed") } a.closeWaitGroup.Add(1) go func() { defer a.closeWaitGroup.Done() fn() }() return nil } func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, derpForceWebSockets, disableDirectConnections bool) (_ *tailnet.Conn, err error) { network, err := tailnet.NewConn(&tailnet.Options{ ID: agentID, Addresses: a.wireguardAddresses(agentID), DERPMap: derpMap, DERPForceWebSockets: derpForceWebSockets, Logger: a.logger.Named("net.tailnet"), ListenPort: a.tailnetListenPort, BlockEndpoints: disableDirectConnections, }) if err != nil { return nil, xerrors.Errorf("create tailnet: %w", err) } defer func() { if err != nil { network.Close() } }() sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort)) if err != nil { return nil, xerrors.Errorf("listen on the ssh port: %w", err) } defer func() { if err != nil { _ = sshListener.Close() } }() if err = a.trackGoroutine(func() { _ = a.sshServer.Serve(sshListener) }); err != nil { return nil, err } reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort)) if err != nil { return nil, xerrors.Errorf("listen for reconnecting pty: %w", err) } defer func() { if err != nil { _ = reconnectingPTYListener.Close() } }() if err = a.trackGoroutine(func() { logger := a.logger.Named("reconnecting-pty") var wg sync.WaitGroup for { conn, err := reconnectingPTYListener.Accept() if err != nil { if !a.isClosed() { logger.Debug(ctx, "accept pty failed", slog.Error(err)) } break } clog := logger.With( slog.F("remote", conn.RemoteAddr().String()), slog.F("local", conn.LocalAddr().String())) clog.Info(ctx, "accepted conn") wg.Add(1) closed := make(chan struct{}) go func() { select { case <-closed: case <-a.hardCtx.Done(): _ = conn.Close() } wg.Done() }() go func() { defer close(closed) // This cannot use a JSON decoder, since that can // buffer additional data that is required for the PTY. rawLen := make([]byte, 2) _, err = conn.Read(rawLen) if err != nil { return } length := binary.LittleEndian.Uint16(rawLen) data := make([]byte, length) _, err = conn.Read(data) if err != nil { return } var msg workspacesdk.AgentReconnectingPTYInit err = json.Unmarshal(data, &msg) if err != nil { logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data)) return } _ = a.handleReconnectingPTY(ctx, clog, msg, conn) }() } wg.Wait() }); err != nil { return nil, err } speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("listen for speedtest: %w", err) } defer func() { if err != nil { _ = speedtestListener.Close() } }() if err = a.trackGoroutine(func() { var wg sync.WaitGroup for { conn, err := speedtestListener.Accept() if err != nil { if !a.isClosed() { a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err)) } break } clog := a.logger.Named("speedtest").With( slog.F("remote", conn.RemoteAddr().String()), slog.F("local", conn.LocalAddr().String())) clog.Info(ctx, "accepted conn") wg.Add(1) closed := make(chan struct{}) go func() { select { case <-closed: case <-a.hardCtx.Done(): _ = conn.Close() } wg.Done() }() go func() { defer close(closed) sErr := speedtest.ServeConn(conn) if sErr != nil { clog.Error(ctx, "test ended with error", slog.Error(sErr)) return } clog.Info(ctx, "test ended") }() } wg.Wait() }); err != nil { return nil, err } apiListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("api listener: %w", err) } defer func() { if err != nil { _ = apiListener.Close() } }() if err = a.trackGoroutine(func() { defer apiListener.Close() server := &http.Server{ Handler: a.apiHandler(), ReadTimeout: 20 * time.Second, ReadHeaderTimeout: 20 * time.Second, WriteTimeout: 20 * time.Second, ErrorLog: slog.Stdlib(ctx, a.logger.Named("http_api_server"), slog.LevelInfo), } go func() { select { case <-ctx.Done(): case <-a.hardCtx.Done(): } _ = server.Close() }() err := server.Serve(apiListener) if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") { a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err)) } }); err != nil { return nil, err } return network, nil } // runCoordinator runs a coordinator and returns whether a reconnect // should occur. func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error { defer a.logger.Debug(ctx, "disconnected from coordination RPC") tClient := tailnetproto.NewDRPCTailnetClient(conn) // we run the RPC on the hardCtx so that we have a chance to send the disconnect message if we // gracefully shut down. coordinate, err := tClient.Coordinate(a.hardCtx) if err != nil { return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err) } defer func() { cErr := coordinate.Close() if cErr != nil { a.logger.Debug(ctx, "error closing Coordinate client", slog.Error(err)) } }() a.logger.Info(ctx, "connected to coordination RPC") // This allows the Close() routine to wait for the coordinator to gracefully disconnect. a.closeMutex.Lock() if a.isClosed() { return nil } disconnected := make(chan struct{}) a.coordDisconnected = disconnected defer close(disconnected) a.closeMutex.Unlock() coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) errCh := make(chan error, 1) go func() { defer close(errCh) select { case <-ctx.Done(): err := coordination.Close() if err != nil { a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) } return case err := <-coordination.Error(): errCh <- err } }() return <-errCh } // runDERPMapSubscriber runs a coordinator and returns if a reconnect should occur. func (a *agent) runDERPMapSubscriber(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error { defer a.logger.Debug(ctx, "disconnected from derp map RPC") ctx, cancel := context.WithCancel(ctx) defer cancel() tClient := tailnetproto.NewDRPCTailnetClient(conn) stream, err := tClient.StreamDERPMaps(ctx, &tailnetproto.StreamDERPMapsRequest{}) if err != nil { return xerrors.Errorf("stream DERP Maps: %w", err) } defer func() { cErr := stream.Close() if cErr != nil { a.logger.Debug(ctx, "error closing DERPMap stream", slog.Error(err)) } }() a.logger.Info(ctx, "connected to derp map RPC") for { dmp, err := stream.Recv() if err != nil { return xerrors.Errorf("recv DERPMap error: %w", err) } dm := tailnet.DERPMapFromProto(dmp) a.client.RewriteDERPMap(dm) network.SetDERPMap(dm) } } func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg workspacesdk.AgentReconnectingPTYInit, conn net.Conn) (retErr error) { defer conn.Close() a.metrics.connectionsTotal.Add(1) a.connCountReconnectingPTY.Add(1) defer a.connCountReconnectingPTY.Add(-1) connectionID := uuid.NewString() connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID)) connLogger.Debug(ctx, "starting handler") defer func() { if err := retErr; err != nil { a.closeMutex.Lock() closed := a.isClosed() a.closeMutex.Unlock() // If the agent is closed, we don't want to // log this as an error since it's expected. if closed { connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err)) } else { connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err)) } } connLogger.Info(ctx, "reconnecting pty connection closed") }() var rpty reconnectingpty.ReconnectingPTY sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1) // On store, reserve this ID to prevent multiple concurrent new connections. waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected) if ok { close(sendConnected) // Unused. connLogger.Debug(ctx, "connecting to existing reconnecting pty") c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY) if !ok { return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady) } rpty, ok = <-c if !ok || rpty == nil { return xerrors.Errorf("reconnecting pty closed before connection") } c <- rpty // Put it back for the next reconnect. } else { connLogger.Debug(ctx, "creating new reconnecting pty") connected := false defer func() { if !connected && retErr != nil { a.reconnectingPTYs.Delete(msg.ID) close(sendConnected) } }() // Empty command will default to the users shell! cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil) if err != nil { a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1) return xerrors.Errorf("create command: %w", err) } rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{ Timeout: a.reconnectingPTYTimeout, Metrics: a.metrics.reconnectingPTYErrors, }, logger.With(slog.F("message_id", msg.ID))) if err = a.trackGoroutine(func() { rpty.Wait() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { rpty.Close(err) return xerrors.Errorf("start routine: %w", err) } connected = true sendConnected <- rpty } return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger) } // Collect collects additional stats from the agent func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats { a.logger.Debug(context.Background(), "computing stats report") stats := &proto.Stats{ ConnectionCount: int64(len(networkStats)), ConnectionsByProto: map[string]int64{}, } for conn, counts := range networkStats { stats.ConnectionsByProto[conn.Proto.String()]++ stats.RxBytes += int64(counts.RxBytes) stats.RxPackets += int64(counts.RxPackets) stats.TxBytes += int64(counts.TxBytes) stats.TxPackets += int64(counts.TxPackets) } // The count of active sessions. sshStats := a.sshServer.ConnStats() stats.SessionCountSsh = sshStats.Sessions stats.SessionCountVscode = sshStats.VSCode stats.SessionCountJetbrains = sshStats.JetBrains stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load() // Compute the median connection latency! a.logger.Debug(ctx, "starting peer latency measurement for stats") var wg sync.WaitGroup var mu sync.Mutex status := a.network.Status() durations := []float64{} pingCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) defer cancelFunc() for nodeID, peer := range status.Peer { if !peer.Active { continue } addresses, found := a.network.NodeAddresses(nodeID) if !found { continue } if len(addresses) == 0 { continue } wg.Add(1) go func() { defer wg.Done() duration, _, _, err := a.network.Ping(pingCtx, addresses[0].Addr()) if err != nil { return } mu.Lock() defer mu.Unlock() durations = append(durations, float64(duration.Microseconds())) }() } wg.Wait() sort.Float64s(durations) durationsLength := len(durations) if durationsLength == 0 { stats.ConnectionMedianLatencyMs = -1 } else if durationsLength%2 == 0 { stats.ConnectionMedianLatencyMs = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2 } else { stats.ConnectionMedianLatencyMs = durations[durationsLength/2] } // Convert from microseconds to milliseconds. stats.ConnectionMedianLatencyMs /= 1000 // Collect agent metrics. // Agent metrics are changing all the time, so there is no need to perform // reflect.DeepEqual to see if stats should be transferred. metricsCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) defer cancelFunc() a.logger.Debug(ctx, "collecting agent metrics for stats") stats.Metrics = a.collectMetrics(metricsCtx) return stats } var prioritizedProcs = []string{"coder agent"} func (a *agent) manageProcessPriorityUntilGracefulShutdown() { // process priority can stop as soon as we are gracefully shutting down ctx := a.gracefulCtx defer func() { if r := recover(); r != nil { a.logger.Critical(ctx, "recovered from panic", slog.F("panic", r), slog.F("stack", string(debug.Stack())), ) } }() if val := a.environmentVariables[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" { a.logger.Debug(ctx, "process priority not enabled, agent will not manage process niceness/oom_score_adj ", slog.F("env_var", EnvProcPrioMgmt), slog.F("value", val), slog.F("goos", runtime.GOOS), ) return } if a.processManagementTick == nil { ticker := time.NewTicker(time.Second) defer ticker.Stop() a.processManagementTick = ticker.C } oomScore := unsetOOMScore if scoreStr, ok := a.environmentVariables[EnvProcOOMScore]; ok { score, err := strconv.Atoi(strings.TrimSpace(scoreStr)) if err == nil && score >= -1000 && score <= 1000 { oomScore = score } else { a.logger.Error(ctx, "invalid oom score", slog.F("min_value", -1000), slog.F("max_value", 1000), slog.F("value", scoreStr), ) } } debouncer := &logDebouncer{ logger: a.logger, messages: map[string]time.Time{}, interval: time.Minute, } for { procs, err := a.manageProcessPriority(ctx, debouncer, oomScore) // Avoid spamming the logs too often. if err != nil { debouncer.Error(ctx, "manage process priority", slog.Error(err), ) } if a.modifiedProcs != nil { a.modifiedProcs <- procs } select { case <-a.processManagementTick: case <-ctx.Done(): return } } } // unsetOOMScore is set to an invalid OOM score to imply an unset value. const unsetOOMScore = 1001 func (a *agent) manageProcessPriority(ctx context.Context, debouncer *logDebouncer, oomScore int) ([]*agentproc.Process, error) { const ( niceness = 10 ) // We fetch the agent score each time because it's possible someone updates the // value after it is started. agentScore, err := a.getAgentOOMScore() if err != nil { agentScore = unsetOOMScore } if oomScore == unsetOOMScore && agentScore != unsetOOMScore { // If the child score has not been explicitly specified we should // set it to a score relative to the agent score. oomScore = childOOMScore(agentScore) } procs, err := agentproc.List(a.filesystem, a.syscaller) if err != nil { return nil, xerrors.Errorf("list: %w", err) } modProcs := []*agentproc.Process{} for _, proc := range procs { containsFn := func(e string) bool { contains := strings.Contains(proc.Cmd(), e) return contains } // If the process is prioritized we should adjust // it's oom_score_adj and avoid lowering its niceness. if slices.ContainsFunc(prioritizedProcs, containsFn) { continue } score, niceErr := proc.Niceness(a.syscaller) if niceErr != nil && !xerrors.Is(niceErr, os.ErrPermission) { debouncer.Warn(ctx, "unable to get proc niceness", slog.F("cmd", proc.Cmd()), slog.F("pid", proc.PID), slog.Error(niceErr), ) continue } // We only want processes that don't have a nice value set // so we don't override user nice values. // Getpriority actually returns priority for the nice value // which is niceness + 20, so here 20 = a niceness of 0 (aka unset). if score != 20 { // We don't log here since it can get spammy continue } if niceErr == nil { err := proc.SetNiceness(a.syscaller, niceness) if err != nil && !xerrors.Is(err, os.ErrPermission) { debouncer.Warn(ctx, "unable to set proc niceness", slog.F("cmd", proc.Cmd()), slog.F("pid", proc.PID), slog.F("niceness", niceness), slog.Error(err), ) } } // If the oom score is valid and it's not already set and isn't a custom value set by another process then it's ok to update it. if oomScore != unsetOOMScore && oomScore != proc.OOMScoreAdj && !isCustomOOMScore(agentScore, proc) { oomScoreStr := strconv.Itoa(oomScore) err := afero.WriteFile(a.filesystem, fmt.Sprintf("/proc/%d/oom_score_adj", proc.PID), []byte(oomScoreStr), 0o644) if err != nil && !xerrors.Is(err, os.ErrPermission) { debouncer.Warn(ctx, "unable to set oom_score_adj", slog.F("cmd", proc.Cmd()), slog.F("pid", proc.PID), slog.F("score", oomScoreStr), slog.Error(err), ) } } modProcs = append(modProcs, proc) } return modProcs, nil } // isClosed returns whether the API is closed or not. func (a *agent) isClosed() bool { return a.hardCtx.Err() != nil } func (a *agent) requireNetwork() (*tailnet.Conn, bool) { a.closeMutex.Lock() defer a.closeMutex.Unlock() return a.network, a.network != nil } func (a *agent) HandleHTTPDebugMagicsock(w http.ResponseWriter, r *http.Request) { network, ok := a.requireNetwork() if !ok { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("network is not ready yet")) return } network.MagicsockServeHTTPDebug(w, r) } func (a *agent) HandleHTTPMagicsockDebugLoggingState(w http.ResponseWriter, r *http.Request) { state := chi.URLParam(r, "state") stateBool, err := strconv.ParseBool(state) if err != nil { w.WriteHeader(http.StatusBadRequest) _, _ = fmt.Fprintf(w, "invalid state %q, must be a boolean", state) return } network, ok := a.requireNetwork() if !ok { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("network is not ready yet")) return } network.MagicsockSetDebugLoggingEnabled(stateBool) a.logger.Info(r.Context(), "updated magicsock debug logging due to debug request", slog.F("new_state", stateBool)) w.WriteHeader(http.StatusOK) _, _ = fmt.Fprintf(w, "updated magicsock debug logging to %v", stateBool) } func (a *agent) HandleHTTPDebugManifest(w http.ResponseWriter, r *http.Request) { sdkManifest := a.manifest.Load() if sdkManifest == nil { a.logger.Error(r.Context(), "no manifest in-memory") w.WriteHeader(http.StatusInternalServerError) _, _ = fmt.Fprintf(w, "no manifest in-memory") return } w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(sdkManifest); err != nil { a.logger.Error(a.hardCtx, "write debug manifest", slog.Error(err)) } } func (a *agent) HandleHTTPDebugLogs(w http.ResponseWriter, r *http.Request) { logPath := filepath.Join(a.logDir, "coder-agent.log") f, err := os.Open(logPath) if err != nil { a.logger.Error(r.Context(), "open agent log file", slog.Error(err), slog.F("path", logPath)) w.WriteHeader(http.StatusInternalServerError) _, _ = fmt.Fprintf(w, "could not open log file: %s", err) return } defer f.Close() // Limit to 10MB. w.WriteHeader(http.StatusOK) _, err = io.Copy(w, io.LimitReader(f, 10*1024*1024)) if err != nil && !errors.Is(err, io.EOF) { a.logger.Error(r.Context(), "read agent log file", slog.Error(err)) return } } func (a *agent) HTTPDebug() http.Handler { r := chi.NewRouter() r.Get("/debug/logs", a.HandleHTTPDebugLogs) r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock) r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState) r.Get("/debug/manifest", a.HandleHTTPDebugManifest) r.NotFound(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) _, _ = w.Write([]byte("404 not found")) }) return r } func (a *agent) Close() error { a.closeMutex.Lock() defer a.closeMutex.Unlock() if a.isClosed() { return nil } a.logger.Info(a.hardCtx, "shutting down agent") a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown) // Attempt to gracefully shut down all active SSH connections and // stop accepting new ones. err := a.sshServer.Shutdown(a.hardCtx) if err != nil { a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err)) } err = a.sshServer.Close() if err != nil { a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err)) } // wait for SSH to shut down before the general graceful cancel, because // this triggers a disconnect in the tailnet layer, telling all clients to // shut down their wireguard tunnels to us. If SSH sessions are still up, // they might hang instead of being closed. a.gracefulCancel() lifecycleState := codersdk.WorkspaceAgentLifecycleOff err = a.scriptRunner.Execute(a.hardCtx, func(script codersdk.WorkspaceAgentScript) bool { return script.RunOnStop }) if err != nil { a.logger.Warn(a.hardCtx, "shutdown script(s) failed", slog.Error(err)) if errors.Is(err, agentscripts.ErrTimeout) { lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownTimeout } else { lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError } } a.setLifecycle(lifecycleState) err = a.scriptRunner.Close() if err != nil { a.logger.Error(a.hardCtx, "script runner close", slog.Error(err)) } // Wait for the graceful shutdown to complete, but don't wait forever so // that we don't break user expectations. go func() { defer a.hardCancel() select { case <-a.hardCtx.Done(): case <-time.After(5 * time.Second): } }() // Wait for lifecycle to be reported lifecycleWaitLoop: for { select { case <-a.hardCtx.Done(): a.logger.Warn(context.Background(), "failed to report final lifecycle state") break lifecycleWaitLoop case s := <-a.lifecycleReported: if s == lifecycleState { a.logger.Debug(context.Background(), "reported final lifecycle state") break lifecycleWaitLoop } } } // Wait for graceful disconnect from the Coordinator RPC select { case <-a.hardCtx.Done(): a.logger.Warn(context.Background(), "timed out waiting for Coordinator RPC disconnect") case <-a.coordDisconnected: a.logger.Debug(context.Background(), "coordinator RPC disconnected") } // Wait for logs to be sent err = a.logSender.WaitUntilEmpty(a.hardCtx) if err != nil { a.logger.Warn(context.Background(), "timed out waiting for all logs to be sent", slog.Error(err)) } a.hardCancel() if a.network != nil { _ = a.network.Close() } a.closeWaitGroup.Wait() return nil } // userHomeDir returns the home directory of the current user, giving // priority to the $HOME environment variable. func userHomeDir() (string, error) { // First we check the environment. homedir, err := os.UserHomeDir() if err == nil { return homedir, nil } // As a fallback, we try the user information. u, err := user.Current() if err != nil { return "", xerrors.Errorf("current user: %w", err) } return u.HomeDir, nil } // expandDirectory converts a directory path to an absolute path. // It primarily resolves the home directory and any environment // variables that may be set func expandDirectory(dir string) (string, error) { if dir == "" { return "", nil } if dir[0] == '~' { home, err := userHomeDir() if err != nil { return "", err } dir = filepath.Join(home, dir[1:]) } dir = os.ExpandEnv(dir) if !filepath.IsAbs(dir) { home, err := userHomeDir() if err != nil { return "", err } dir = filepath.Join(home, dir) } return dir, nil } // EnvAgentSubsystem is the environment variable used to denote the // specialized environment in which the agent is running // (e.g. envbox, envbuilder). const EnvAgentSubsystem = "CODER_AGENT_SUBSYSTEM" // eitherContext returns a context that is canceled when either context ends. func eitherContext(a, b context.Context) context.Context { ctx, cancel := context.WithCancel(a) go func() { defer cancel() select { case <-a.Done(): case <-b.Done(): } }() return ctx } type gracefulShutdownBehavior int const ( gracefulShutdownBehaviorStop gracefulShutdownBehavior = iota gracefulShutdownBehaviorRemain ) type apiConnRoutineManager struct { logger slog.Logger conn drpc.Conn eg *errgroup.Group stopCtx context.Context remainCtx context.Context } func newAPIConnRoutineManager(gracefulCtx, hardCtx context.Context, logger slog.Logger, conn drpc.Conn) *apiConnRoutineManager { // routines that remain in operation during graceful shutdown use the remainCtx. They'll still // exit if the errgroup hits an error, which usually means a problem with the conn. eg, remainCtx := errgroup.WithContext(hardCtx) // routines that stop operation during graceful shutdown use the stopCtx, which ends when the // first of remainCtx or gracefulContext ends (an error or start of graceful shutdown). // // +------------------------------------------+ // | hardCtx | // | +------------------------------------+ | // | | stopCtx | | // | | +--------------+ +--------------+ | | // | | | remainCtx | | gracefulCtx | | | // | | +--------------+ +--------------+ | | // | +------------------------------------+ | // +------------------------------------------+ stopCtx := eitherContext(remainCtx, gracefulCtx) return &apiConnRoutineManager{ logger: logger, conn: conn, eg: eg, stopCtx: stopCtx, remainCtx: remainCtx, } } func (a *apiConnRoutineManager) start(name string, b gracefulShutdownBehavior, f func(context.Context, drpc.Conn) error) { logger := a.logger.With(slog.F("name", name)) var ctx context.Context switch b { case gracefulShutdownBehaviorStop: ctx = a.stopCtx case gracefulShutdownBehaviorRemain: ctx = a.remainCtx default: panic("unknown behavior") } a.eg.Go(func() error { logger.Debug(ctx, "starting routine") err := f(ctx, a.conn) if xerrors.Is(err, context.Canceled) && ctx.Err() != nil { logger.Debug(ctx, "swallowing context canceled") // Don't propagate context canceled errors to the error group, because we don't want the // graceful context being canceled to halt the work of routines with // gracefulShutdownBehaviorRemain. Note that we check both that the error is // context.Canceled and that *our* context is currently canceled, because when Coderd // unilaterally closes the API connection (for example if the build is outdated), it can // sometimes show up as context.Canceled in our RPC calls. return nil } logger.Debug(ctx, "routine exited", slog.Error(err)) if err != nil { return xerrors.Errorf("error in routine %s: %w", name, err) } return nil }) } func (a *apiConnRoutineManager) wait() error { return a.eg.Wait() } func PrometheusMetricsHandler(prometheusRegistry *prometheus.Registry, logger slog.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") // Based on: https://github.com/tailscale/tailscale/blob/280255acae604796a1113861f5a84e6fa2dc6121/ipn/localapi/localapi.go#L489 clientmetric.WritePrometheusExpositionFormat(w) metricFamilies, err := prometheusRegistry.Gather() if err != nil { logger.Error(context.Background(), "prometheus handler failed to gather metric families", slog.Error(err)) return } for _, metricFamily := range metricFamilies { _, err = expfmt.MetricFamilyToText(w, metricFamily) if err != nil { logger.Error(context.Background(), "expfmt.MetricFamilyToText failed", slog.Error(err)) return } } }) } // childOOMScore returns the oom_score_adj for a child process. It is based // on the oom_score_adj of the agent process. func childOOMScore(agentScore int) int { // If the agent has a negative oom_score_adj, we set the child to 0 // so it's treated like every other process. if agentScore < 0 { return 0 } // If the agent is already almost at the maximum then set it to the max. if agentScore >= 998 { return 1000 } // If the agent oom_score_adj is >=0, we set the child to slightly // less than the maximum. If users want a different score they set it // directly. return 998 } func (a *agent) getAgentOOMScore() (int, error) { scoreStr, err := afero.ReadFile(a.filesystem, "/proc/self/oom_score_adj") if err != nil { return 0, xerrors.Errorf("read file: %w", err) } score, err := strconv.Atoi(strings.TrimSpace(string(scoreStr))) if err != nil { return 0, xerrors.Errorf("parse int: %w", err) } return score, nil } // isCustomOOMScore checks to see if the oom_score_adj is not a value that would // originate from an agent-spawned process. func isCustomOOMScore(agentScore int, process *agentproc.Process) bool { score := process.OOMScoreAdj return agentScore != score && score != 1000 && score != 0 && score != 998 } // logDebouncer skips writing a log for a particular message if // it's been emitted within the given interval duration. // It's a shoddy implementation used in one spot that should be replaced at // some point. type logDebouncer struct { logger slog.Logger messages map[string]time.Time interval time.Duration } func (l *logDebouncer) Warn(ctx context.Context, msg string, fields ...any) { l.log(ctx, slog.LevelWarn, msg, fields...) } func (l *logDebouncer) Error(ctx context.Context, msg string, fields ...any) { l.log(ctx, slog.LevelError, msg, fields...) } func (l *logDebouncer) log(ctx context.Context, level slog.Level, msg string, fields ...any) { // This (bad) implementation assumes you wouldn't reuse the same msg // for different levels. if last, ok := l.messages[msg]; ok && time.Since(last) < l.interval { return } switch level { case slog.LevelWarn: l.logger.Warn(ctx, msg, fields...) case slog.LevelError: l.logger.Error(ctx, msg, fields...) } l.messages[msg] = time.Now() }