mirror of https://github.com/coder/coder.git
refactor(agent): Move SSH server into agentssh package (#7004)
Refs: #6177
This commit is contained in:
parent
3ff2ae1b1a
commit
0224426e5b
568
agent/agent.go
568
agent/agent.go
|
@ -4,8 +4,6 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -16,11 +14,9 @@ import (
|
|||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -28,12 +24,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/armon/circbuf"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/spf13/afero"
|
||||
"go.uber.org/atomic"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/net/speedtest"
|
||||
|
@ -41,7 +34,7 @@ import (
|
|||
"tailscale.com/types/netlogtype"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent/usershell"
|
||||
"github.com/coder/coder/agent/agentssh"
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
|
@ -56,19 +49,6 @@ const (
|
|||
ProtocolReconnectingPTY = "reconnecting-pty"
|
||||
ProtocolSSH = "ssh"
|
||||
ProtocolDial = "dial"
|
||||
|
||||
// MagicSessionErrorCode indicates that something went wrong with the session, rather than the
|
||||
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
|
||||
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
|
||||
MagicSessionErrorCode = 229
|
||||
|
||||
// MagicSSHSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
|
||||
// This is stripped from any commands being executed, and is counted towards connection stats.
|
||||
MagicSSHSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
|
||||
// MagicSSHSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
|
||||
MagicSSHSessionTypeVSCode = "vscode"
|
||||
// MagicSSHSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself.
|
||||
MagicSSHSessionTypeJetBrains = "jetbrains"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
|
@ -165,7 +145,7 @@ type agent struct {
|
|||
// manifest is atomic because values can change after reconnection.
|
||||
manifest atomic.Pointer[agentsdk.Manifest]
|
||||
sessionToken atomic.Pointer[string]
|
||||
sshServer *ssh.Server
|
||||
sshServer *agentssh.Server
|
||||
sshMaxTimeout time.Duration
|
||||
|
||||
lifecycleUpdate chan struct{}
|
||||
|
@ -177,10 +157,20 @@ type agent struct {
|
|||
connStatsChan chan *agentsdk.Stats
|
||||
latestStat atomic.Pointer[agentsdk.Stats]
|
||||
|
||||
connCountVSCode atomic.Int64
|
||||
connCountJetBrains atomic.Int64
|
||||
connCountReconnectingPTY atomic.Int64
|
||||
connCountSSHSession atomic.Int64
|
||||
}
|
||||
|
||||
func (a *agent) init(ctx context.Context) {
|
||||
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sshSrv.Env = a.envVars
|
||||
sshSrv.AgentToken = func() string { return *a.sessionToken.Load() }
|
||||
sshSrv.Manifest = &a.manifest
|
||||
a.sshServer = sshSrv
|
||||
|
||||
go a.runLoop(ctx)
|
||||
}
|
||||
|
||||
// runLoop attempts to start the agent in a retry loop.
|
||||
|
@ -223,7 +213,7 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
|
|||
// if it is certain the clocks are in sync.
|
||||
CollectedAt: time.Now(),
|
||||
}
|
||||
cmd, err := a.createCommand(ctx, md.Script, nil)
|
||||
cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
|
||||
if err != nil {
|
||||
result.Error = err.Error()
|
||||
return result
|
||||
|
@ -633,28 +623,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
|
|||
}
|
||||
}()
|
||||
if err = a.trackConnGoroutine(func() {
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := sshListener.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-closed:
|
||||
case <-a.closed:
|
||||
_ = conn.Close()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
defer close(closed)
|
||||
a.sshServer.HandleConn(conn)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
_ = a.sshServer.Serve(sshListener)
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -857,7 +826,7 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
|
|||
}()
|
||||
}
|
||||
|
||||
cmd, err := a.createCommand(ctx, script, nil)
|
||||
cmd, err := a.sshServer.CreateCommand(ctx, script, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create command: %w", err)
|
||||
}
|
||||
|
@ -990,394 +959,6 @@ func (a *agent) trackScriptLogs(ctx context.Context, reader io.Reader) (chan str
|
|||
return logsFinished, nil
|
||||
}
|
||||
|
||||
func (a *agent) init(ctx context.Context) {
|
||||
// Clients' should ignore the host key when connecting.
|
||||
// The agent needs to authenticate with coderd to SSH,
|
||||
// so SSH authentication doesn't improve security.
|
||||
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sshLogger := a.logger.Named("ssh-server")
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := &forwardedUnixHandler{log: a.logger}
|
||||
|
||||
a.sshServer = &ssh.Server{
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"direct-tcpip": ssh.DirectTCPIPHandler,
|
||||
"direct-streamlocal@openssh.com": directStreamLocalHandler,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
ConnectionFailedCallback: func(conn net.Conn, err error) {
|
||||
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
|
||||
},
|
||||
Handler: func(session ssh.Session) {
|
||||
err := a.handleSSHSession(session)
|
||||
var exitError *exec.ExitError
|
||||
if xerrors.As(err, &exitError) {
|
||||
a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
|
||||
_ = session.Exit(exitError.ExitCode())
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "ssh session failed", slog.Error(err))
|
||||
// This exit code is designed to be unlikely to be confused for a legit exit code
|
||||
// from the process.
|
||||
_ = session.Exit(MagicSessionErrorCode)
|
||||
return
|
||||
}
|
||||
_ = session.Exit(0)
|
||||
},
|
||||
HostSigners: []ssh.Signer{randomSigner},
|
||||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
// Allow local port forwarding all!
|
||||
sshLogger.Debug(ctx, "local port forward",
|
||||
slog.F("destination-host", destinationHost),
|
||||
slog.F("destination-port", destinationPort))
|
||||
return true
|
||||
},
|
||||
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
|
||||
return true
|
||||
},
|
||||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
// Allow reverse port forwarding all!
|
||||
sshLogger.Debug(ctx, "local port forward",
|
||||
slog.F("bind-host", bindHost),
|
||||
slog.F("bind-port", bindPort))
|
||||
return true
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": forwardHandler.HandleSSHRequest,
|
||||
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
|
||||
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
|
||||
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
|
||||
},
|
||||
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
|
||||
return &gossh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
}
|
||||
},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": func(session ssh.Session) {
|
||||
ctx := session.Context()
|
||||
|
||||
// Typically sftp sessions don't request a TTY, but if they do,
|
||||
// we must ensure the gliderlabs/ssh CRLF emulation is disabled.
|
||||
// Otherwise sftp will be broken. This can happen if a user sets
|
||||
// `RequestTTY force` in their SSH config.
|
||||
session.DisablePTYEmulation()
|
||||
|
||||
var opts []sftp.ServerOption
|
||||
// Change current working directory to the users home
|
||||
// directory so that SFTP connections land there.
|
||||
homedir, err := userHomeDir()
|
||||
if err != nil {
|
||||
sshLogger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err))
|
||||
} else {
|
||||
opts = append(opts, sftp.WithServerWorkingDirectory(homedir))
|
||||
}
|
||||
|
||||
server, err := sftp.NewServer(session, opts...)
|
||||
if err != nil {
|
||||
sshLogger.Debug(ctx, "initialize sftp server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
err = server.Serve()
|
||||
if errors.Is(err, io.EOF) {
|
||||
// Unless we call `session.Exit(0)` here, the client won't
|
||||
// receive `exit-status` because `(*sftp.Server).Close()`
|
||||
// calls `Close()` on the underlying connection (session),
|
||||
// which actually calls `channel.Close()` because it isn't
|
||||
// wrapped. This causes sftp clients to receive a non-zero
|
||||
// exit code. Typically sftp clients don't echo this exit
|
||||
// code but `scp` on macOS does (when using the default
|
||||
// SFTP backend).
|
||||
_ = session.Exit(0)
|
||||
return
|
||||
}
|
||||
sshLogger.Warn(ctx, "sftp server closed with error", slog.Error(err))
|
||||
_ = session.Exit(1)
|
||||
},
|
||||
},
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
}
|
||||
|
||||
go a.runLoop(ctx)
|
||||
}
|
||||
|
||||
// createCommand processes raw command input with OpenSSH-like behavior.
|
||||
// If the script provided is empty, it will default to the users shell.
|
||||
// This injects environment variables specified by the user at launch too.
|
||||
func (a *agent) createCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get current user: %w", err)
|
||||
}
|
||||
username := currentUser.Username
|
||||
|
||||
shell, err := usershell.Get(username)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get user shell: %w", err)
|
||||
}
|
||||
|
||||
manifest := a.manifest.Load()
|
||||
if manifest == nil {
|
||||
return nil, xerrors.Errorf("no metadata was provided")
|
||||
}
|
||||
|
||||
// OpenSSH executes all commands with the users current shell.
|
||||
// We replicate that behavior for IDE support.
|
||||
caller := "-c"
|
||||
if runtime.GOOS == "windows" {
|
||||
caller = "/c"
|
||||
}
|
||||
args := []string{caller, script}
|
||||
|
||||
// gliderlabs/ssh returns a command slice of zero
|
||||
// when a shell is requested.
|
||||
if len(script) == 0 {
|
||||
args = []string{}
|
||||
if runtime.GOOS != "windows" {
|
||||
// On Linux and macOS, we should start a login
|
||||
// shell to consume juicy environment variables!
|
||||
args = append(args, "-l")
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, shell, args...)
|
||||
cmd.Dir = manifest.Directory
|
||||
|
||||
// If the metadata directory doesn't exist, we run the command
|
||||
// in the users home directory.
|
||||
_, err = os.Stat(cmd.Dir)
|
||||
if cmd.Dir == "" || err != nil {
|
||||
// Default to user home if a directory is not set.
|
||||
homedir, err := userHomeDir()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get home dir: %w", err)
|
||||
}
|
||||
cmd.Dir = homedir
|
||||
}
|
||||
cmd.Env = append(os.Environ(), env...)
|
||||
executablePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("getting os executable: %w", err)
|
||||
}
|
||||
// Set environment variables reliable detection of being inside a
|
||||
// Coder workspace.
|
||||
cmd.Env = append(cmd.Env, "CODER=true")
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
|
||||
// Git on Windows resolves with UNIX-style paths.
|
||||
// If using backslashes, it's unable to find the executable.
|
||||
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
|
||||
|
||||
// Specific Coder subcommands require the agent token exposed!
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", *a.sessionToken.Load()))
|
||||
|
||||
// Set SSH connection environment variables (these are also set by OpenSSH
|
||||
// and thus expected to be present by SSH clients). Since the agent does
|
||||
// networking in-memory, trying to provide accurate values here would be
|
||||
// nonsensical. For now, we hard code these values so that they're present.
|
||||
srcAddr, srcPort := "0.0.0.0", "0"
|
||||
dstAddr, dstPort := "0.0.0.0", "0"
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort))
|
||||
|
||||
// This adds the ports dialog to code-server that enables
|
||||
// proxying a port dynamically.
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI))
|
||||
|
||||
// Hide Coder message on code-server's "Getting Started" page
|
||||
cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true")
|
||||
|
||||
// Load environment variables passed via the agent.
|
||||
// These should override all variables we manually specify.
|
||||
for envKey, value := range manifest.EnvironmentVariables {
|
||||
// Expanding environment variables allows for customization
|
||||
// of the $PATH, among other variables. Customers can prepend
|
||||
// or append to the $PATH, so allowing expand is required!
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value)))
|
||||
}
|
||||
|
||||
// Agent-level environment variables should take over all!
|
||||
// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
|
||||
for envKey, value := range a.envVars {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
||||
ctx := session.Context()
|
||||
env := session.Environ()
|
||||
var magicType string
|
||||
for index, kv := range env {
|
||||
if !strings.HasPrefix(kv, MagicSSHSessionTypeEnvironmentVariable) {
|
||||
continue
|
||||
}
|
||||
magicType = strings.TrimPrefix(kv, MagicSSHSessionTypeEnvironmentVariable+"=")
|
||||
env = append(env[:index], env[index+1:]...)
|
||||
}
|
||||
switch magicType {
|
||||
case MagicSSHSessionTypeVSCode:
|
||||
a.connCountVSCode.Add(1)
|
||||
defer a.connCountVSCode.Add(-1)
|
||||
case MagicSSHSessionTypeJetBrains:
|
||||
a.connCountJetBrains.Add(1)
|
||||
defer a.connCountJetBrains.Add(-1)
|
||||
case "":
|
||||
a.connCountSSHSession.Add(1)
|
||||
defer a.connCountSSHSession.Add(-1)
|
||||
default:
|
||||
a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
|
||||
}
|
||||
|
||||
cmd, err := a.createCommand(ctx, session.RawCommand(), env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ssh.AgentRequested(session) {
|
||||
l, err := ssh.NewAgentListener()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("new agent listener: %w", err)
|
||||
}
|
||||
defer l.Close()
|
||||
go ssh.ForwardAgentConnections(l, session)
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
|
||||
}
|
||||
|
||||
sshPty, windowSize, isPty := session.Pty()
|
||||
if isPty {
|
||||
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
|
||||
// See https://github.com/coder/coder/issues/3371.
|
||||
session.DisablePTYEmulation()
|
||||
|
||||
if !isQuietLogin(session.RawCommand()) {
|
||||
manifest := a.manifest.Load()
|
||||
if manifest != nil {
|
||||
err = showMOTD(session, manifest.MOTDFile)
|
||||
if err != nil {
|
||||
a.logger.Error(ctx, "show MOTD", slog.Error(err))
|
||||
}
|
||||
} else {
|
||||
a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
|
||||
|
||||
// The pty package sets `SSH_TTY` on supported platforms.
|
||||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
|
||||
pty.WithSSHRequest(sshPty),
|
||||
pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)),
|
||||
))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("start command: %w", err)
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
defer func() {
|
||||
defer wg.Wait()
|
||||
closeErr := ptty.Close()
|
||||
if closeErr != nil {
|
||||
a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
|
||||
if retErr == nil {
|
||||
retErr = closeErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for win := range windowSize {
|
||||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||
// If the pty is closed, then command has exited, no need to log.
|
||||
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
|
||||
a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
|
||||
}
|
||||
}
|
||||
}()
|
||||
// We don't add input copy to wait group because
|
||||
// it won't return until the session is closed.
|
||||
go func() {
|
||||
_, _ = io.Copy(ptty.Input(), session)
|
||||
}()
|
||||
|
||||
// In low parallelism scenarios, the command may exit and we may close
|
||||
// the pty before the output copy has started. This can result in the
|
||||
// output being lost. To avoid this, we wait for the output copy to
|
||||
// start before waiting for the command to exit. This ensures that the
|
||||
// output copy goroutine will be scheduled before calling close on the
|
||||
// pty. This shouldn't be needed because of `pty.Dup()` below, but it
|
||||
// may not be supported on all platforms.
|
||||
outputCopyStarted := make(chan struct{})
|
||||
ptyOutput := func() io.ReadCloser {
|
||||
defer close(outputCopyStarted)
|
||||
// Try to dup so we can separate stdin and stdout closure.
|
||||
// Once the original pty is closed, the dup will return
|
||||
// input/output error once the buffered data has been read.
|
||||
stdout, err := ptty.Dup()
|
||||
if err == nil {
|
||||
return stdout
|
||||
}
|
||||
// If we can't dup, we shouldn't close
|
||||
// the fd since it's tied to stdin.
|
||||
return readNopCloser{ptty.Output()}
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
// Ensure data is flushed to session on command exit, if we
|
||||
// close the session too soon, we might lose data.
|
||||
defer wg.Done()
|
||||
|
||||
stdout := ptyOutput()
|
||||
defer stdout.Close()
|
||||
|
||||
_, _ = io.Copy(session, stdout)
|
||||
}()
|
||||
<-outputCopyStarted
|
||||
|
||||
err = process.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
|
||||
// and not something to be concerned about. But, if it's something else, we should log it.
|
||||
if err != nil && !xerrors.As(err, &exitErr) {
|
||||
a.logger.Warn(ctx, "wait error", slog.Error(err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Stdout = session
|
||||
cmd.Stderr = session.Stderr()
|
||||
// This blocks forever until stdin is received if we don't
|
||||
// use StdinPipe. It's unknown what causes this.
|
||||
stdinPipe, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create stdin pipe: %w", err)
|
||||
}
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinPipe, session)
|
||||
_ = stdinPipe.Close()
|
||||
}()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("start: %w", err)
|
||||
}
|
||||
return cmd.Wait()
|
||||
}
|
||||
|
||||
type readNopCloser struct{ io.Reader }
|
||||
|
||||
// Close implements io.Closer.
|
||||
func (readNopCloser) Close() error { return nil }
|
||||
|
||||
func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) {
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -1416,7 +997,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
|||
logger.Debug(ctx, "creating new session")
|
||||
|
||||
// Empty command will default to the users shell!
|
||||
cmd, err := a.createCommand(ctx, msg.Command, nil)
|
||||
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create command: %w", err)
|
||||
}
|
||||
|
@ -1590,9 +1171,11 @@ func (a *agent) startReportingConnectionStats(ctx context.Context) {
|
|||
}
|
||||
|
||||
// The count of active sessions.
|
||||
stats.SessionCountSSH = a.connCountSSHSession.Load()
|
||||
stats.SessionCountVSCode = a.connCountVSCode.Load()
|
||||
stats.SessionCountJetBrains = a.connCountJetBrains.Load()
|
||||
sshStats := a.sshServer.ConnStats()
|
||||
stats.SessionCountSSH = sshStats.Sessions
|
||||
stats.SessionCountVSCode = sshStats.VSCode
|
||||
stats.SessionCountJetBrains = sshStats.JetBrains
|
||||
|
||||
stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load()
|
||||
|
||||
// Compute the median connection latency!
|
||||
|
@ -1692,8 +1275,16 @@ func (a *agent) Close() error {
|
|||
}
|
||||
|
||||
ctx := context.Background()
|
||||
a.logger.Info(ctx, "shutting down agent")
|
||||
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown)
|
||||
|
||||
// Attempt to gracefully shut down all active SSH connections and
|
||||
// stop accepting new ones.
|
||||
err := a.sshServer.Shutdown(ctx)
|
||||
if err != nil {
|
||||
a.logger.Error(ctx, "ssh server shutdown", slog.Error(err))
|
||||
}
|
||||
|
||||
lifecycleState := codersdk.WorkspaceAgentLifecycleOff
|
||||
if manifest := a.manifest.Load(); manifest != nil && manifest.ShutdownScript != "" {
|
||||
scriptDone := make(chan error, 1)
|
||||
|
@ -1785,101 +1376,6 @@ func (r *reconnectingPTY) Close() {
|
|||
r.timeout.Stop()
|
||||
}
|
||||
|
||||
// Bicopy copies all of the data between the two connections and will close them
|
||||
// after one or both of them are done writing. If the context is canceled, both
|
||||
// of the connections will be closed.
|
||||
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
defer func() {
|
||||
_ = c1.Close()
|
||||
_ = c2.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
copyFunc := func(dst io.WriteCloser, src io.Reader) {
|
||||
defer func() {
|
||||
wg.Done()
|
||||
// If one side of the copy fails, ensure the other one exits as
|
||||
// well.
|
||||
cancel()
|
||||
}()
|
||||
_, _ = io.Copy(dst, src)
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go copyFunc(c1, c2)
|
||||
go copyFunc(c2, c1)
|
||||
|
||||
// Convert waitgroup to a channel so we can also wait on the context.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
// isQuietLogin checks if the SSH server should perform a quiet login or not.
|
||||
//
|
||||
// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816
|
||||
func isQuietLogin(rawCommand string) bool {
|
||||
// We are always quiet unless this is a login shell.
|
||||
if len(rawCommand) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Best effort, if we can't get the home directory,
|
||||
// we can't lookup .hushlogin.
|
||||
homedir, err := userHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err = os.Stat(filepath.Join(homedir, ".hushlogin"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// showMOTD will output the message of the day from
|
||||
// the given filename to dest, if the file exists.
|
||||
//
|
||||
// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784
|
||||
func showMOTD(dest io.Writer, filename string) error {
|
||||
if filename == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, os.ErrNotExist) {
|
||||
// This is not an error, there simply isn't a MOTD to show.
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("open MOTD: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
// Carriage return ensures each line starts
|
||||
// at the beginning of the terminal.
|
||||
_, err = fmt.Fprint(dest, s.Text()+"\r\n")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write MOTD: %w", err)
|
||||
}
|
||||
}
|
||||
if err := s.Err(); err != nil {
|
||||
return xerrors.Errorf("read MOTD: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// userHomeDir returns the home directory of the current user, giving
|
||||
// priority to the $HOME environment variable.
|
||||
func userHomeDir() (string, error) {
|
||||
|
|
|
@ -41,6 +41,7 @@ import (
|
|||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/agent/agentssh"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
|
@ -131,13 +132,13 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
|||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode)
|
||||
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
|
||||
defer session.Close()
|
||||
|
||||
command := "sh -c 'echo $" + agent.MagicSSHSessionTypeEnvironmentVariable + "'"
|
||||
command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'"
|
||||
expected := ""
|
||||
if runtime.GOOS == "windows" {
|
||||
expected = "%" + agent.MagicSSHSessionTypeEnvironmentVariable + "%"
|
||||
expected = "%" + agentssh.MagicSessionTypeEnvironmentVariable + "%"
|
||||
command = "cmd.exe /c echo " + expected
|
||||
}
|
||||
output, err := session.Output(command)
|
||||
|
@ -158,7 +159,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
|||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
session.Setenv(agent.MagicSSHSessionTypeEnvironmentVariable, agent.MagicSSHSessionTypeVSCode)
|
||||
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
|
||||
defer session.Close()
|
||||
stdin, err := session.StdinPipe()
|
||||
require.NoError(t, err)
|
||||
|
@ -1651,7 +1652,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
|
|||
}
|
||||
waitGroup.Add(1)
|
||||
go func() {
|
||||
agent.Bicopy(context.Background(), conn, ssh)
|
||||
agentssh.Bicopy(context.Background(), conn, ssh)
|
||||
waitGroup.Done()
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,677 @@
|
|||
package agentssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/sftp"
|
||||
"go.uber.org/atomic"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/agent/usershell"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/pty"
|
||||
)
|
||||
|
||||
const (
|
||||
// MagicSessionErrorCode indicates that something went wrong with the session, rather than the
|
||||
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
|
||||
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
|
||||
MagicSessionErrorCode = 229
|
||||
|
||||
// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
|
||||
// This is stripped from any commands being executed, and is counted towards connection stats.
|
||||
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
|
||||
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
|
||||
MagicSessionTypeVSCode = "vscode"
|
||||
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself.
|
||||
MagicSessionTypeJetBrains = "jetbrains"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
mu sync.RWMutex // Protects following.
|
||||
listeners map[net.Listener]struct{}
|
||||
conns map[net.Conn]struct{}
|
||||
closing chan struct{}
|
||||
// Wait for goroutines to exit, waited without
|
||||
// a lock on mu but protected by closing.
|
||||
wg sync.WaitGroup
|
||||
|
||||
logger slog.Logger
|
||||
srv *ssh.Server
|
||||
|
||||
Env map[string]string
|
||||
AgentToken func() string
|
||||
Manifest *atomic.Pointer[agentsdk.Manifest]
|
||||
|
||||
connCountVSCode atomic.Int64
|
||||
connCountJetBrains atomic.Int64
|
||||
connCountSSHSession atomic.Int64
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) {
|
||||
// Clients' should ignore the host key when connecting.
|
||||
// The agent needs to authenticate with coderd to SSH,
|
||||
// so SSH authentication doesn't improve security.
|
||||
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := &forwardedUnixHandler{log: logger}
|
||||
|
||||
s := &Server{
|
||||
listeners: make(map[net.Listener]struct{}),
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
s.srv = &ssh.Server{
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"direct-tcpip": ssh.DirectTCPIPHandler,
|
||||
"direct-streamlocal@openssh.com": directStreamLocalHandler,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
ConnectionFailedCallback: func(_ net.Conn, err error) {
|
||||
s.logger.Info(ctx, "ssh connection ended", slog.Error(err))
|
||||
},
|
||||
Handler: s.sessionHandler,
|
||||
HostSigners: []ssh.Signer{randomSigner},
|
||||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
// Allow local port forwarding all!
|
||||
s.logger.Debug(ctx, "local port forward",
|
||||
slog.F("destination-host", destinationHost),
|
||||
slog.F("destination-port", destinationPort))
|
||||
return true
|
||||
},
|
||||
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
|
||||
return true
|
||||
},
|
||||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
// Allow reverse port forwarding all!
|
||||
s.logger.Debug(ctx, "local port forward",
|
||||
slog.F("bind-host", bindHost),
|
||||
slog.F("bind-port", bindPort))
|
||||
return true
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": forwardHandler.HandleSSHRequest,
|
||||
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
|
||||
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
|
||||
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
|
||||
},
|
||||
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
|
||||
return &gossh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
}
|
||||
},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": s.sftpHandler,
|
||||
},
|
||||
MaxTimeout: maxTimeout,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
type ConnStats struct {
|
||||
Sessions int64
|
||||
VSCode int64
|
||||
JetBrains int64
|
||||
}
|
||||
|
||||
func (s *Server) ConnStats() ConnStats {
|
||||
return ConnStats{
|
||||
Sessions: s.connCountSSHSession.Load(),
|
||||
VSCode: s.connCountVSCode.Load(),
|
||||
JetBrains: s.connCountJetBrains.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sessionHandler(session ssh.Session) {
|
||||
ctx := session.Context()
|
||||
err := s.sessionStart(session)
|
||||
var exitError *exec.ExitError
|
||||
if xerrors.As(err, &exitError) {
|
||||
s.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
|
||||
_ = session.Exit(exitError.ExitCode())
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "ssh session failed", slog.Error(err))
|
||||
// This exit code is designed to be unlikely to be confused for a legit exit code
|
||||
// from the process.
|
||||
_ = session.Exit(MagicSessionErrorCode)
|
||||
return
|
||||
}
|
||||
_ = session.Exit(0)
|
||||
}
|
||||
|
||||
func (s *Server) sessionStart(session ssh.Session) (retErr error) {
|
||||
ctx := session.Context()
|
||||
env := session.Environ()
|
||||
var magicType string
|
||||
for index, kv := range env {
|
||||
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
|
||||
continue
|
||||
}
|
||||
magicType = strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"=")
|
||||
env = append(env[:index], env[index+1:]...)
|
||||
}
|
||||
switch magicType {
|
||||
case MagicSessionTypeVSCode:
|
||||
s.connCountVSCode.Add(1)
|
||||
defer s.connCountVSCode.Add(-1)
|
||||
case MagicSessionTypeJetBrains:
|
||||
s.connCountJetBrains.Add(1)
|
||||
defer s.connCountJetBrains.Add(-1)
|
||||
case "":
|
||||
s.connCountSSHSession.Add(1)
|
||||
defer s.connCountSSHSession.Add(-1)
|
||||
default:
|
||||
s.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
|
||||
}
|
||||
|
||||
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ssh.AgentRequested(session) {
|
||||
l, err := ssh.NewAgentListener()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("new agent listener: %w", err)
|
||||
}
|
||||
defer l.Close()
|
||||
go ssh.ForwardAgentConnections(l, session)
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
|
||||
}
|
||||
|
||||
sshPty, windowSize, isPty := session.Pty()
|
||||
if isPty {
|
||||
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
|
||||
// See https://github.com/coder/coder/issues/3371.
|
||||
session.DisablePTYEmulation()
|
||||
|
||||
if !isQuietLogin(session.RawCommand()) {
|
||||
manifest := s.Manifest.Load()
|
||||
if manifest != nil {
|
||||
err = showMOTD(session, manifest.MOTDFile)
|
||||
if err != nil {
|
||||
s.logger.Error(ctx, "show MOTD", slog.Error(err))
|
||||
}
|
||||
} else {
|
||||
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
|
||||
|
||||
// The pty package sets `SSH_TTY` on supported platforms.
|
||||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
|
||||
pty.WithSSHRequest(sshPty),
|
||||
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
|
||||
))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("start command: %w", err)
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
defer func() {
|
||||
defer wg.Wait()
|
||||
closeErr := ptty.Close()
|
||||
if closeErr != nil {
|
||||
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
|
||||
if retErr == nil {
|
||||
retErr = closeErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for win := range windowSize {
|
||||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||
// If the pty is closed, then command has exited, no need to log.
|
||||
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
|
||||
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
|
||||
}
|
||||
}
|
||||
}()
|
||||
// We don't add input copy to wait group because
|
||||
// it won't return until the session is closed.
|
||||
go func() {
|
||||
_, _ = io.Copy(ptty.Input(), session)
|
||||
}()
|
||||
|
||||
// In low parallelism scenarios, the command may exit and we may close
|
||||
// the pty before the output copy has started. This can result in the
|
||||
// output being lost. To avoid this, we wait for the output copy to
|
||||
// start before waiting for the command to exit. This ensures that the
|
||||
// output copy goroutine will be scheduled before calling close on the
|
||||
// pty. This shouldn't be needed because of `pty.Dup()` below, but it
|
||||
// may not be supported on all platforms.
|
||||
outputCopyStarted := make(chan struct{})
|
||||
ptyOutput := func() io.ReadCloser {
|
||||
defer close(outputCopyStarted)
|
||||
// Try to dup so we can separate stdin and stdout closure.
|
||||
// Once the original pty is closed, the dup will return
|
||||
// input/output error once the buffered data has been read.
|
||||
stdout, err := ptty.Dup()
|
||||
if err == nil {
|
||||
return stdout
|
||||
}
|
||||
// If we can't dup, we shouldn't close
|
||||
// the fd since it's tied to stdin.
|
||||
return readNopCloser{ptty.Output()}
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
// Ensure data is flushed to session on command exit, if we
|
||||
// close the session too soon, we might lose data.
|
||||
defer wg.Done()
|
||||
|
||||
stdout := ptyOutput()
|
||||
defer stdout.Close()
|
||||
|
||||
_, _ = io.Copy(session, stdout)
|
||||
}()
|
||||
<-outputCopyStarted
|
||||
|
||||
err = process.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
|
||||
// and not something to be concerned about. But, if it's something else, we should log it.
|
||||
if err != nil && !xerrors.As(err, &exitErr) {
|
||||
s.logger.Warn(ctx, "wait error", slog.Error(err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Stdout = session
|
||||
cmd.Stderr = session.Stderr()
|
||||
// This blocks forever until stdin is received if we don't
|
||||
// use StdinPipe. It's unknown what causes this.
|
||||
stdinPipe, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create stdin pipe: %w", err)
|
||||
}
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinPipe, session)
|
||||
_ = stdinPipe.Close()
|
||||
}()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("start: %w", err)
|
||||
}
|
||||
return cmd.Wait()
|
||||
}
|
||||
|
||||
type readNopCloser struct{ io.Reader }
|
||||
|
||||
// Close implements io.Closer.
|
||||
func (readNopCloser) Close() error { return nil }
|
||||
|
||||
func (s *Server) sftpHandler(session ssh.Session) {
|
||||
ctx := session.Context()
|
||||
|
||||
// Typically sftp sessions don't request a TTY, but if they do,
|
||||
// we must ensure the gliderlabs/ssh CRLF emulation is disabled.
|
||||
// Otherwise sftp will be broken. This can happen if a user sets
|
||||
// `RequestTTY force` in their SSH config.
|
||||
session.DisablePTYEmulation()
|
||||
|
||||
var opts []sftp.ServerOption
|
||||
// Change current working directory to the users home
|
||||
// directory so that SFTP connections land there.
|
||||
homedir, err := userHomeDir()
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err))
|
||||
} else {
|
||||
opts = append(opts, sftp.WithServerWorkingDirectory(homedir))
|
||||
}
|
||||
|
||||
server, err := sftp.NewServer(session, opts...)
|
||||
if err != nil {
|
||||
s.logger.Debug(ctx, "initialize sftp server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
err = server.Serve()
|
||||
if errors.Is(err, io.EOF) {
|
||||
// Unless we call `session.Exit(0)` here, the client won't
|
||||
// receive `exit-status` because `(*sftp.Server).Close()`
|
||||
// calls `Close()` on the underlying connection (session),
|
||||
// which actually calls `channel.Close()` because it isn't
|
||||
// wrapped. This causes sftp clients to receive a non-zero
|
||||
// exit code. Typically sftp clients don't echo this exit
|
||||
// code but `scp` on macOS does (when using the default
|
||||
// SFTP backend).
|
||||
_ = session.Exit(0)
|
||||
return
|
||||
}
|
||||
s.logger.Warn(ctx, "sftp server closed with error", slog.Error(err))
|
||||
_ = session.Exit(1)
|
||||
}
|
||||
|
||||
// CreateCommand processes raw command input with OpenSSH-like behavior.
|
||||
// If the script provided is empty, it will default to the users shell.
|
||||
// This injects environment variables specified by the user at launch too.
|
||||
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get current user: %w", err)
|
||||
}
|
||||
username := currentUser.Username
|
||||
|
||||
shell, err := usershell.Get(username)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get user shell: %w", err)
|
||||
}
|
||||
|
||||
manifest := s.Manifest.Load()
|
||||
if manifest == nil {
|
||||
return nil, xerrors.Errorf("no metadata was provided")
|
||||
}
|
||||
|
||||
// OpenSSH executes all commands with the users current shell.
|
||||
// We replicate that behavior for IDE support.
|
||||
caller := "-c"
|
||||
if runtime.GOOS == "windows" {
|
||||
caller = "/c"
|
||||
}
|
||||
args := []string{caller, script}
|
||||
|
||||
// gliderlabs/ssh returns a command slice of zero
|
||||
// when a shell is requested.
|
||||
if len(script) == 0 {
|
||||
args = []string{}
|
||||
if runtime.GOOS != "windows" {
|
||||
// On Linux and macOS, we should start a login
|
||||
// shell to consume juicy environment variables!
|
||||
args = append(args, "-l")
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, shell, args...)
|
||||
cmd.Dir = manifest.Directory
|
||||
|
||||
// If the metadata directory doesn't exist, we run the command
|
||||
// in the users home directory.
|
||||
_, err = os.Stat(cmd.Dir)
|
||||
if cmd.Dir == "" || err != nil {
|
||||
// Default to user home if a directory is not set.
|
||||
homedir, err := userHomeDir()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get home dir: %w", err)
|
||||
}
|
||||
cmd.Dir = homedir
|
||||
}
|
||||
cmd.Env = append(os.Environ(), env...)
|
||||
executablePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("getting os executable: %w", err)
|
||||
}
|
||||
// Set environment variables reliable detection of being inside a
|
||||
// Coder workspace.
|
||||
cmd.Env = append(cmd.Env, "CODER=true")
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
|
||||
// Git on Windows resolves with UNIX-style paths.
|
||||
// If using backslashes, it's unable to find the executable.
|
||||
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
|
||||
|
||||
// Specific Coder subcommands require the agent token exposed!
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", s.AgentToken()))
|
||||
|
||||
// Set SSH connection environment variables (these are also set by OpenSSH
|
||||
// and thus expected to be present by SSH clients). Since the agent does
|
||||
// networking in-memory, trying to provide accurate values here would be
|
||||
// nonsensical. For now, we hard code these values so that they're present.
|
||||
srcAddr, srcPort := "0.0.0.0", "0"
|
||||
dstAddr, dstPort := "0.0.0.0", "0"
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort))
|
||||
|
||||
// This adds the ports dialog to code-server that enables
|
||||
// proxying a port dynamically.
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI))
|
||||
|
||||
// Hide Coder message on code-server's "Getting Started" page
|
||||
cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true")
|
||||
|
||||
// Load environment variables passed via the agent.
|
||||
// These should override all variables we manually specify.
|
||||
for envKey, value := range manifest.EnvironmentVariables {
|
||||
// Expanding environment variables allows for customization
|
||||
// of the $PATH, among other variables. Customers can prepend
|
||||
// or append to the $PATH, so allowing expand is required!
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value)))
|
||||
}
|
||||
|
||||
// Agent-level environment variables should take over all!
|
||||
// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
|
||||
for envKey, value := range s.Env {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (s *Server) Serve(l net.Listener) error {
|
||||
defer l.Close()
|
||||
|
||||
s.trackListener(l, true)
|
||||
defer s.trackListener(l, false)
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handleConn(l, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(l net.Listener, c net.Conn) {
|
||||
defer c.Close()
|
||||
|
||||
if !s.trackConn(l, c, true) {
|
||||
// Server is closed or we no longer want
|
||||
// connections from this listener.
|
||||
s.logger.Debug(context.Background(), "received connection after server closed")
|
||||
return
|
||||
}
|
||||
defer s.trackConn(l, c, false)
|
||||
|
||||
s.srv.HandleConn(c)
|
||||
}
|
||||
|
||||
// trackListener registers the listener with the server. If the server is
|
||||
// closing, the function will block until the server is closed.
|
||||
//
|
||||
//nolint:revive
|
||||
func (s *Server) trackListener(l net.Listener, add bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if add {
|
||||
for s.closing != nil {
|
||||
closing := s.closing
|
||||
// Wait until close is complete before
|
||||
// serving a new listener.
|
||||
s.mu.Unlock()
|
||||
<-closing
|
||||
s.mu.Lock()
|
||||
}
|
||||
s.wg.Add(1)
|
||||
s.listeners[l] = struct{}{}
|
||||
return
|
||||
}
|
||||
s.wg.Done()
|
||||
delete(s.listeners, l)
|
||||
}
|
||||
|
||||
// trackConn registers the connection with the server. If the server is
|
||||
// closed or the listener is closed, the connection is not registered
|
||||
// and should be closed.
|
||||
//
|
||||
//nolint:revive
|
||||
func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if add {
|
||||
found := false
|
||||
for ll := range s.listeners {
|
||||
if l == ll {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if s.closing != nil || !found {
|
||||
// Server or listener closed.
|
||||
return false
|
||||
}
|
||||
s.wg.Add(1)
|
||||
s.conns[c] = struct{}{}
|
||||
return true
|
||||
}
|
||||
s.wg.Done()
|
||||
delete(s.conns, c)
|
||||
return true
|
||||
}
|
||||
|
||||
// Close the server and all active connections. Server can be re-used
|
||||
// after Close is done.
|
||||
func (s *Server) Close() error {
|
||||
s.mu.Lock()
|
||||
|
||||
// Guard against multiple calls to Close and
|
||||
// accepting new connections during close.
|
||||
if s.closing != nil {
|
||||
s.mu.Unlock()
|
||||
return xerrors.New("server is closing")
|
||||
}
|
||||
s.closing = make(chan struct{})
|
||||
|
||||
// Close all active listeners and connections.
|
||||
for l := range s.listeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
for c := range s.conns {
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
// Close the underlying SSH server.
|
||||
err := s.srv.Close()
|
||||
|
||||
s.mu.Unlock()
|
||||
s.wg.Wait() // Wait for all goroutines to exit.
|
||||
|
||||
s.mu.Lock()
|
||||
close(s.closing)
|
||||
s.closing = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown gracefully closes all active SSH connections and stops
|
||||
// accepting new connections.
|
||||
//
|
||||
// Shutdown is not implemented.
|
||||
func (*Server) Shutdown(_ context.Context) error {
|
||||
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
|
||||
return nil
|
||||
}
|
||||
|
||||
// isQuietLogin checks if the SSH server should perform a quiet login or not.
|
||||
//
|
||||
// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816
|
||||
func isQuietLogin(rawCommand string) bool {
|
||||
// We are always quiet unless this is a login shell.
|
||||
if len(rawCommand) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Best effort, if we can't get the home directory,
|
||||
// we can't lookup .hushlogin.
|
||||
homedir, err := userHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err = os.Stat(filepath.Join(homedir, ".hushlogin"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// showMOTD will output the message of the day from
|
||||
// the given filename to dest, if the file exists.
|
||||
//
|
||||
// https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784
|
||||
func showMOTD(dest io.Writer, filename string) error {
|
||||
if filename == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, os.ErrNotExist) {
|
||||
// This is not an error, there simply isn't a MOTD to show.
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("open MOTD: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
// Carriage return ensures each line starts
|
||||
// at the beginning of the terminal.
|
||||
_, err = fmt.Fprint(dest, s.Text()+"\r\n")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write MOTD: %w", err)
|
||||
}
|
||||
}
|
||||
if err := s.Err(); err != nil {
|
||||
return xerrors.Errorf("read MOTD: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// userHomeDir returns the home directory of the current user, giving
|
||||
// priority to the $HOME environment variable.
|
||||
func userHomeDir() (string, error) {
|
||||
// First we check the environment.
|
||||
homedir, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
return homedir, nil
|
||||
}
|
||||
|
||||
// As a fallback, we try the user information.
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("current user: %w", err)
|
||||
}
|
||||
return u.HomeDir, nil
|
||||
}
|
|
@ -0,0 +1,139 @@
|
|||
// Package agentssh_test provides tests for basic functinoality of the agentssh
|
||||
// package, more test coverage can be found in the `agent` and `cli` package(s).
|
||||
package agentssh_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/agent/agentssh"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestNewServer_ServeClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The assumption is that these are set before serving SSH connections.
|
||||
s.AgentToken = func() string { return "" }
|
||||
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
err := s.Serve(ln)
|
||||
assert.Error(t, err) // Server is closed.
|
||||
}()
|
||||
|
||||
c := sshClient(t, ln.Addr().String())
|
||||
var b bytes.Buffer
|
||||
sess, err := c.NewSession()
|
||||
sess.Stdout = &b
|
||||
require.NoError(t, err)
|
||||
err = sess.Start("echo hello")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sess.Wait()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "hello", strings.TrimSpace(b.String()))
|
||||
|
||||
err = s.Close()
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestNewServer_CloseActiveConnections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
s, err := agentssh.NewServer(ctx, logger, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The assumption is that these are set before serving SSH connections.
|
||||
s.AgentToken = func() string { return "" }
|
||||
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := s.Serve(ln)
|
||||
assert.Error(t, err) // Server is closed.
|
||||
}()
|
||||
|
||||
pty := ptytest.New(t)
|
||||
|
||||
doClose := make(chan struct{})
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c := sshClient(t, ln.Addr().String())
|
||||
sess, err := c.NewSession()
|
||||
sess.Stdin = pty.Input()
|
||||
sess.Stdout = pty.Output()
|
||||
sess.Stderr = pty.Output()
|
||||
|
||||
assert.NoError(t, err)
|
||||
err = sess.Start("")
|
||||
assert.NoError(t, err)
|
||||
|
||||
close(doClose)
|
||||
err = sess.Wait()
|
||||
assert.Error(t, err)
|
||||
}()
|
||||
|
||||
<-doClose
|
||||
err = s.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func sshClient(t *testing.T, addr string) *ssh.Client {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
})
|
||||
|
||||
sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test.
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = sshConn.Close()
|
||||
})
|
||||
c := ssh.NewClient(sshConn, channels, requests)
|
||||
t.Cleanup(func() {
|
||||
_ = c.Close()
|
||||
})
|
||||
return c
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package agentssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Bicopy copies all of the data between the two connections and will close them
|
||||
// after one or both of them are done writing. If the context is canceled, both
|
||||
// of the connections will be closed.
|
||||
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
defer func() {
|
||||
_ = c1.Close()
|
||||
_ = c2.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
copyFunc := func(dst io.WriteCloser, src io.Reader) {
|
||||
defer func() {
|
||||
wg.Done()
|
||||
// If one side of the copy fails, ensure the other one exits as
|
||||
// well.
|
||||
cancel()
|
||||
}()
|
||||
_, _ = io.Copy(dst, src)
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go copyFunc(c1, c2)
|
||||
go copyFunc(c2, c1)
|
||||
|
||||
// Convert waitgroup to a channel so we can also wait on the context.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package agent
|
||||
package agentssh
|
||||
|
||||
import (
|
||||
"context"
|
|
@ -14,7 +14,7 @@ import (
|
|||
"github.com/pion/udp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/agent/agentssh"
|
||||
"github.com/coder/coder/cli/clibase"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
@ -226,7 +226,7 @@ func listenAndPortForward(ctx context.Context, inv *clibase.Invocation, conn *co
|
|||
}
|
||||
defer remoteConn.Close()
|
||||
|
||||
agent.Bicopy(ctx, netConn, remoteConn)
|
||||
agentssh.Bicopy(ctx, netConn, remoteConn)
|
||||
}(netConn)
|
||||
}
|
||||
}(spec)
|
||||
|
|
|
@ -23,7 +23,7 @@ import (
|
|||
"golang.org/x/term"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/agent/agentssh"
|
||||
"github.com/coder/coder/cli/clibase"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/autobuild/notify"
|
||||
|
@ -574,7 +574,7 @@ func sshForwardRemote(ctx context.Context, stderr io.Writer, sshClient *gossh.Cl
|
|||
}
|
||||
}
|
||||
|
||||
agent.Bicopy(ctx, localConn, remoteConn)
|
||||
agentssh.Bicopy(ctx, localConn, remoteConn)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/agent/agentssh"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
|
@ -575,7 +575,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
defer ptNetConn.Close()
|
||||
agent.Bicopy(ctx, wsNetConn, ptNetConn)
|
||||
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
|
|
Loading…
Reference in New Issue