refactor(agent/agentssh): move envs to agent and add agentssh config struct (#12204)

This commit refactors where custom environment variables are set in the
workspace and decouples agent specific configs from the `agentssh.Server`.
To reproduce all functionality, `agentssh.Config` is introduced.

The custom environment variables are now configured in `agent/agent.go`
and the agent retains control of the final state. This will allow for
easier extension in the future and keep other modules decoupled.
This commit is contained in:
Mathias Fredriksson 2024-02-19 16:30:00 +02:00 committed by GitHub
parent 817cc78b94
commit c63f569174
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 268 additions and 137 deletions

View File

@ -146,7 +146,7 @@ func New(options Options) Agent {
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
envVars: options.EnvironmentVariables,
environmentVariables: options.EnvironmentVariables,
client: options.Client,
exchangeToken: options.ExchangeToken,
filesystem: options.Filesystem,
@ -169,6 +169,8 @@ func New(options Options) Agent {
prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
}
a.serviceBanner.Store(new(codersdk.ServiceBannerConfig))
a.sessionToken.Store(new(string))
a.init(ctx)
return a
}
@ -196,7 +198,7 @@ type agent struct {
closeMutex sync.Mutex
closed chan struct{}
envVars map[string]string
environmentVariables map[string]string
manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection.
reportMetadataInterval time.Duration
@ -235,14 +237,16 @@ func (a *agent) TailnetConn() *tailnet.Conn {
}
func (a *agent) init(ctx context.Context) {
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.sshMaxTimeout, "")
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() },
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
})
if err != nil {
panic(err)
}
sshSrv.Env = a.envVars
sshSrv.AgentToken = func() string { return *a.sessionToken.Load() }
sshSrv.Manifest = &a.manifest
sshSrv.ServiceBanner = &a.serviceBanner
a.sshServer = sshSrv
a.scriptRunner = agentscripts.New(agentscripts.Options{
LogDir: a.logDir,
@ -879,6 +883,83 @@ func (a *agent) run(ctx context.Context) error {
return eg.Wait()
}
// updateCommandEnv updates the provided command environment with the
// following set of environment variables:
// - Predefined workspace environment variables
// - Environment variables currently set (overriding predefined)
// - Environment variables passed via the agent manifest (overriding predefined and current)
// - Agent-level environment variables (overriding all)
func (a *agent) updateCommandEnv(current []string) (updated []string, err error) {
manifest := a.manifest.Load()
if manifest == nil {
return nil, xerrors.Errorf("no manifest")
}
executablePath, err := os.Executable()
if err != nil {
return nil, xerrors.Errorf("getting os executable: %w", err)
}
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
// Define environment variables that should be set for all commands,
// and then merge them with the current environment.
envs := map[string]string{
// Set env vars indicating we're inside a Coder workspace.
"CODER": "true",
"CODER_WORKSPACE_NAME": manifest.WorkspaceName,
"CODER_WORKSPACE_AGENT_NAME": manifest.AgentName,
// Specific Coder subcommands require the agent token exposed!
"CODER_AGENT_TOKEN": *a.sessionToken.Load(),
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
"GIT_SSH_COMMAND": fmt.Sprintf("%s gitssh --", unixExecutablePath),
// Hide Coder message on code-server's "Getting Started" page
"CS_DISABLE_GETTING_STARTED_OVERRIDE": "true",
}
// This adds the ports dialog to code-server that enables
// proxying a port dynamically.
// If this is empty string, do not set anything. Code-server auto defaults
// using its basepath to construct a path based port proxy.
if manifest.VSCodePortProxyURI != "" {
envs["VSCODE_PROXY_URI"] = manifest.VSCodePortProxyURI
}
// Allow any of the current env to override what we defined above.
for _, env := range current {
parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 {
continue
}
if _, ok := envs[parts[0]]; !ok {
envs[parts[0]] = parts[1]
}
}
// Load environment variables passed via the agent manifest.
// These override all variables we manually specify.
for k, v := range manifest.EnvironmentVariables {
// Expanding environment variables allows for customization
// of the $PATH, among other variables. Customers can prepend
// or append to the $PATH, so allowing expand is required!
envs[k] = os.ExpandEnv(v)
}
// Agent-level environment variables should take over all. This is
// used for setting agent-specific variables like CODER_AGENT_TOKEN
// and GIT_ASKPASS.
for k, v := range a.environmentVariables {
envs[k] = v
}
for k, v := range envs {
updated = append(updated, fmt.Sprintf("%s=%s", k, v))
}
return updated, nil
}
func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
if len(a.addresses) == 0 {
return []netip.Prefix{
@ -1314,7 +1395,7 @@ func (a *agent) manageProcessPriorityLoop(ctx context.Context) {
}
}()
if val := a.envVars[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" {
if val := a.environmentVariables[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" {
a.logger.Debug(ctx, "process priority not enabled, agent will not manage process niceness/oom_score_adj ",
slog.F("env_var", EnvProcPrioMgmt),
slog.F("value", val),

View File

@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
@ -281,6 +282,83 @@ func TestAgent_SessionExec(t *testing.T) {
require.Equal(t, "test", strings.TrimSpace(string(output)))
}
//nolint:tparallel // Sub tests need to run sequentially.
func TestAgent_Session_EnvironmentVariables(t *testing.T) {
t.Parallel()
manifest := agentsdk.Manifest{
EnvironmentVariables: map[string]string{
"MY_MANIFEST": "true",
"MY_OVERRIDE": "false",
"MY_SESSION_MANIFEST": "false",
},
}
banner := codersdk.ServiceBannerConfig{}
session := setupSSHSession(t, manifest, banner, nil, func(_ *agenttest.Client, opts *agent.Options) {
opts.EnvironmentVariables["MY_OVERRIDE"] = "true"
})
err := session.Setenv("MY_SESSION_MANIFEST", "true")
require.NoError(t, err)
err = session.Setenv("MY_SESSION", "true")
require.NoError(t, err)
command := "sh"
echoEnv := func(t *testing.T, w io.Writer, env string) {
if runtime.GOOS == "windows" {
_, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env)
require.NoError(t, err)
} else {
_, err := fmt.Fprintf(w, "echo $%s\n", env)
require.NoError(t, err)
}
}
if runtime.GOOS == "windows" {
command = "cmd.exe"
}
stdin, err := session.StdinPipe()
require.NoError(t, err)
defer stdin.Close()
stdout, err := session.StdoutPipe()
require.NoError(t, err)
err = session.Start(command)
require.NoError(t, err)
// Context is fine here since we're not doing a parallel subtest.
ctx := testutil.Context(t, testutil.WaitLong)
go func() {
<-ctx.Done()
_ = session.Close()
}()
s := bufio.NewScanner(stdout)
//nolint:paralleltest // These tests need to run sequentially.
for k, partialV := range map[string]string{
"CODER": "true", // From the agent.
"MY_MANIFEST": "true", // From the manifest.
"MY_OVERRIDE": "true", // From the agent environment variables option, overrides manifest.
"MY_SESSION_MANIFEST": "false", // From the manifest, overrides session env.
"MY_SESSION": "true", // From the session.
} {
t.Run(k, func(t *testing.T) {
echoEnv(t, stdin, k)
// Windows is unreliable, so keep scanning until we find a match.
for s.Scan() {
got := strings.TrimSpace(s.Text())
t.Logf("%s=%s", k, got)
if strings.Contains(got, partialV) {
break
}
}
if err := s.Err(); !errors.Is(err, io.EOF) {
require.NoError(t, err)
}
})
}
}
func TestAgent_GitSSH(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
@ -1991,15 +2069,17 @@ func setupSSHSession(
manifest agentsdk.Manifest,
serviceBanner codersdk.ServiceBannerConfig,
prepareFS func(fs afero.Fs),
opts ...func(*agenttest.Client, *agent.Options),
) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, manifest, 0, func(c *agenttest.Client, _ *agent.Options) {
opts = append(opts, func(c *agenttest.Client, o *agent.Options) {
c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
return serviceBanner, nil
})
})
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, manifest, 0, opts...)
if prepareFS != nil {
prepareFS(fs)
}
@ -2057,6 +2137,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
EnvironmentVariables: map[string]string{},
}
for _, opt := range opts {

View File

@ -8,7 +8,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
@ -72,10 +71,8 @@ func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchL
}
fs := afero.NewMemMapFs()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, 0, "")
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil)
require.NoError(t, err)
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
t.Cleanup(func() {
_ = s.Close()
})

View File

@ -32,7 +32,6 @@ import (
"github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty"
)
@ -55,6 +54,28 @@ const (
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
)
// Config sets configuration parameters for the agent SSH server.
type Config struct {
// MaxTimeout sets the absolute connection timeout, none if empty. If set to
// 3 seconds or more, keep alive will be used instead.
MaxTimeout time.Duration
// MOTDFile returns the path to the message of the day file. If set, the
// file will be displayed to the user upon login.
MOTDFile func() string
// ServiceBanner returns the configuration for the Coder service banner.
ServiceBanner func() *codersdk.ServiceBannerConfig
// UpdateEnv updates the environment variables for the command to be
// executed. It can be used to add, modify or replace environment variables.
UpdateEnv func(current []string) (updated []string, err error)
// WorkingDirectory sets the working directory for commands and defines
// where users will land when they connect via SSH. Default is the home
// directory of the user.
WorkingDirectory func() string
// X11SocketDir is the directory where X11 sockets are created. Default is
// /tmp/.X11-unix.
X11SocketDir string
}
type Server struct {
mu sync.RWMutex // Protects following.
fs afero.Fs
@ -66,14 +87,10 @@ type Server struct {
// a lock on mu but protected by closing.
wg sync.WaitGroup
logger slog.Logger
srv *ssh.Server
x11SocketDir string
logger slog.Logger
srv *ssh.Server
Env map[string]string
AgentToken func() string
Manifest *atomic.Pointer[agentsdk.Manifest]
ServiceBanner *atomic.Pointer[codersdk.ServiceBannerConfig]
config *Config
connCountVSCode atomic.Int64
connCountJetBrains atomic.Int64
@ -82,7 +99,7 @@ type Server struct {
metrics *sshServerMetrics
}
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, maxTimeout time.Duration, x11SocketDir string) (*Server, error) {
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, config *Config) (*Server, error) {
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
@ -94,8 +111,29 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
if err != nil {
return nil, err
}
if x11SocketDir == "" {
x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
if config == nil {
config = &Config{}
}
if config.X11SocketDir == "" {
config.X11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
}
if config.UpdateEnv == nil {
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
}
if config.MOTDFile == nil {
config.MOTDFile = func() string { return "" }
}
if config.ServiceBanner == nil {
config.ServiceBanner = func() *codersdk.ServiceBannerConfig { return &codersdk.ServiceBannerConfig{} }
}
if config.WorkingDirectory == nil {
config.WorkingDirectory = func() string {
home, err := userHomeDir()
if err != nil {
return ""
}
return home
}
}
forwardHandler := &ssh.ForwardedTCPHandler{}
@ -103,12 +141,13 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
metrics := newSSHServerMetrics(prometheusRegistry)
s := &Server{
listeners: make(map[net.Listener]struct{}),
fs: fs,
conns: make(map[net.Conn]struct{}),
sessions: make(map[ssh.Session]struct{}),
logger: logger,
x11SocketDir: x11SocketDir,
listeners: make(map[net.Listener]struct{}),
fs: fs,
conns: make(map[net.Conn]struct{}),
sessions: make(map[ssh.Session]struct{}),
logger: logger,
config: config,
metrics: metrics,
}
@ -172,14 +211,16 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
},
}
// The MaxTimeout functionality has been substituted with the introduction of the KeepAlive feature.
// In cases where very short timeouts are set, the SSH server will automatically switch to the connection timeout for both read and write operations.
if maxTimeout >= 3*time.Second {
// The MaxTimeout functionality has been substituted with the introduction
// of the KeepAlive feature. In cases where very short timeouts are set, the
// SSH server will automatically switch to the connection timeout for both
// read and write operations.
if config.MaxTimeout >= 3*time.Second {
srv.ClientAliveCountMax = 3
srv.ClientAliveInterval = maxTimeout / time.Duration(srv.ClientAliveCountMax)
srv.ClientAliveInterval = config.MaxTimeout / time.Duration(srv.ClientAliveCountMax)
srv.MaxTimeout = 0
} else {
srv.MaxTimeout = maxTimeout
srv.MaxTimeout = config.MaxTimeout
}
s.srv = srv
@ -400,7 +441,7 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
session.DisablePTYEmulation()
if isLoginShell(session.RawCommand()) {
serviceBanner := s.ServiceBanner.Load()
serviceBanner := s.config.ServiceBanner()
if serviceBanner != nil {
err := showServiceBanner(session, serviceBanner)
if err != nil {
@ -411,15 +452,10 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
}
if !isQuietLogin(s.fs, session.RawCommand()) {
manifest := s.Manifest.Load()
if manifest != nil {
err := showMOTD(s.fs, session, manifest.MOTDFile)
if err != nil {
logger.Error(ctx, "agent failed to show MOTD", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "motd").Add(1)
}
} else {
logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
err := showMOTD(s.fs, session, s.config.MOTDFile())
if err != nil {
logger.Error(ctx, "agent failed to show MOTD", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "motd").Add(1)
}
}
@ -589,11 +625,6 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
return nil, xerrors.Errorf("get user shell: %w", err)
}
manifest := s.Manifest.Load()
if manifest == nil {
return nil, xerrors.Errorf("no metadata was provided")
}
// OpenSSH executes all commands with the users current shell.
// We replicate that behavior for IDE support.
caller := "-c"
@ -638,7 +669,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
}
cmd := pty.CommandContext(ctx, name, args...)
cmd.Dir = manifest.Directory
cmd.Dir = s.config.WorkingDirectory()
// If the metadata directory doesn't exist, we run the command
// in the users home directory.
@ -652,23 +683,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
cmd.Dir = homedir
}
cmd.Env = append(os.Environ(), env...)
executablePath, err := os.Executable()
if err != nil {
return nil, xerrors.Errorf("getting os executable: %w", err)
}
// Set environment variables reliable detection of being inside a
// Coder workspace.
cmd.Env = append(cmd.Env, "CODER=true")
cmd.Env = append(cmd.Env, "CODER_WORKSPACE_NAME="+manifest.WorkspaceName)
cmd.Env = append(cmd.Env, "CODER_WORKSPACE_AGENT_NAME="+manifest.AgentName)
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
// Specific Coder subcommands require the agent token exposed!
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", s.AgentToken()))
// Set SSH connection environment variables (these are also set by OpenSSH
// and thus expected to be present by SSH clients). Since the agent does
@ -679,30 +694,9 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort))
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort))
// This adds the ports dialog to code-server that enables
// proxying a port dynamically.
// If this is empty string, do not set anything. Code-server auto defaults
// using its basepath to construct a path based port proxy.
if manifest.VSCodePortProxyURI != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI))
}
// Hide Coder message on code-server's "Getting Started" page
cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true")
// Load environment variables passed via the agent.
// These should override all variables we manually specify.
for envKey, value := range manifest.EnvironmentVariables {
// Expanding environment variables allows for customization
// of the $PATH, among other variables. Customers can prepend
// or append to the $PATH, so allowing expand is required!
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value)))
}
// Agent-level environment variables should take over all!
// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
for envKey, value := range s.Env {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
cmd.Env, err = s.config.UpdateEnv(cmd.Env)
if err != nil {
return nil, xerrors.Errorf("apply env: %w", err)
}
return cmd, nil

View File

@ -37,7 +37,7 @@ func Test_sessionStart_orphan(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil)
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
require.NoError(t, err)
defer s.Close()

View File

@ -17,14 +17,12 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@ -38,14 +36,10 @@ func TestNewServer_ServeClient(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@ -83,13 +77,11 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = s.Close()
})
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
t.Run("Basic", func(t *testing.T) {
t.Parallel()
@ -116,14 +108,10 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@ -171,14 +159,10 @@ func TestNewServer_Signal(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@ -240,14 +224,10 @@ func TestNewServer_Signal(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

View File

@ -32,9 +32,9 @@ func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
return false
}
err = s.fs.MkdirAll(s.x11SocketDir, 0o700)
err = s.fs.MkdirAll(s.config.X11SocketDir, 0o700)
if err != nil {
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.x11SocketDir), slog.Error(err))
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.config.X11SocketDir), slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
return false
}
@ -57,7 +57,7 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
return false
}
// We want to overwrite the socket so that subsequent connections will succeed.
socketPath := filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
socketPath := filepath.Join(s.config.X11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
err := os.Remove(socketPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err))

View File

@ -14,13 +14,11 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
)
@ -34,14 +32,12 @@ func TestServer_X11(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fs := afero.NewOsFs()
dir := t.TempDir()
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, 0, dir)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{
X11SocketDir: dir,
})
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

View File

@ -278,8 +278,13 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
subsystems = append(subsystems, subsystem)
}
procTicker := time.NewTicker(time.Second)
defer procTicker.Stop()
environmentVariables := map[string]string{
"GIT_ASKPASS": executablePath,
}
if v, ok := os.LookupEnv(agent.EnvProcPrioMgmt); ok {
environmentVariables[agent.EnvProcPrioMgmt] = v
}
agnt := agent.New(agent.Options{
Client: client,
Logger: logger,
@ -296,13 +301,10 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
client.SetSessionToken(resp.SessionToken)
return resp.SessionToken, nil
},
EnvironmentVariables: map[string]string{
"GIT_ASKPASS": executablePath,
agent.EnvProcPrioMgmt: os.Getenv(agent.EnvProcPrioMgmt),
},
IgnorePorts: ignorePorts,
SSHMaxTimeout: sshMaxTimeout,
Subsystems: subsystems,
EnvironmentVariables: environmentVariables,
IgnorePorts: ignorePorts,
SSHMaxTimeout: sshMaxTimeout,
Subsystems: subsystems,
PrometheusRegistry: prometheusRegistry,
Syscaller: agentproc.NewSyscaller(),