diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index b0f9c11806..1829449a85 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -311,10 +311,10 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv if isPty { return s.startPTYSession(logger, session, magicTypeLabel, cmd, sshPty, windowSize) } - return s.startNonPTYSession(session, magicTypeLabel, cmd.AsExec()) + return s.startNonPTYSession(logger, session, magicTypeLabel, cmd.AsExec()) } -func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error { +func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error { s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1) cmd.Stdout = session @@ -338,6 +338,17 @@ func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string, s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1) return xerrors.Errorf("start: %w", err) } + sigs := make(chan ssh.Signal, 1) + session.Signals(sigs) + defer func() { + session.Signals(nil) + close(sigs) + }() + go func() { + for sig := range sigs { + s.handleSignal(logger, sig, cmd.Process, magicTypeLabel) + } + }() return cmd.Wait() } @@ -348,6 +359,7 @@ type ptySession interface { Context() ssh.Context DisablePTYEmulation() RawCommand() string + Signals(chan<- ssh.Signal) } func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTypeLabel string, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) { @@ -403,13 +415,36 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy } } }() + sigs := make(chan ssh.Signal, 1) + session.Signals(sigs) + defer func() { + session.Signals(nil) + close(sigs) + }() go func() { - for win := range windowSize { - resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) - // If the pty is closed, then command has exited, no need to log. - if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { - logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) - s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1) + for { + if sigs == nil && windowSize == nil { + return + } + + select { + case sig, ok := <-sigs: + if !ok { + sigs = nil + continue + } + s.handleSignal(logger, sig, process, magicTypeLabel) + case win, ok := <-windowSize: + if !ok { + windowSize = nil + continue + } + resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) + // If the pty is closed, then command has exited, no need to log. + if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { + logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) + s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1) + } } } }() @@ -452,6 +487,18 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy return nil } +func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) { + ctx := context.Background() + sig := osSignalFrom(ssig) + logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String())) + logger.Info(ctx, "received signal from client") + err := signaler.Signal(sig) + if err != nil { + logger.Warn(ctx, "signaling the process failed", slog.Error(err)) + s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1) + } +} + func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) { s.metrics.sftpConnectionsTotal.Add(1) diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index dd87be0503..1bdc3541a7 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -114,6 +114,11 @@ type testSSHContext struct { context.Context } +var ( + _ gliderssh.Context = testSSHContext{} + _ ptySession = &testSession{} +) + func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) { toClient, fromPty := io.Pipe() toPty, fromClient := io.Pipe() @@ -144,6 +149,10 @@ func (s *testSession) Write(p []byte) (n int, err error) { return s.fromPty.Write(p) } +func (*testSession) Signals(_ chan<- gliderssh.Signal) { + // Not implemented, but will be called. +} + func (testSSHContext) Lock() { panic("not implemented") } diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index b72da96e4c..4cd9544019 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -3,8 +3,10 @@ package agentssh_test import ( + "bufio" "bytes" "context" + "fmt" "net" "runtime" "strings" @@ -24,6 +26,7 @@ import ( "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/pty/ptytest" + "github.com/coder/coder/v2/testutil" ) func TestMain(m *testing.M) { @@ -57,8 +60,8 @@ func TestNewServer_ServeClient(t *testing.T) { var b bytes.Buffer sess, err := c.NewSession() - sess.Stdout = &b require.NoError(t, err) + sess.Stdout = &b err = sess.Start("echo hello") require.NoError(t, err) @@ -139,6 +142,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { defer wg.Done() c := sshClient(t, ln.Addr().String()) sess, err := c.NewSession() + assert.NoError(t, err) sess.Stdin = pty.Input() sess.Stdout = pty.Output() sess.Stderr = pty.Output() @@ -159,6 +163,147 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { wg.Wait() } +func TestNewServer_Signal(t *testing.T) { + t.Parallel() + + t.Run("Stdout", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + logger := slogtest.Make(t, nil) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "") + require.NoError(t, err) + defer s.Close() + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + defer func() { + err := s.Close() + require.NoError(t, err) + <-done + }() + + c := sshClient(t, ln.Addr().String()) + + sess, err := c.NewSession() + require.NoError(t, err) + r, err := sess.StdoutPipe() + require.NoError(t, err) + + // Perform multiple sleeps since the interrupt signal doesn't propagate to + // the process group, this lets us exit early. + sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds())) + err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps)) + require.NoError(t, err) + + sc := bufio.NewScanner(r) + for sc.Scan() { + t.Log(sc.Text()) + if strings.Contains(sc.Text(), "hello") { + break + } + } + require.NoError(t, sc.Err()) + + err = sess.Signal(ssh.SIGINT) + require.NoError(t, err) + + // Assumption, signal propagates and the command exists, closing stdout. + for sc.Scan() { + t.Log(sc.Text()) + require.NotContains(t, sc.Text(), "bye") + } + require.NoError(t, sc.Err()) + + err = sess.Wait() + require.Error(t, err) + }) + t.Run("PTY", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + logger := slogtest.Make(t, nil) + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "") + require.NoError(t, err) + defer s.Close() + + // The assumption is that these are set before serving SSH connections. + s.AgentToken = func() string { return "" } + s.Manifest = atomic.NewPointer(&agentsdk.Manifest{}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + err := s.Serve(ln) + assert.Error(t, err) // Server is closed. + }() + defer func() { + err := s.Close() + require.NoError(t, err) + <-done + }() + + c := sshClient(t, ln.Addr().String()) + + pty := ptytest.New(t) + + sess, err := c.NewSession() + require.NoError(t, err) + r, err := sess.StdoutPipe() + require.NoError(t, err) + + // Note, we request pty but don't use ptytest here because we can't + // easily test for no text before EOF. + sess.Stdin = pty.Input() + sess.Stderr = pty.Output() + + err = sess.RequestPty("xterm", 80, 80, nil) + require.NoError(t, err) + + // Perform multiple sleeps since the interrupt signal doesn't propagate to + // the process group, this lets us exit early. + sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds())) + err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps)) + require.NoError(t, err) + + sc := bufio.NewScanner(r) + for sc.Scan() { + t.Log(sc.Text()) + if strings.Contains(sc.Text(), "hello") { + break + } + } + require.NoError(t, sc.Err()) + + err = sess.Signal(ssh.SIGINT) + require.NoError(t, err) + + // Assumption, signal propagates and the command exists, closing stdout. + for sc.Scan() { + t.Log(sc.Text()) + require.NotContains(t, sc.Text(), "bye") + } + require.NoError(t, sc.Err()) + + err = sess.Wait() + require.Error(t, err) + }) +} + func sshClient(t *testing.T, addr string) *ssh.Client { conn, err := net.Dial("tcp", addr) require.NoError(t, err) diff --git a/agent/agentssh/signal_other.go b/agent/agentssh/signal_other.go new file mode 100644 index 0000000000..7e6f2a9937 --- /dev/null +++ b/agent/agentssh/signal_other.go @@ -0,0 +1,45 @@ +//go:build !windows + +package agentssh + +import ( + "os" + + "github.com/gliderlabs/ssh" + "golang.org/x/sys/unix" +) + +func osSignalFrom(sig ssh.Signal) os.Signal { + switch sig { + case ssh.SIGABRT: + return unix.SIGABRT + case ssh.SIGALRM: + return unix.SIGALRM + case ssh.SIGFPE: + return unix.SIGFPE + case ssh.SIGHUP: + return unix.SIGHUP + case ssh.SIGILL: + return unix.SIGILL + case ssh.SIGINT: + return unix.SIGINT + case ssh.SIGKILL: + return unix.SIGKILL + case ssh.SIGPIPE: + return unix.SIGPIPE + case ssh.SIGQUIT: + return unix.SIGQUIT + case ssh.SIGSEGV: + return unix.SIGSEGV + case ssh.SIGTERM: + return unix.SIGTERM + case ssh.SIGUSR1: + return unix.SIGUSR1 + case ssh.SIGUSR2: + return unix.SIGUSR2 + + // Unhandled, use sane fallback. + default: + return unix.SIGKILL + } +} diff --git a/agent/agentssh/signal_windows.go b/agent/agentssh/signal_windows.go new file mode 100644 index 0000000000..c7d5cae52a --- /dev/null +++ b/agent/agentssh/signal_windows.go @@ -0,0 +1,15 @@ +package agentssh + +import ( + "os" + + "github.com/gliderlabs/ssh" +) + +func osSignalFrom(sig ssh.Signal) os.Signal { + switch sig { + // Signals are not supported on Windows. + default: + return os.Kill + } +} diff --git a/pty/pty.go b/pty/pty.go index 507e9468e2..c51fcf003e 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -3,6 +3,7 @@ package pty import ( "io" "log" + "os" "github.com/gliderlabs/ssh" "golang.org/x/xerrors" @@ -69,6 +70,11 @@ type Process interface { // Kill the command process. Returned error is as for os.Process.Kill() Kill() error + + // Signal sends a signal to the command process. On non-windows systems, the + // returned error is as for os.Process.Signal(), on Windows it's + // as for os.Process.Kill(). + Signal(sig os.Signal) error } // WithFlags represents a PTY whose flags can be inspected, in particular diff --git a/pty/pty_other.go b/pty/pty_other.go index a5fa9d555d..67ca6ba6da 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -170,6 +170,10 @@ func (p *otherProcess) Kill() error { return p.cmd.Process.Kill() } +func (p *otherProcess) Signal(sig os.Signal) error { + return p.cmd.Process.Signal(sig) +} + func (p *otherProcess) waitInternal() { // The GC can garbage collect the TTY FD before the command // has finished running. See: diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 6d7ee60a89..93fea12019 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -243,6 +243,11 @@ func (p *windowsProcess) Kill() error { return p.proc.Kill() } +func (p *windowsProcess) Signal(sig os.Signal) error { + // Windows doesn't support signals. + return p.Kill() +} + // killOnContext waits for the context to be done and kills the process, unless it exits on its own first. func (p *windowsProcess) killOnContext(ctx context.Context) { select { diff --git a/pty/start_other_test.go b/pty/start_other_test.go index 63b6a36e8c..7cd874b7f6 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -3,6 +3,7 @@ package pty_test import ( + "os" "os/exec" "testing" @@ -46,6 +47,19 @@ func TestStart(t *testing.T) { require.NoError(t, err) }) + t.Run("Interrupt", func(t *testing.T) { + t.Parallel() + pty, ps := ptytest.Start(t, pty.Command("sleep", "30")) + err := ps.Signal(os.Interrupt) + assert.NoError(t, err) + err = ps.Wait() + var exitErr *exec.ExitError + require.True(t, xerrors.As(err, &exitErr)) + assert.NotEqual(t, 0, exitErr.ExitCode()) + err = pty.Close() + require.NoError(t, err) + }) + t.Run("SSH_TTY", func(t *testing.T) { t.Parallel() opts := pty.WithPTYOption(pty.WithSSHRequest(ssh.Pty{ diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index 280639cafe..094ba67f9d 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -5,6 +5,7 @@ package pty_test import ( "fmt" + "os" "os/exec" "testing" @@ -51,6 +52,18 @@ func TestStart(t *testing.T) { err = ptty.Close() require.NoError(t, err) }) + t.Run("Interrupt", func(t *testing.T) { + t.Parallel() + ptty, ps := ptytest.Start(t, pty.Command("cmd.exe")) + err := ps.Signal(os.Interrupt) // Actually does kill. + assert.NoError(t, err) + err = ps.Wait() + var exitErr *exec.ExitError + require.True(t, xerrors.As(err, &exitErr)) + assert.NotEqual(t, 0, exitErr.ExitCode()) + err = ptty.Close() + require.NoError(t, err) + }) } // these constants/vars are used by Test_Start_copy