feat(agent/agentssh): handle session signals (#10842)

This commit is contained in:
Mathias Fredriksson 2023-11-23 19:55:36 +02:00 committed by GitHub
parent a7c27cad26
commit 2c6e0f7d0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 312 additions and 9 deletions

View File

@ -311,10 +311,10 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
if isPty {
return s.startPTYSession(logger, session, magicTypeLabel, cmd, sshPty, windowSize)
}
return s.startNonPTYSession(session, magicTypeLabel, cmd.AsExec())
return s.startNonPTYSession(logger, session, magicTypeLabel, cmd.AsExec())
}
func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
cmd.Stdout = session
@ -338,6 +338,17 @@ func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string,
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1)
return xerrors.Errorf("start: %w", err)
}
sigs := make(chan ssh.Signal, 1)
session.Signals(sigs)
defer func() {
session.Signals(nil)
close(sigs)
}()
go func() {
for sig := range sigs {
s.handleSignal(logger, sig, cmd.Process, magicTypeLabel)
}
}()
return cmd.Wait()
}
@ -348,6 +359,7 @@ type ptySession interface {
Context() ssh.Context
DisablePTYEmulation()
RawCommand() string
Signals(chan<- ssh.Signal)
}
func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTypeLabel string, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
@ -403,13 +415,36 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
}
}
}()
sigs := make(chan ssh.Signal, 1)
session.Signals(sigs)
defer func() {
session.Signals(nil)
close(sigs)
}()
go func() {
for win := range windowSize {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
// If the pty is closed, then command has exited, no need to log.
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
for {
if sigs == nil && windowSize == nil {
return
}
select {
case sig, ok := <-sigs:
if !ok {
sigs = nil
continue
}
s.handleSignal(logger, sig, process, magicTypeLabel)
case win, ok := <-windowSize:
if !ok {
windowSize = nil
continue
}
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
// If the pty is closed, then command has exited, no need to log.
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
}
}
}
}()
@ -452,6 +487,18 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
return nil
}
func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) {
ctx := context.Background()
sig := osSignalFrom(ssig)
logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String()))
logger.Info(ctx, "received signal from client")
err := signaler.Signal(sig)
if err != nil {
logger.Warn(ctx, "signaling the process failed", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
}
}
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
s.metrics.sftpConnectionsTotal.Add(1)

View File

@ -114,6 +114,11 @@ type testSSHContext struct {
context.Context
}
var (
_ gliderssh.Context = testSSHContext{}
_ ptySession = &testSession{}
)
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
toClient, fromPty := io.Pipe()
toPty, fromClient := io.Pipe()
@ -144,6 +149,10 @@ func (s *testSession) Write(p []byte) (n int, err error) {
return s.fromPty.Write(p)
}
func (*testSession) Signals(_ chan<- gliderssh.Signal) {
// Not implemented, but will be called.
}
func (testSSHContext) Lock() {
panic("not implemented")
}

View File

@ -3,8 +3,10 @@
package agentssh_test
import (
"bufio"
"bytes"
"context"
"fmt"
"net"
"runtime"
"strings"
@ -24,6 +26,7 @@ import (
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
@ -57,8 +60,8 @@ func TestNewServer_ServeClient(t *testing.T) {
var b bytes.Buffer
sess, err := c.NewSession()
sess.Stdout = &b
require.NoError(t, err)
sess.Stdout = &b
err = sess.Start("echo hello")
require.NoError(t, err)
@ -139,6 +142,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
defer wg.Done()
c := sshClient(t, ln.Addr().String())
sess, err := c.NewSession()
assert.NoError(t, err)
sess.Stdin = pty.Input()
sess.Stdout = pty.Output()
sess.Stderr = pty.Output()
@ -159,6 +163,147 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
wg.Wait()
}
func TestNewServer_Signal(t *testing.T) {
t.Parallel()
t.Run("Stdout", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()
defer func() {
err := s.Close()
require.NoError(t, err)
<-done
}()
c := sshClient(t, ln.Addr().String())
sess, err := c.NewSession()
require.NoError(t, err)
r, err := sess.StdoutPipe()
require.NoError(t, err)
// Perform multiple sleeps since the interrupt signal doesn't propagate to
// the process group, this lets us exit early.
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
require.NoError(t, err)
sc := bufio.NewScanner(r)
for sc.Scan() {
t.Log(sc.Text())
if strings.Contains(sc.Text(), "hello") {
break
}
}
require.NoError(t, sc.Err())
err = sess.Signal(ssh.SIGINT)
require.NoError(t, err)
// Assumption, signal propagates and the command exists, closing stdout.
for sc.Scan() {
t.Log(sc.Text())
require.NotContains(t, sc.Text(), "bye")
}
require.NoError(t, sc.Err())
err = sess.Wait()
require.Error(t, err)
})
t.Run("PTY", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()
defer func() {
err := s.Close()
require.NoError(t, err)
<-done
}()
c := sshClient(t, ln.Addr().String())
pty := ptytest.New(t)
sess, err := c.NewSession()
require.NoError(t, err)
r, err := sess.StdoutPipe()
require.NoError(t, err)
// Note, we request pty but don't use ptytest here because we can't
// easily test for no text before EOF.
sess.Stdin = pty.Input()
sess.Stderr = pty.Output()
err = sess.RequestPty("xterm", 80, 80, nil)
require.NoError(t, err)
// Perform multiple sleeps since the interrupt signal doesn't propagate to
// the process group, this lets us exit early.
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
require.NoError(t, err)
sc := bufio.NewScanner(r)
for sc.Scan() {
t.Log(sc.Text())
if strings.Contains(sc.Text(), "hello") {
break
}
}
require.NoError(t, sc.Err())
err = sess.Signal(ssh.SIGINT)
require.NoError(t, err)
// Assumption, signal propagates and the command exists, closing stdout.
for sc.Scan() {
t.Log(sc.Text())
require.NotContains(t, sc.Text(), "bye")
}
require.NoError(t, sc.Err())
err = sess.Wait()
require.Error(t, err)
})
}
func sshClient(t *testing.T, addr string) *ssh.Client {
conn, err := net.Dial("tcp", addr)
require.NoError(t, err)

View File

@ -0,0 +1,45 @@
//go:build !windows
package agentssh
import (
"os"
"github.com/gliderlabs/ssh"
"golang.org/x/sys/unix"
)
func osSignalFrom(sig ssh.Signal) os.Signal {
switch sig {
case ssh.SIGABRT:
return unix.SIGABRT
case ssh.SIGALRM:
return unix.SIGALRM
case ssh.SIGFPE:
return unix.SIGFPE
case ssh.SIGHUP:
return unix.SIGHUP
case ssh.SIGILL:
return unix.SIGILL
case ssh.SIGINT:
return unix.SIGINT
case ssh.SIGKILL:
return unix.SIGKILL
case ssh.SIGPIPE:
return unix.SIGPIPE
case ssh.SIGQUIT:
return unix.SIGQUIT
case ssh.SIGSEGV:
return unix.SIGSEGV
case ssh.SIGTERM:
return unix.SIGTERM
case ssh.SIGUSR1:
return unix.SIGUSR1
case ssh.SIGUSR2:
return unix.SIGUSR2
// Unhandled, use sane fallback.
default:
return unix.SIGKILL
}
}

View File

@ -0,0 +1,15 @@
package agentssh
import (
"os"
"github.com/gliderlabs/ssh"
)
func osSignalFrom(sig ssh.Signal) os.Signal {
switch sig {
// Signals are not supported on Windows.
default:
return os.Kill
}
}

View File

@ -3,6 +3,7 @@ package pty
import (
"io"
"log"
"os"
"github.com/gliderlabs/ssh"
"golang.org/x/xerrors"
@ -69,6 +70,11 @@ type Process interface {
// Kill the command process. Returned error is as for os.Process.Kill()
Kill() error
// Signal sends a signal to the command process. On non-windows systems, the
// returned error is as for os.Process.Signal(), on Windows it's
// as for os.Process.Kill().
Signal(sig os.Signal) error
}
// WithFlags represents a PTY whose flags can be inspected, in particular

View File

@ -170,6 +170,10 @@ func (p *otherProcess) Kill() error {
return p.cmd.Process.Kill()
}
func (p *otherProcess) Signal(sig os.Signal) error {
return p.cmd.Process.Signal(sig)
}
func (p *otherProcess) waitInternal() {
// The GC can garbage collect the TTY FD before the command
// has finished running. See:

View File

@ -243,6 +243,11 @@ func (p *windowsProcess) Kill() error {
return p.proc.Kill()
}
func (p *windowsProcess) Signal(sig os.Signal) error {
// Windows doesn't support signals.
return p.Kill()
}
// killOnContext waits for the context to be done and kills the process, unless it exits on its own first.
func (p *windowsProcess) killOnContext(ctx context.Context) {
select {

View File

@ -3,6 +3,7 @@
package pty_test
import (
"os"
"os/exec"
"testing"
@ -46,6 +47,19 @@ func TestStart(t *testing.T) {
require.NoError(t, err)
})
t.Run("Interrupt", func(t *testing.T) {
t.Parallel()
pty, ps := ptytest.Start(t, pty.Command("sleep", "30"))
err := ps.Signal(os.Interrupt)
assert.NoError(t, err)
err = ps.Wait()
var exitErr *exec.ExitError
require.True(t, xerrors.As(err, &exitErr))
assert.NotEqual(t, 0, exitErr.ExitCode())
err = pty.Close()
require.NoError(t, err)
})
t.Run("SSH_TTY", func(t *testing.T) {
t.Parallel()
opts := pty.WithPTYOption(pty.WithSSHRequest(ssh.Pty{

View File

@ -5,6 +5,7 @@ package pty_test
import (
"fmt"
"os"
"os/exec"
"testing"
@ -51,6 +52,18 @@ func TestStart(t *testing.T) {
err = ptty.Close()
require.NoError(t, err)
})
t.Run("Interrupt", func(t *testing.T) {
t.Parallel()
ptty, ps := ptytest.Start(t, pty.Command("cmd.exe"))
err := ps.Signal(os.Interrupt) // Actually does kill.
assert.NoError(t, err)
err = ps.Wait()
var exitErr *exec.ExitError
require.True(t, xerrors.As(err, &exitErr))
assert.NotEqual(t, 0, exitErr.ExitCode())
err = ptty.Close()
require.NoError(t, err)
})
}
// these constants/vars are used by Test_Start_copy