mirror of https://github.com/coder/coder.git
208 lines
6.5 KiB
Go
208 lines
6.5 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/pprof"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"time"
|
|
|
|
"cloud.google.com/go/compute/metadata"
|
|
"github.com/spf13/cobra"
|
|
"golang.org/x/xerrors"
|
|
"gopkg.in/natefinch/lumberjack.v2"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/sloghuman"
|
|
"github.com/coder/coder/agent"
|
|
"github.com/coder/coder/agent/reaper"
|
|
"github.com/coder/coder/buildinfo"
|
|
"github.com/coder/coder/cli/cliflag"
|
|
"github.com/coder/coder/codersdk"
|
|
)
|
|
|
|
func workspaceAgent() *cobra.Command {
|
|
var (
|
|
auth string
|
|
pprofAddress string
|
|
noReap bool
|
|
)
|
|
cmd := &cobra.Command{
|
|
Use: "agent",
|
|
// This command isn't useful to manually execute.
|
|
Hidden: true,
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
ctx, cancel := context.WithCancel(cmd.Context())
|
|
defer cancel()
|
|
|
|
go dumpHandler(ctx)
|
|
|
|
rawURL, err := cmd.Flags().GetString(varAgentURL)
|
|
if err != nil {
|
|
return xerrors.Errorf("CODER_AGENT_URL must be set: %w", err)
|
|
}
|
|
coderURL, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return xerrors.Errorf("parse %q: %w", rawURL, err)
|
|
}
|
|
|
|
logWriter := &lumberjack.Logger{
|
|
Filename: filepath.Join(os.TempDir(), "coder-agent.log"),
|
|
MaxSize: 5, // MB
|
|
}
|
|
defer logWriter.Close()
|
|
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug)
|
|
|
|
isLinux := runtime.GOOS == "linux"
|
|
|
|
// Spawn a reaper so that we don't accumulate a ton
|
|
// of zombie processes.
|
|
if reaper.IsInitProcess() && !noReap && isLinux {
|
|
logger.Info(ctx, "spawning reaper process")
|
|
// Do not start a reaper on the child process. It's important
|
|
// to do this else we fork bomb ourselves.
|
|
args := append(os.Args, "--no-reap")
|
|
err := reaper.ForkReap(reaper.WithExecArgs(args...))
|
|
if err != nil {
|
|
logger.Error(ctx, "failed to reap", slog.Error(err))
|
|
return xerrors.Errorf("fork reap: %w", err)
|
|
}
|
|
|
|
logger.Info(ctx, "reaper process exiting")
|
|
return nil
|
|
}
|
|
|
|
version := buildinfo.Version()
|
|
logger.Info(ctx, "starting agent",
|
|
slog.F("url", coderURL),
|
|
slog.F("auth", auth),
|
|
slog.F("version", version),
|
|
)
|
|
client := codersdk.New(coderURL)
|
|
client.Logger = logger
|
|
// Set a reasonable timeout so requests can't hang forever!
|
|
client.HTTPClient.Timeout = 10 * time.Second
|
|
|
|
// Enable pprof handler
|
|
// This prevents the pprof import from being accidentally deleted.
|
|
_ = pprof.Handler
|
|
pprofSrvClose := serveHandler(ctx, logger, nil, pprofAddress, "pprof")
|
|
defer pprofSrvClose()
|
|
|
|
// exchangeToken returns a session token.
|
|
// This is abstracted to allow for the same looping condition
|
|
// regardless of instance identity auth type.
|
|
var exchangeToken func(context.Context) (codersdk.WorkspaceAgentAuthenticateResponse, error)
|
|
switch auth {
|
|
case "token":
|
|
token, err := cmd.Flags().GetString(varAgentToken)
|
|
if err != nil {
|
|
return xerrors.Errorf("CODER_AGENT_TOKEN must be set for token auth: %w", err)
|
|
}
|
|
client.SetSessionToken(token)
|
|
case "google-instance-identity":
|
|
// This is *only* done for testing to mock client authentication.
|
|
// This will never be set in a production scenario.
|
|
var gcpClient *metadata.Client
|
|
gcpClientRaw := ctx.Value("gcp-client")
|
|
if gcpClientRaw != nil {
|
|
gcpClient, _ = gcpClientRaw.(*metadata.Client)
|
|
}
|
|
exchangeToken = func(ctx context.Context) (codersdk.WorkspaceAgentAuthenticateResponse, error) {
|
|
return client.AuthWorkspaceGoogleInstanceIdentity(ctx, "", gcpClient)
|
|
}
|
|
case "aws-instance-identity":
|
|
// This is *only* done for testing to mock client authentication.
|
|
// This will never be set in a production scenario.
|
|
var awsClient *http.Client
|
|
awsClientRaw := ctx.Value("aws-client")
|
|
if awsClientRaw != nil {
|
|
awsClient, _ = awsClientRaw.(*http.Client)
|
|
if awsClient != nil {
|
|
client.HTTPClient = awsClient
|
|
}
|
|
}
|
|
exchangeToken = func(ctx context.Context) (codersdk.WorkspaceAgentAuthenticateResponse, error) {
|
|
return client.AuthWorkspaceAWSInstanceIdentity(ctx)
|
|
}
|
|
case "azure-instance-identity":
|
|
// This is *only* done for testing to mock client authentication.
|
|
// This will never be set in a production scenario.
|
|
var azureClient *http.Client
|
|
azureClientRaw := ctx.Value("azure-client")
|
|
if azureClientRaw != nil {
|
|
azureClient, _ = azureClientRaw.(*http.Client)
|
|
if azureClient != nil {
|
|
client.HTTPClient = azureClient
|
|
}
|
|
}
|
|
exchangeToken = func(ctx context.Context) (codersdk.WorkspaceAgentAuthenticateResponse, error) {
|
|
return client.AuthWorkspaceAzureInstanceIdentity(ctx)
|
|
}
|
|
}
|
|
|
|
executablePath, err := os.Executable()
|
|
if err != nil {
|
|
return xerrors.Errorf("getting os executable: %w", err)
|
|
}
|
|
err = os.Setenv("PATH", fmt.Sprintf("%s%c%s", os.Getenv("PATH"), filepath.ListSeparator, filepath.Dir(executablePath)))
|
|
if err != nil {
|
|
return xerrors.Errorf("add executable to $PATH: %w", err)
|
|
}
|
|
|
|
closer := agent.New(agent.Options{
|
|
Client: client,
|
|
Logger: logger,
|
|
ExchangeToken: func(ctx context.Context) (string, error) {
|
|
if exchangeToken == nil {
|
|
return client.SessionToken(), nil
|
|
}
|
|
resp, err := exchangeToken(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
client.SetSessionToken(resp.SessionToken)
|
|
return resp.SessionToken, nil
|
|
},
|
|
EnvironmentVariables: map[string]string{
|
|
"GIT_ASKPASS": executablePath,
|
|
},
|
|
})
|
|
<-ctx.Done()
|
|
return closer.Close()
|
|
},
|
|
}
|
|
|
|
cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AGENT_AUTH", "token", "Specify the authentication type to use for the agent")
|
|
cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.")
|
|
cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.")
|
|
return cmd
|
|
}
|
|
|
|
func serveHandler(ctx context.Context, logger slog.Logger, handler http.Handler, addr, name string) (closeFunc func()) {
|
|
logger.Debug(ctx, "http server listening", slog.F("addr", addr), slog.F("name", name))
|
|
|
|
// ReadHeaderTimeout is purposefully not enabled. It caused some issues with
|
|
// websockets over the dev tunnel.
|
|
// See: https://github.com/coder/coder/pull/3730
|
|
//nolint:gosec
|
|
srv := &http.Server{
|
|
Addr: addr,
|
|
Handler: handler,
|
|
}
|
|
go func() {
|
|
err := srv.ListenAndServe()
|
|
if err != nil && !xerrors.Is(err, http.ErrServerClosed) {
|
|
logger.Error(ctx, "http server listen", slog.F("name", name), slog.Error(err))
|
|
}
|
|
}()
|
|
|
|
return func() {
|
|
_ = srv.Close()
|
|
}
|
|
}
|