mirror of https://github.com/coder/coder.git
feat(agent/agentssh): handle session signals (#10842)
This commit is contained in:
parent
a7c27cad26
commit
2c6e0f7d0a
|
@ -311,10 +311,10 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
|
|||
if isPty {
|
||||
return s.startPTYSession(logger, session, magicTypeLabel, cmd, sshPty, windowSize)
|
||||
}
|
||||
return s.startNonPTYSession(session, magicTypeLabel, cmd.AsExec())
|
||||
return s.startNonPTYSession(logger, session, magicTypeLabel, cmd.AsExec())
|
||||
}
|
||||
|
||||
func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
|
||||
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
|
||||
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
|
||||
|
||||
cmd.Stdout = session
|
||||
|
@ -338,6 +338,17 @@ func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string,
|
|||
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1)
|
||||
return xerrors.Errorf("start: %w", err)
|
||||
}
|
||||
sigs := make(chan ssh.Signal, 1)
|
||||
session.Signals(sigs)
|
||||
defer func() {
|
||||
session.Signals(nil)
|
||||
close(sigs)
|
||||
}()
|
||||
go func() {
|
||||
for sig := range sigs {
|
||||
s.handleSignal(logger, sig, cmd.Process, magicTypeLabel)
|
||||
}
|
||||
}()
|
||||
return cmd.Wait()
|
||||
}
|
||||
|
||||
|
@ -348,6 +359,7 @@ type ptySession interface {
|
|||
Context() ssh.Context
|
||||
DisablePTYEmulation()
|
||||
RawCommand() string
|
||||
Signals(chan<- ssh.Signal)
|
||||
}
|
||||
|
||||
func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTypeLabel string, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
|
||||
|
@ -403,8 +415,30 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
|
|||
}
|
||||
}
|
||||
}()
|
||||
sigs := make(chan ssh.Signal, 1)
|
||||
session.Signals(sigs)
|
||||
defer func() {
|
||||
session.Signals(nil)
|
||||
close(sigs)
|
||||
}()
|
||||
go func() {
|
||||
for win := range windowSize {
|
||||
for {
|
||||
if sigs == nil && windowSize == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sig, ok := <-sigs:
|
||||
if !ok {
|
||||
sigs = nil
|
||||
continue
|
||||
}
|
||||
s.handleSignal(logger, sig, process, magicTypeLabel)
|
||||
case win, ok := <-windowSize:
|
||||
if !ok {
|
||||
windowSize = nil
|
||||
continue
|
||||
}
|
||||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||
// If the pty is closed, then command has exited, no need to log.
|
||||
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
|
||||
|
@ -412,6 +446,7 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
|
|||
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
@ -452,6 +487,18 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) {
|
||||
ctx := context.Background()
|
||||
sig := osSignalFrom(ssig)
|
||||
logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String()))
|
||||
logger.Info(ctx, "received signal from client")
|
||||
err := signaler.Signal(sig)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "signaling the process failed", slog.Error(err))
|
||||
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
|
||||
s.metrics.sftpConnectionsTotal.Add(1)
|
||||
|
||||
|
|
|
@ -114,6 +114,11 @@ type testSSHContext struct {
|
|||
context.Context
|
||||
}
|
||||
|
||||
var (
|
||||
_ gliderssh.Context = testSSHContext{}
|
||||
_ ptySession = &testSession{}
|
||||
)
|
||||
|
||||
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
|
||||
toClient, fromPty := io.Pipe()
|
||||
toPty, fromClient := io.Pipe()
|
||||
|
@ -144,6 +149,10 @@ func (s *testSession) Write(p []byte) (n int, err error) {
|
|||
return s.fromPty.Write(p)
|
||||
}
|
||||
|
||||
func (*testSession) Signals(_ chan<- gliderssh.Signal) {
|
||||
// Not implemented, but will be called.
|
||||
}
|
||||
|
||||
func (testSSHContext) Lock() {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
|
|
@ -3,8 +3,10 @@
|
|||
package agentssh_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
@ -24,6 +26,7 @@ import (
|
|||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -57,8 +60,8 @@ func TestNewServer_ServeClient(t *testing.T) {
|
|||
|
||||
var b bytes.Buffer
|
||||
sess, err := c.NewSession()
|
||||
sess.Stdout = &b
|
||||
require.NoError(t, err)
|
||||
sess.Stdout = &b
|
||||
err = sess.Start("echo hello")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -139,6 +142,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
|
|||
defer wg.Done()
|
||||
c := sshClient(t, ln.Addr().String())
|
||||
sess, err := c.NewSession()
|
||||
assert.NoError(t, err)
|
||||
sess.Stdin = pty.Input()
|
||||
sess.Stdout = pty.Output()
|
||||
sess.Stderr = pty.Output()
|
||||
|
@ -159,6 +163,147 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestNewServer_Signal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Stdout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// The assumption is that these are set before serving SSH connections.
|
||||
s.AgentToken = func() string { return "" }
|
||||
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
err := s.Serve(ln)
|
||||
assert.Error(t, err) // Server is closed.
|
||||
}()
|
||||
defer func() {
|
||||
err := s.Close()
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
}()
|
||||
|
||||
c := sshClient(t, ln.Addr().String())
|
||||
|
||||
sess, err := c.NewSession()
|
||||
require.NoError(t, err)
|
||||
r, err := sess.StdoutPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Perform multiple sleeps since the interrupt signal doesn't propagate to
|
||||
// the process group, this lets us exit early.
|
||||
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
|
||||
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
|
||||
require.NoError(t, err)
|
||||
|
||||
sc := bufio.NewScanner(r)
|
||||
for sc.Scan() {
|
||||
t.Log(sc.Text())
|
||||
if strings.Contains(sc.Text(), "hello") {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NoError(t, sc.Err())
|
||||
|
||||
err = sess.Signal(ssh.SIGINT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assumption, signal propagates and the command exists, closing stdout.
|
||||
for sc.Scan() {
|
||||
t.Log(sc.Text())
|
||||
require.NotContains(t, sc.Text(), "bye")
|
||||
}
|
||||
require.NoError(t, sc.Err())
|
||||
|
||||
err = sess.Wait()
|
||||
require.Error(t, err)
|
||||
})
|
||||
t.Run("PTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// The assumption is that these are set before serving SSH connections.
|
||||
s.AgentToken = func() string { return "" }
|
||||
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
err := s.Serve(ln)
|
||||
assert.Error(t, err) // Server is closed.
|
||||
}()
|
||||
defer func() {
|
||||
err := s.Close()
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
}()
|
||||
|
||||
c := sshClient(t, ln.Addr().String())
|
||||
|
||||
pty := ptytest.New(t)
|
||||
|
||||
sess, err := c.NewSession()
|
||||
require.NoError(t, err)
|
||||
r, err := sess.StdoutPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note, we request pty but don't use ptytest here because we can't
|
||||
// easily test for no text before EOF.
|
||||
sess.Stdin = pty.Input()
|
||||
sess.Stderr = pty.Output()
|
||||
|
||||
err = sess.RequestPty("xterm", 80, 80, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Perform multiple sleeps since the interrupt signal doesn't propagate to
|
||||
// the process group, this lets us exit early.
|
||||
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
|
||||
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
|
||||
require.NoError(t, err)
|
||||
|
||||
sc := bufio.NewScanner(r)
|
||||
for sc.Scan() {
|
||||
t.Log(sc.Text())
|
||||
if strings.Contains(sc.Text(), "hello") {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NoError(t, sc.Err())
|
||||
|
||||
err = sess.Signal(ssh.SIGINT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assumption, signal propagates and the command exists, closing stdout.
|
||||
for sc.Scan() {
|
||||
t.Log(sc.Text())
|
||||
require.NotContains(t, sc.Text(), "bye")
|
||||
}
|
||||
require.NoError(t, sc.Err())
|
||||
|
||||
err = sess.Wait()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func sshClient(t *testing.T, addr string) *ssh.Client {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
//go:build !windows
|
||||
|
||||
package agentssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func osSignalFrom(sig ssh.Signal) os.Signal {
|
||||
switch sig {
|
||||
case ssh.SIGABRT:
|
||||
return unix.SIGABRT
|
||||
case ssh.SIGALRM:
|
||||
return unix.SIGALRM
|
||||
case ssh.SIGFPE:
|
||||
return unix.SIGFPE
|
||||
case ssh.SIGHUP:
|
||||
return unix.SIGHUP
|
||||
case ssh.SIGILL:
|
||||
return unix.SIGILL
|
||||
case ssh.SIGINT:
|
||||
return unix.SIGINT
|
||||
case ssh.SIGKILL:
|
||||
return unix.SIGKILL
|
||||
case ssh.SIGPIPE:
|
||||
return unix.SIGPIPE
|
||||
case ssh.SIGQUIT:
|
||||
return unix.SIGQUIT
|
||||
case ssh.SIGSEGV:
|
||||
return unix.SIGSEGV
|
||||
case ssh.SIGTERM:
|
||||
return unix.SIGTERM
|
||||
case ssh.SIGUSR1:
|
||||
return unix.SIGUSR1
|
||||
case ssh.SIGUSR2:
|
||||
return unix.SIGUSR2
|
||||
|
||||
// Unhandled, use sane fallback.
|
||||
default:
|
||||
return unix.SIGKILL
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package agentssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
)
|
||||
|
||||
func osSignalFrom(sig ssh.Signal) os.Signal {
|
||||
switch sig {
|
||||
// Signals are not supported on Windows.
|
||||
default:
|
||||
return os.Kill
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ package pty
|
|||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -69,6 +70,11 @@ type Process interface {
|
|||
|
||||
// Kill the command process. Returned error is as for os.Process.Kill()
|
||||
Kill() error
|
||||
|
||||
// Signal sends a signal to the command process. On non-windows systems, the
|
||||
// returned error is as for os.Process.Signal(), on Windows it's
|
||||
// as for os.Process.Kill().
|
||||
Signal(sig os.Signal) error
|
||||
}
|
||||
|
||||
// WithFlags represents a PTY whose flags can be inspected, in particular
|
||||
|
|
|
@ -170,6 +170,10 @@ func (p *otherProcess) Kill() error {
|
|||
return p.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
func (p *otherProcess) Signal(sig os.Signal) error {
|
||||
return p.cmd.Process.Signal(sig)
|
||||
}
|
||||
|
||||
func (p *otherProcess) waitInternal() {
|
||||
// The GC can garbage collect the TTY FD before the command
|
||||
// has finished running. See:
|
||||
|
|
|
@ -243,6 +243,11 @@ func (p *windowsProcess) Kill() error {
|
|||
return p.proc.Kill()
|
||||
}
|
||||
|
||||
func (p *windowsProcess) Signal(sig os.Signal) error {
|
||||
// Windows doesn't support signals.
|
||||
return p.Kill()
|
||||
}
|
||||
|
||||
// killOnContext waits for the context to be done and kills the process, unless it exits on its own first.
|
||||
func (p *windowsProcess) killOnContext(ctx context.Context) {
|
||||
select {
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package pty_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
|
@ -46,6 +47,19 @@ func TestStart(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Interrupt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pty, ps := ptytest.Start(t, pty.Command("sleep", "30"))
|
||||
err := ps.Signal(os.Interrupt)
|
||||
assert.NoError(t, err)
|
||||
err = ps.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, xerrors.As(err, &exitErr))
|
||||
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||
err = pty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SSH_TTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := pty.WithPTYOption(pty.WithSSHRequest(ssh.Pty{
|
||||
|
|
|
@ -5,6 +5,7 @@ package pty_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
|
@ -51,6 +52,18 @@ func TestStart(t *testing.T) {
|
|||
err = ptty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("Interrupt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ptty, ps := ptytest.Start(t, pty.Command("cmd.exe"))
|
||||
err := ps.Signal(os.Interrupt) // Actually does kill.
|
||||
assert.NoError(t, err)
|
||||
err = ps.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, xerrors.As(err, &exitErr))
|
||||
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||
err = ptty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// these constants/vars are used by Test_Start_copy
|
||||
|
|
Loading…
Reference in New Issue