package cli import ( "context" "fmt" "io" "net/http" "net/http/pprof" "net/url" "os" "os/signal" "path/filepath" "runtime" "strconv" "sync" "time" "cloud.google.com/go/compute/metadata" "golang.org/x/xerrors" "gopkg.in/natefinch/lumberjack.v2" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" "github.com/coder/coder/agent" "github.com/coder/coder/agent/reaper" "github.com/coder/coder/buildinfo" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/codersdk/agentsdk" ) func (r *RootCmd) workspaceAgent() *clibase.Cmd { var ( auth string logDir string pprofAddress string noReap bool sshMaxTimeout time.Duration ) cmd := &clibase.Cmd{ Use: "agent", Short: `Starts the Coder workspace agent.`, // This command isn't useful to manually execute. Hidden: true, Handler: func(inv *clibase.Invocation) error { ctx, cancel := context.WithCancel(inv.Context()) defer cancel() agentPorts := map[int]string{} isLinux := runtime.GOOS == "linux" // Spawn a reaper so that we don't accumulate a ton // of zombie processes. if reaper.IsInitProcess() && !noReap && isLinux { logWriter := &lumberjack.Logger{ Filename: filepath.Join(logDir, "coder-agent-init.log"), MaxSize: 5, // MB } defer logWriter.Close() logger := slog.Make(sloghuman.Sink(inv.Stderr), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug) logger.Info(ctx, "spawning reaper process") // Do not start a reaper on the child process. It's important // to do this else we fork bomb ourselves. args := append(os.Args, "--no-reap") err := reaper.ForkReap( reaper.WithExecArgs(args...), reaper.WithCatchSignals(InterruptSignals...), ) if err != nil { logger.Error(ctx, "failed to reap", slog.Error(err)) return xerrors.Errorf("fork reap: %w", err) } logger.Info(ctx, "reaper process exiting") return nil } // Handle interrupt signals to allow for graceful shutdown, // note that calling stopNotify disables the signal handler // and the next interrupt will terminate the program (you // probably want cancel instead). // // Note that we don't want to handle these signals in the // process that runs as PID 1, that's why we do this after // the reaper forked. ctx, stopNotify := signal.NotifyContext(ctx, InterruptSignals...) defer stopNotify() // dumpHandler does signal handling, so we call it after the // reaper. go dumpHandler(ctx) ljLogger := &lumberjack.Logger{ Filename: filepath.Join(logDir, "coder-agent.log"), MaxSize: 5, // MB } defer ljLogger.Close() logWriter := &closeWriter{w: ljLogger} defer logWriter.Close() logger := slog.Make(sloghuman.Sink(inv.Stderr), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug) version := buildinfo.Version() logger.Info(ctx, "starting agent", slog.F("url", r.agentURL), slog.F("auth", auth), slog.F("version", version), ) client := agentsdk.New(r.agentURL) client.SDK.Logger = logger // Set a reasonable timeout so requests can't hang forever! // The timeout needs to be reasonably long, because requests // with large payloads can take a bit. e.g. startup scripts // may take a while to insert. client.SDK.HTTPClient.Timeout = 30 * time.Second // Enable pprof handler // This prevents the pprof import from being accidentally deleted. _ = pprof.Handler pprofSrvClose := serveHandler(ctx, logger, nil, pprofAddress, "pprof") defer pprofSrvClose() // Do a best effort here. If this fails, it's not a big deal. if port, err := urlPort(pprofAddress); err == nil { agentPorts[port] = "pprof" } // exchangeToken returns a session token. // This is abstracted to allow for the same looping condition // regardless of instance identity auth type. var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error) switch auth { case "token": token, err := inv.ParsedFlags().GetString(varAgentToken) if err != nil { return xerrors.Errorf("CODER_AGENT_TOKEN must be set for token auth: %w", err) } client.SetSessionToken(token) case "google-instance-identity": // This is *only* done for testing to mock client authentication. // This will never be set in a production scenario. var gcpClient *metadata.Client gcpClientRaw := ctx.Value("gcp-client") if gcpClientRaw != nil { gcpClient, _ = gcpClientRaw.(*metadata.Client) } exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { return client.AuthGoogleInstanceIdentity(ctx, "", gcpClient) } case "aws-instance-identity": // This is *only* done for testing to mock client authentication. // This will never be set in a production scenario. var awsClient *http.Client awsClientRaw := ctx.Value("aws-client") if awsClientRaw != nil { awsClient, _ = awsClientRaw.(*http.Client) if awsClient != nil { client.SDK.HTTPClient = awsClient } } exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { return client.AuthAWSInstanceIdentity(ctx) } case "azure-instance-identity": // This is *only* done for testing to mock client authentication. // This will never be set in a production scenario. var azureClient *http.Client azureClientRaw := ctx.Value("azure-client") if azureClientRaw != nil { azureClient, _ = azureClientRaw.(*http.Client) if azureClient != nil { client.SDK.HTTPClient = azureClient } } exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { return client.AuthAzureInstanceIdentity(ctx) } } executablePath, err := os.Executable() if err != nil { return xerrors.Errorf("getting os executable: %w", err) } err = os.Setenv("PATH", fmt.Sprintf("%s%c%s", os.Getenv("PATH"), filepath.ListSeparator, filepath.Dir(executablePath))) if err != nil { return xerrors.Errorf("add executable to $PATH: %w", err) } closer := agent.New(agent.Options{ Client: client, Logger: logger, LogDir: logDir, ExchangeToken: func(ctx context.Context) (string, error) { if exchangeToken == nil { return client.SDK.SessionToken(), nil } resp, err := exchangeToken(ctx) if err != nil { return "", err } client.SetSessionToken(resp.SessionToken) return resp.SessionToken, nil }, EnvironmentVariables: map[string]string{ "GIT_ASKPASS": executablePath, }, AgentPorts: agentPorts, SSHMaxTimeout: sshMaxTimeout, }) <-ctx.Done() return closer.Close() }, } cmd.Options = clibase.OptionSet{ { Flag: "auth", Default: "token", Description: "Specify the authentication type to use for the agent.", Env: "CODER_AGENT_AUTH", Value: clibase.StringOf(&auth), }, { Flag: "log-dir", Default: os.TempDir(), Description: "Specify the location for the agent log files.", Env: "CODER_AGENT_LOG_DIR", Value: clibase.StringOf(&logDir), }, { Flag: "pprof-address", Default: "127.0.0.1:6060", Env: "CODER_AGENT_PPROF_ADDRESS", Value: clibase.StringOf(&pprofAddress), Description: "The address to serve pprof.", }, { Flag: "no-reap", Env: "", Description: "Do not start a process reaper.", Value: clibase.BoolOf(&noReap), }, { Flag: "ssh-max-timeout", Default: "0", Env: "CODER_AGENT_SSH_MAX_TIMEOUT", Description: "Specify the max timeout for a SSH connection.", Value: clibase.DurationOf(&sshMaxTimeout), }, } return cmd } func serveHandler(ctx context.Context, logger slog.Logger, handler http.Handler, addr, name string) (closeFunc func()) { logger.Debug(ctx, "http server listening", slog.F("addr", addr), slog.F("name", name)) // ReadHeaderTimeout is purposefully not enabled. It caused some issues with // websockets over the dev tunnel. // See: https://github.com/coder/coder/pull/3730 //nolint:gosec srv := &http.Server{ Addr: addr, Handler: handler, } go func() { err := srv.ListenAndServe() if err != nil && !xerrors.Is(err, http.ErrServerClosed) { logger.Error(ctx, "http server listen", slog.F("name", name), slog.Error(err)) } }() return func() { _ = srv.Close() } } // closeWriter is a wrapper around an io.WriteCloser that prevents // writes after Close. This is necessary because lumberjack will // re-open the file on write. type closeWriter struct { w io.WriteCloser mu sync.Mutex // Protects following. closed bool } func (c *closeWriter) Close() error { c.mu.Lock() defer c.mu.Unlock() c.closed = true return c.w.Close() } func (c *closeWriter) Write(p []byte) (int, error) { c.mu.Lock() defer c.mu.Unlock() if c.closed { return 0, io.ErrClosedPipe } return c.w.Write(p) } // extractPort handles different url strings. // - localhost:6060 // - http://localhost:6060 func extractPort(u string) (int, error) { port, firstError := urlPort(u) if firstError == nil { return port, nil } // Try with a scheme port, err := urlPort("http://" + u) if err == nil { return port, nil } return -1, xerrors.Errorf("invalid url %q: %w", u, firstError) } // urlPort extracts the port from a valid URL. func urlPort(u string) (int, error) { parsed, err := url.Parse(u) if err != nil { return -1, xerrors.Errorf("invalid url %q: %w", u, err) } if parsed.Port() != "" { port, err := strconv.ParseInt(parsed.Port(), 10, 64) if err == nil && port > 0 { return int(port), nil } } return -1, xerrors.Errorf("invalid port: %s", u) }