From c63f569174de311c49a21ea0c2d351172ceb9f07 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 19 Feb 2024 16:30:00 +0200 Subject: [PATCH] refactor(agent/agentssh): move envs to agent and add agentssh config struct (#12204) This commit refactors where custom environment variables are set in the workspace and decouples agent specific configs from the `agentssh.Server`. To reproduce all functionality, `agentssh.Config` is introduced. The custom environment variables are now configured in `agent/agent.go` and the agent retains control of the final state. This will allow for easier extension in the future and keep other modules decoupled. --- agent/agent.go | 97 +++++++++++++-- agent/agent_test.go | 85 ++++++++++++- agent/agentscripts/agentscripts_test.go | 5 +- agent/agentssh/agentssh.go | 150 +++++++++++------------ agent/agentssh/agentssh_internal_test.go | 2 +- agent/agentssh/agentssh_test.go | 30 +---- agent/agentssh/x11.go | 6 +- agent/agentssh/x11_test.go | 10 +- cli/agent.go | 20 +-- 9 files changed, 268 insertions(+), 137 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 48c8f66694..a369432c03 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -146,7 +146,7 @@ func New(options Options) Agent { logger: options.Logger, closeCancel: cancelFunc, closed: make(chan struct{}), - envVars: options.EnvironmentVariables, + environmentVariables: options.EnvironmentVariables, client: options.Client, exchangeToken: options.ExchangeToken, filesystem: options.Filesystem, @@ -169,6 +169,8 @@ func New(options Options) Agent { prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), } + a.serviceBanner.Store(new(codersdk.ServiceBannerConfig)) + a.sessionToken.Store(new(string)) a.init(ctx) return a } @@ -196,7 +198,7 @@ type agent struct { closeMutex sync.Mutex closed chan struct{} - envVars map[string]string + environmentVariables map[string]string manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection. reportMetadataInterval time.Duration @@ -235,14 +237,16 @@ func (a *agent) TailnetConn() *tailnet.Conn { } func (a *agent) init(ctx context.Context) { - sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.sshMaxTimeout, "") + sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{ + MaxTimeout: a.sshMaxTimeout, + MOTDFile: func() string { return a.manifest.Load().MOTDFile }, + ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() }, + UpdateEnv: a.updateCommandEnv, + WorkingDirectory: func() string { return a.manifest.Load().Directory }, + }) if err != nil { panic(err) } - sshSrv.Env = a.envVars - sshSrv.AgentToken = func() string { return *a.sessionToken.Load() } - sshSrv.Manifest = &a.manifest - sshSrv.ServiceBanner = &a.serviceBanner a.sshServer = sshSrv a.scriptRunner = agentscripts.New(agentscripts.Options{ LogDir: a.logDir, @@ -879,6 +883,83 @@ func (a *agent) run(ctx context.Context) error { return eg.Wait() } +// updateCommandEnv updates the provided command environment with the +// following set of environment variables: +// - Predefined workspace environment variables +// - Environment variables currently set (overriding predefined) +// - Environment variables passed via the agent manifest (overriding predefined and current) +// - Agent-level environment variables (overriding all) +func (a *agent) updateCommandEnv(current []string) (updated []string, err error) { + manifest := a.manifest.Load() + if manifest == nil { + return nil, xerrors.Errorf("no manifest") + } + + executablePath, err := os.Executable() + if err != nil { + return nil, xerrors.Errorf("getting os executable: %w", err) + } + unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") + + // Define environment variables that should be set for all commands, + // and then merge them with the current environment. + envs := map[string]string{ + // Set env vars indicating we're inside a Coder workspace. + "CODER": "true", + "CODER_WORKSPACE_NAME": manifest.WorkspaceName, + "CODER_WORKSPACE_AGENT_NAME": manifest.AgentName, + + // Specific Coder subcommands require the agent token exposed! + "CODER_AGENT_TOKEN": *a.sessionToken.Load(), + + // Git on Windows resolves with UNIX-style paths. + // If using backslashes, it's unable to find the executable. + "GIT_SSH_COMMAND": fmt.Sprintf("%s gitssh --", unixExecutablePath), + // Hide Coder message on code-server's "Getting Started" page + "CS_DISABLE_GETTING_STARTED_OVERRIDE": "true", + } + + // This adds the ports dialog to code-server that enables + // proxying a port dynamically. + // If this is empty string, do not set anything. Code-server auto defaults + // using its basepath to construct a path based port proxy. + if manifest.VSCodePortProxyURI != "" { + envs["VSCODE_PROXY_URI"] = manifest.VSCodePortProxyURI + } + + // Allow any of the current env to override what we defined above. + for _, env := range current { + parts := strings.SplitN(env, "=", 2) + if len(parts) != 2 { + continue + } + if _, ok := envs[parts[0]]; !ok { + envs[parts[0]] = parts[1] + } + } + + // Load environment variables passed via the agent manifest. + // These override all variables we manually specify. + for k, v := range manifest.EnvironmentVariables { + // Expanding environment variables allows for customization + // of the $PATH, among other variables. Customers can prepend + // or append to the $PATH, so allowing expand is required! + envs[k] = os.ExpandEnv(v) + } + + // Agent-level environment variables should take over all. This is + // used for setting agent-specific variables like CODER_AGENT_TOKEN + // and GIT_ASKPASS. + for k, v := range a.environmentVariables { + envs[k] = v + } + + for k, v := range envs { + updated = append(updated, fmt.Sprintf("%s=%s", k, v)) + } + return updated, nil +} + func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix { if len(a.addresses) == 0 { return []netip.Prefix{ @@ -1314,7 +1395,7 @@ func (a *agent) manageProcessPriorityLoop(ctx context.Context) { } }() - if val := a.envVars[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" { + if val := a.environmentVariables[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" { a.logger.Debug(ctx, "process priority not enabled, agent will not manage process niceness/oom_score_adj ", slog.F("env_var", EnvProcPrioMgmt), slog.F("value", val), diff --git a/agent/agent_test.go b/agent/agent_test.go index f30dc430ad..b894beeca9 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "math/rand" @@ -281,6 +282,83 @@ func TestAgent_SessionExec(t *testing.T) { require.Equal(t, "test", strings.TrimSpace(string(output))) } +//nolint:tparallel // Sub tests need to run sequentially. +func TestAgent_Session_EnvironmentVariables(t *testing.T) { + t.Parallel() + + manifest := agentsdk.Manifest{ + EnvironmentVariables: map[string]string{ + "MY_MANIFEST": "true", + "MY_OVERRIDE": "false", + "MY_SESSION_MANIFEST": "false", + }, + } + banner := codersdk.ServiceBannerConfig{} + session := setupSSHSession(t, manifest, banner, nil, func(_ *agenttest.Client, opts *agent.Options) { + opts.EnvironmentVariables["MY_OVERRIDE"] = "true" + }) + + err := session.Setenv("MY_SESSION_MANIFEST", "true") + require.NoError(t, err) + err = session.Setenv("MY_SESSION", "true") + require.NoError(t, err) + + command := "sh" + echoEnv := func(t *testing.T, w io.Writer, env string) { + if runtime.GOOS == "windows" { + _, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env) + require.NoError(t, err) + } else { + _, err := fmt.Fprintf(w, "echo $%s\n", env) + require.NoError(t, err) + } + } + if runtime.GOOS == "windows" { + command = "cmd.exe" + } + stdin, err := session.StdinPipe() + require.NoError(t, err) + defer stdin.Close() + stdout, err := session.StdoutPipe() + require.NoError(t, err) + + err = session.Start(command) + require.NoError(t, err) + + // Context is fine here since we're not doing a parallel subtest. + ctx := testutil.Context(t, testutil.WaitLong) + go func() { + <-ctx.Done() + _ = session.Close() + }() + + s := bufio.NewScanner(stdout) + + //nolint:paralleltest // These tests need to run sequentially. + for k, partialV := range map[string]string{ + "CODER": "true", // From the agent. + "MY_MANIFEST": "true", // From the manifest. + "MY_OVERRIDE": "true", // From the agent environment variables option, overrides manifest. + "MY_SESSION_MANIFEST": "false", // From the manifest, overrides session env. + "MY_SESSION": "true", // From the session. + } { + t.Run(k, func(t *testing.T) { + echoEnv(t, stdin, k) + // Windows is unreliable, so keep scanning until we find a match. + for s.Scan() { + got := strings.TrimSpace(s.Text()) + t.Logf("%s=%s", k, got) + if strings.Contains(got, partialV) { + break + } + } + if err := s.Err(); !errors.Is(err, io.EOF) { + require.NoError(t, err) + } + }) + } +} + func TestAgent_GitSSH(t *testing.T) { t.Parallel() session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil) @@ -1991,15 +2069,17 @@ func setupSSHSession( manifest agentsdk.Manifest, serviceBanner codersdk.ServiceBannerConfig, prepareFS func(fs afero.Fs), + opts ...func(*agenttest.Client, *agent.Options), ) *ssh.Session { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - //nolint:dogsled - conn, _, _, fs, _ := setupAgent(t, manifest, 0, func(c *agenttest.Client, _ *agent.Options) { + opts = append(opts, func(c *agenttest.Client, o *agent.Options) { c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) { return serviceBanner, nil }) }) + //nolint:dogsled + conn, _, _, fs, _ := setupAgent(t, manifest, 0, opts...) if prepareFS != nil { prepareFS(fs) } @@ -2057,6 +2137,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati Filesystem: fs, Logger: logger.Named("agent"), ReconnectingPTYTimeout: ptyTimeout, + EnvironmentVariables: map[string]string{}, } for _, opt := range opts { diff --git a/agent/agentscripts/agentscripts_test.go b/agent/agentscripts/agentscripts_test.go index 9957b8833b..bb3f842a45 100644 --- a/agent/agentscripts/agentscripts_test.go +++ b/agent/agentscripts/agentscripts_test.go @@ -8,7 +8,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/spf13/afero" "github.com/stretchr/testify/require" - "go.uber.org/atomic" "go.uber.org/goleak" "cdr.dev/slog/sloggers/slogtest" @@ -72,10 +71,8 @@ func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchL } fs := afero.NewMemMapFs() logger := slogtest.Make(t, nil) - s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, 0, "") + s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil) require.NoError(t, err) - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) t.Cleanup(func() { _ = s.Close() }) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index f05bbaf7c8..48da6aa029 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -32,7 +32,6 @@ import ( "github.com/coder/coder/v2/agent/usershell" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/pty" ) @@ -55,6 +54,28 @@ const ( MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains" ) +// Config sets configuration parameters for the agent SSH server. +type Config struct { + // MaxTimeout sets the absolute connection timeout, none if empty. If set to + // 3 seconds or more, keep alive will be used instead. + MaxTimeout time.Duration + // MOTDFile returns the path to the message of the day file. If set, the + // file will be displayed to the user upon login. + MOTDFile func() string + // ServiceBanner returns the configuration for the Coder service banner. + ServiceBanner func() *codersdk.ServiceBannerConfig + // UpdateEnv updates the environment variables for the command to be + // executed. It can be used to add, modify or replace environment variables. + UpdateEnv func(current []string) (updated []string, err error) + // WorkingDirectory sets the working directory for commands and defines + // where users will land when they connect via SSH. Default is the home + // directory of the user. + WorkingDirectory func() string + // X11SocketDir is the directory where X11 sockets are created. Default is + // /tmp/.X11-unix. + X11SocketDir string +} + type Server struct { mu sync.RWMutex // Protects following. fs afero.Fs @@ -66,14 +87,10 @@ type Server struct { // a lock on mu but protected by closing. wg sync.WaitGroup - logger slog.Logger - srv *ssh.Server - x11SocketDir string + logger slog.Logger + srv *ssh.Server - Env map[string]string - AgentToken func() string - Manifest *atomic.Pointer[agentsdk.Manifest] - ServiceBanner *atomic.Pointer[codersdk.ServiceBannerConfig] + config *Config connCountVSCode atomic.Int64 connCountJetBrains atomic.Int64 @@ -82,7 +99,7 @@ type Server struct { metrics *sshServerMetrics } -func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, maxTimeout time.Duration, x11SocketDir string) (*Server, error) { +func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, config *Config) (*Server, error) { // Clients' should ignore the host key when connecting. // The agent needs to authenticate with coderd to SSH, // so SSH authentication doesn't improve security. @@ -94,8 +111,29 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom if err != nil { return nil, err } - if x11SocketDir == "" { - x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix") + if config == nil { + config = &Config{} + } + if config.X11SocketDir == "" { + config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix") + } + if config.UpdateEnv == nil { + config.UpdateEnv = func(current []string) ([]string, error) { return current, nil } + } + if config.MOTDFile == nil { + config.MOTDFile = func() string { return "" } + } + if config.ServiceBanner == nil { + config.ServiceBanner = func() *codersdk.ServiceBannerConfig { return &codersdk.ServiceBannerConfig{} } + } + if config.WorkingDirectory == nil { + config.WorkingDirectory = func() string { + home, err := userHomeDir() + if err != nil { + return "" + } + return home + } } forwardHandler := &ssh.ForwardedTCPHandler{} @@ -103,12 +141,13 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom metrics := newSSHServerMetrics(prometheusRegistry) s := &Server{ - listeners: make(map[net.Listener]struct{}), - fs: fs, - conns: make(map[net.Conn]struct{}), - sessions: make(map[ssh.Session]struct{}), - logger: logger, - x11SocketDir: x11SocketDir, + listeners: make(map[net.Listener]struct{}), + fs: fs, + conns: make(map[net.Conn]struct{}), + sessions: make(map[ssh.Session]struct{}), + logger: logger, + + config: config, metrics: metrics, } @@ -172,14 +211,16 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom }, } - // The MaxTimeout functionality has been substituted with the introduction of the KeepAlive feature. - // In cases where very short timeouts are set, the SSH server will automatically switch to the connection timeout for both read and write operations. - if maxTimeout >= 3*time.Second { + // The MaxTimeout functionality has been substituted with the introduction + // of the KeepAlive feature. In cases where very short timeouts are set, the + // SSH server will automatically switch to the connection timeout for both + // read and write operations. + if config.MaxTimeout >= 3*time.Second { srv.ClientAliveCountMax = 3 - srv.ClientAliveInterval = maxTimeout / time.Duration(srv.ClientAliveCountMax) + srv.ClientAliveInterval = config.MaxTimeout / time.Duration(srv.ClientAliveCountMax) srv.MaxTimeout = 0 } else { - srv.MaxTimeout = maxTimeout + srv.MaxTimeout = config.MaxTimeout } s.srv = srv @@ -400,7 +441,7 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy session.DisablePTYEmulation() if isLoginShell(session.RawCommand()) { - serviceBanner := s.ServiceBanner.Load() + serviceBanner := s.config.ServiceBanner() if serviceBanner != nil { err := showServiceBanner(session, serviceBanner) if err != nil { @@ -411,15 +452,10 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy } if !isQuietLogin(s.fs, session.RawCommand()) { - manifest := s.Manifest.Load() - if manifest != nil { - err := showMOTD(s.fs, session, manifest.MOTDFile) - if err != nil { - logger.Error(ctx, "agent failed to show MOTD", slog.Error(err)) - s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "motd").Add(1) - } - } else { - logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") + err := showMOTD(s.fs, session, s.config.MOTDFile()) + if err != nil { + logger.Error(ctx, "agent failed to show MOTD", slog.Error(err)) + s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "motd").Add(1) } } @@ -589,11 +625,6 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) return nil, xerrors.Errorf("get user shell: %w", err) } - manifest := s.Manifest.Load() - if manifest == nil { - return nil, xerrors.Errorf("no metadata was provided") - } - // OpenSSH executes all commands with the users current shell. // We replicate that behavior for IDE support. caller := "-c" @@ -638,7 +669,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) } cmd := pty.CommandContext(ctx, name, args...) - cmd.Dir = manifest.Directory + cmd.Dir = s.config.WorkingDirectory() // If the metadata directory doesn't exist, we run the command // in the users home directory. @@ -652,23 +683,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) cmd.Dir = homedir } cmd.Env = append(os.Environ(), env...) - executablePath, err := os.Executable() - if err != nil { - return nil, xerrors.Errorf("getting os executable: %w", err) - } - // Set environment variables reliable detection of being inside a - // Coder workspace. - cmd.Env = append(cmd.Env, "CODER=true") - cmd.Env = append(cmd.Env, "CODER_WORKSPACE_NAME="+manifest.WorkspaceName) - cmd.Env = append(cmd.Env, "CODER_WORKSPACE_AGENT_NAME="+manifest.AgentName) cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) - // Git on Windows resolves with UNIX-style paths. - // If using backslashes, it's unable to find the executable. - unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") - cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) - - // Specific Coder subcommands require the agent token exposed! - cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", s.AgentToken())) // Set SSH connection environment variables (these are also set by OpenSSH // and thus expected to be present by SSH clients). Since the agent does @@ -679,30 +694,9 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort)) cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort)) - // This adds the ports dialog to code-server that enables - // proxying a port dynamically. - // If this is empty string, do not set anything. Code-server auto defaults - // using its basepath to construct a path based port proxy. - if manifest.VSCodePortProxyURI != "" { - cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI)) - } - - // Hide Coder message on code-server's "Getting Started" page - cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true") - - // Load environment variables passed via the agent. - // These should override all variables we manually specify. - for envKey, value := range manifest.EnvironmentVariables { - // Expanding environment variables allows for customization - // of the $PATH, among other variables. Customers can prepend - // or append to the $PATH, so allowing expand is required! - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value))) - } - - // Agent-level environment variables should take over all! - // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". - for envKey, value := range s.Env { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) + cmd.Env, err = s.config.UpdateEnv(cmd.Env) + if err != nil { + return nil, xerrors.Errorf("apply env: %w", err) } return cmd, nil diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index 1bdc3541a7..703b228c58 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -37,7 +37,7 @@ func Test_sessionStart_orphan(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() logger := slogtest.Make(t, nil) - s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "") + s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) require.NoError(t, err) defer s.Close() diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index 49d07a11bd..4404d21b5d 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -17,14 +17,12 @@ import ( "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" "go.uber.org/goleak" "golang.org/x/crypto/ssh" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/agent/agentssh" - "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/testutil" ) @@ -38,14 +36,10 @@ func TestNewServer_ServeClient(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, nil) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "") + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) require.NoError(t, err) defer s.Close() - // The assumption is that these are set before serving SSH connections. - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) - ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -83,13 +77,11 @@ func TestNewServer_ExecuteShebang(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, nil) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "") + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) require.NoError(t, err) t.Cleanup(func() { _ = s.Close() }) - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) t.Run("Basic", func(t *testing.T) { t.Parallel() @@ -116,14 +108,10 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "") + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) require.NoError(t, err) defer s.Close() - // The assumption is that these are set before serving SSH connections. - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) - ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -171,14 +159,10 @@ func TestNewServer_Signal(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, nil) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "") + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) require.NoError(t, err) defer s.Close() - // The assumption is that these are set before serving SSH connections. - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) - ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -240,14 +224,10 @@ func TestNewServer_Signal(t *testing.T) { ctx := context.Background() logger := slogtest.Make(t, nil) - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "") + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) require.NoError(t, err) defer s.Close() - // The assumption is that these are set before serving SSH connections. - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) - ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index 462bc1042b..2b083fbf04 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -32,9 +32,9 @@ func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool { return false } - err = s.fs.MkdirAll(s.x11SocketDir, 0o700) + err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700) if err != nil { - s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.x11SocketDir), slog.Error(err)) + s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err)) s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1) return false } @@ -57,7 +57,7 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool { return false } // We want to overwrite the socket so that subsequent connections will succeed. - socketPath := filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)) + socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)) err := os.Remove(socketPath) if err != nil && !errors.Is(err, os.ErrNotExist) { s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err)) diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go index e5f3f62ddc..da3c68c3e5 100644 --- a/agent/agentssh/x11_test.go +++ b/agent/agentssh/x11_test.go @@ -14,13 +14,11 @@ import ( "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/agent/agentssh" - "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/testutil" ) @@ -34,14 +32,12 @@ func TestServer_X11(t *testing.T) { logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) fs := afero.NewOsFs() dir := t.TempDir() - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, 0, dir) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{ + X11SocketDir: dir, + }) require.NoError(t, err) defer s.Close() - // The assumption is that these are set before serving SSH connections. - s.AgentToken = func() string { return "" } - s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) - ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/cli/agent.go b/cli/agent.go index 533065ff62..c951ec7509 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -278,8 +278,13 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { subsystems = append(subsystems, subsystem) } - procTicker := time.NewTicker(time.Second) - defer procTicker.Stop() + environmentVariables := map[string]string{ + "GIT_ASKPASS": executablePath, + } + if v, ok := os.LookupEnv(agent.EnvProcPrioMgmt); ok { + environmentVariables[agent.EnvProcPrioMgmt] = v + } + agnt := agent.New(agent.Options{ Client: client, Logger: logger, @@ -296,13 +301,10 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { client.SetSessionToken(resp.SessionToken) return resp.SessionToken, nil }, - EnvironmentVariables: map[string]string{ - "GIT_ASKPASS": executablePath, - agent.EnvProcPrioMgmt: os.Getenv(agent.EnvProcPrioMgmt), - }, - IgnorePorts: ignorePorts, - SSHMaxTimeout: sshMaxTimeout, - Subsystems: subsystems, + EnvironmentVariables: environmentVariables, + IgnorePorts: ignorePorts, + SSHMaxTimeout: sshMaxTimeout, + Subsystems: subsystems, PrometheusRegistry: prometheusRegistry, Syscaller: agentproc.NewSyscaller(),