mirror of https://github.com/coder/coder.git
feat(agent/agentssh): handle session signals (#10842)
This commit is contained in:
parent
a7c27cad26
commit
2c6e0f7d0a
|
@ -311,10 +311,10 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv
|
||||||
if isPty {
|
if isPty {
|
||||||
return s.startPTYSession(logger, session, magicTypeLabel, cmd, sshPty, windowSize)
|
return s.startPTYSession(logger, session, magicTypeLabel, cmd, sshPty, windowSize)
|
||||||
}
|
}
|
||||||
return s.startNonPTYSession(session, magicTypeLabel, cmd.AsExec())
|
return s.startNonPTYSession(logger, session, magicTypeLabel, cmd.AsExec())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
|
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
|
||||||
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
|
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
|
||||||
|
|
||||||
cmd.Stdout = session
|
cmd.Stdout = session
|
||||||
|
@ -338,6 +338,17 @@ func (s *Server) startNonPTYSession(session ssh.Session, magicTypeLabel string,
|
||||||
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1)
|
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1)
|
||||||
return xerrors.Errorf("start: %w", err)
|
return xerrors.Errorf("start: %w", err)
|
||||||
}
|
}
|
||||||
|
sigs := make(chan ssh.Signal, 1)
|
||||||
|
session.Signals(sigs)
|
||||||
|
defer func() {
|
||||||
|
session.Signals(nil)
|
||||||
|
close(sigs)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
for sig := range sigs {
|
||||||
|
s.handleSignal(logger, sig, cmd.Process, magicTypeLabel)
|
||||||
|
}
|
||||||
|
}()
|
||||||
return cmd.Wait()
|
return cmd.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,6 +359,7 @@ type ptySession interface {
|
||||||
Context() ssh.Context
|
Context() ssh.Context
|
||||||
DisablePTYEmulation()
|
DisablePTYEmulation()
|
||||||
RawCommand() string
|
RawCommand() string
|
||||||
|
Signals(chan<- ssh.Signal)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTypeLabel string, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
|
func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTypeLabel string, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
|
||||||
|
@ -403,13 +415,36 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
sigs := make(chan ssh.Signal, 1)
|
||||||
|
session.Signals(sigs)
|
||||||
|
defer func() {
|
||||||
|
session.Signals(nil)
|
||||||
|
close(sigs)
|
||||||
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
for win := range windowSize {
|
for {
|
||||||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
if sigs == nil && windowSize == nil {
|
||||||
// If the pty is closed, then command has exited, no need to log.
|
return
|
||||||
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
|
}
|
||||||
logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
|
|
||||||
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
|
select {
|
||||||
|
case sig, ok := <-sigs:
|
||||||
|
if !ok {
|
||||||
|
sigs = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.handleSignal(logger, sig, process, magicTypeLabel)
|
||||||
|
case win, ok := <-windowSize:
|
||||||
|
if !ok {
|
||||||
|
windowSize = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||||
|
// If the pty is closed, then command has exited, no need to log.
|
||||||
|
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
|
||||||
|
logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
|
||||||
|
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "resize").Add(1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -452,6 +487,18 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) {
|
||||||
|
ctx := context.Background()
|
||||||
|
sig := osSignalFrom(ssig)
|
||||||
|
logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String()))
|
||||||
|
logger.Info(ctx, "received signal from client")
|
||||||
|
err := signaler.Signal(sig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn(ctx, "signaling the process failed", slog.Error(err))
|
||||||
|
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
|
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
|
||||||
s.metrics.sftpConnectionsTotal.Add(1)
|
s.metrics.sftpConnectionsTotal.Add(1)
|
||||||
|
|
||||||
|
|
|
@ -114,6 +114,11 @@ type testSSHContext struct {
|
||||||
context.Context
|
context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ gliderssh.Context = testSSHContext{}
|
||||||
|
_ ptySession = &testSession{}
|
||||||
|
)
|
||||||
|
|
||||||
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
|
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
|
||||||
toClient, fromPty := io.Pipe()
|
toClient, fromPty := io.Pipe()
|
||||||
toPty, fromClient := io.Pipe()
|
toPty, fromClient := io.Pipe()
|
||||||
|
@ -144,6 +149,10 @@ func (s *testSession) Write(p []byte) (n int, err error) {
|
||||||
return s.fromPty.Write(p)
|
return s.fromPty.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*testSession) Signals(_ chan<- gliderssh.Signal) {
|
||||||
|
// Not implemented, but will be called.
|
||||||
|
}
|
||||||
|
|
||||||
func (testSSHContext) Lock() {
|
func (testSSHContext) Lock() {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,10 @@
|
||||||
package agentssh_test
|
package agentssh_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -24,6 +26,7 @@ import (
|
||||||
"github.com/coder/coder/v2/agent/agentssh"
|
"github.com/coder/coder/v2/agent/agentssh"
|
||||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||||
"github.com/coder/coder/v2/pty/ptytest"
|
"github.com/coder/coder/v2/pty/ptytest"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
@ -57,8 +60,8 @@ func TestNewServer_ServeClient(t *testing.T) {
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
sess, err := c.NewSession()
|
sess, err := c.NewSession()
|
||||||
sess.Stdout = &b
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
sess.Stdout = &b
|
||||||
err = sess.Start("echo hello")
|
err = sess.Start("echo hello")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -139,6 +142,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
c := sshClient(t, ln.Addr().String())
|
c := sshClient(t, ln.Addr().String())
|
||||||
sess, err := c.NewSession()
|
sess, err := c.NewSession()
|
||||||
|
assert.NoError(t, err)
|
||||||
sess.Stdin = pty.Input()
|
sess.Stdin = pty.Input()
|
||||||
sess.Stdout = pty.Output()
|
sess.Stdout = pty.Output()
|
||||||
sess.Stderr = pty.Output()
|
sess.Stderr = pty.Output()
|
||||||
|
@ -159,6 +163,147 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewServer_Signal(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("Stdout", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
logger := slogtest.Make(t, nil)
|
||||||
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
// The assumption is that these are set before serving SSH connections.
|
||||||
|
s.AgentToken = func() string { return "" }
|
||||||
|
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
err := s.Serve(ln)
|
||||||
|
assert.Error(t, err) // Server is closed.
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
err := s.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
<-done
|
||||||
|
}()
|
||||||
|
|
||||||
|
c := sshClient(t, ln.Addr().String())
|
||||||
|
|
||||||
|
sess, err := c.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
r, err := sess.StdoutPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Perform multiple sleeps since the interrupt signal doesn't propagate to
|
||||||
|
// the process group, this lets us exit early.
|
||||||
|
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
|
||||||
|
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sc := bufio.NewScanner(r)
|
||||||
|
for sc.Scan() {
|
||||||
|
t.Log(sc.Text())
|
||||||
|
if strings.Contains(sc.Text(), "hello") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NoError(t, sc.Err())
|
||||||
|
|
||||||
|
err = sess.Signal(ssh.SIGINT)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Assumption, signal propagates and the command exists, closing stdout.
|
||||||
|
for sc.Scan() {
|
||||||
|
t.Log(sc.Text())
|
||||||
|
require.NotContains(t, sc.Text(), "bye")
|
||||||
|
}
|
||||||
|
require.NoError(t, sc.Err())
|
||||||
|
|
||||||
|
err = sess.Wait()
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("PTY", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
logger := slogtest.Make(t, nil)
|
||||||
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
// The assumption is that these are set before serving SSH connections.
|
||||||
|
s.AgentToken = func() string { return "" }
|
||||||
|
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
err := s.Serve(ln)
|
||||||
|
assert.Error(t, err) // Server is closed.
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
err := s.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
<-done
|
||||||
|
}()
|
||||||
|
|
||||||
|
c := sshClient(t, ln.Addr().String())
|
||||||
|
|
||||||
|
pty := ptytest.New(t)
|
||||||
|
|
||||||
|
sess, err := c.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
r, err := sess.StdoutPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Note, we request pty but don't use ptytest here because we can't
|
||||||
|
// easily test for no text before EOF.
|
||||||
|
sess.Stdin = pty.Input()
|
||||||
|
sess.Stderr = pty.Output()
|
||||||
|
|
||||||
|
err = sess.RequestPty("xterm", 80, 80, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Perform multiple sleeps since the interrupt signal doesn't propagate to
|
||||||
|
// the process group, this lets us exit early.
|
||||||
|
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
|
||||||
|
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sc := bufio.NewScanner(r)
|
||||||
|
for sc.Scan() {
|
||||||
|
t.Log(sc.Text())
|
||||||
|
if strings.Contains(sc.Text(), "hello") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NoError(t, sc.Err())
|
||||||
|
|
||||||
|
err = sess.Signal(ssh.SIGINT)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Assumption, signal propagates and the command exists, closing stdout.
|
||||||
|
for sc.Scan() {
|
||||||
|
t.Log(sc.Text())
|
||||||
|
require.NotContains(t, sc.Text(), "bye")
|
||||||
|
}
|
||||||
|
require.NoError(t, sc.Err())
|
||||||
|
|
||||||
|
err = sess.Wait()
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func sshClient(t *testing.T, addr string) *ssh.Client {
|
func sshClient(t *testing.T, addr string) *ssh.Client {
|
||||||
conn, err := net.Dial("tcp", addr)
|
conn, err := net.Dial("tcp", addr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package agentssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func osSignalFrom(sig ssh.Signal) os.Signal {
|
||||||
|
switch sig {
|
||||||
|
case ssh.SIGABRT:
|
||||||
|
return unix.SIGABRT
|
||||||
|
case ssh.SIGALRM:
|
||||||
|
return unix.SIGALRM
|
||||||
|
case ssh.SIGFPE:
|
||||||
|
return unix.SIGFPE
|
||||||
|
case ssh.SIGHUP:
|
||||||
|
return unix.SIGHUP
|
||||||
|
case ssh.SIGILL:
|
||||||
|
return unix.SIGILL
|
||||||
|
case ssh.SIGINT:
|
||||||
|
return unix.SIGINT
|
||||||
|
case ssh.SIGKILL:
|
||||||
|
return unix.SIGKILL
|
||||||
|
case ssh.SIGPIPE:
|
||||||
|
return unix.SIGPIPE
|
||||||
|
case ssh.SIGQUIT:
|
||||||
|
return unix.SIGQUIT
|
||||||
|
case ssh.SIGSEGV:
|
||||||
|
return unix.SIGSEGV
|
||||||
|
case ssh.SIGTERM:
|
||||||
|
return unix.SIGTERM
|
||||||
|
case ssh.SIGUSR1:
|
||||||
|
return unix.SIGUSR1
|
||||||
|
case ssh.SIGUSR2:
|
||||||
|
return unix.SIGUSR2
|
||||||
|
|
||||||
|
// Unhandled, use sane fallback.
|
||||||
|
default:
|
||||||
|
return unix.SIGKILL
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package agentssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func osSignalFrom(sig ssh.Signal) os.Signal {
|
||||||
|
switch sig {
|
||||||
|
// Signals are not supported on Windows.
|
||||||
|
default:
|
||||||
|
return os.Kill
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,6 +3,7 @@ package pty
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
@ -69,6 +70,11 @@ type Process interface {
|
||||||
|
|
||||||
// Kill the command process. Returned error is as for os.Process.Kill()
|
// Kill the command process. Returned error is as for os.Process.Kill()
|
||||||
Kill() error
|
Kill() error
|
||||||
|
|
||||||
|
// Signal sends a signal to the command process. On non-windows systems, the
|
||||||
|
// returned error is as for os.Process.Signal(), on Windows it's
|
||||||
|
// as for os.Process.Kill().
|
||||||
|
Signal(sig os.Signal) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithFlags represents a PTY whose flags can be inspected, in particular
|
// WithFlags represents a PTY whose flags can be inspected, in particular
|
||||||
|
|
|
@ -170,6 +170,10 @@ func (p *otherProcess) Kill() error {
|
||||||
return p.cmd.Process.Kill()
|
return p.cmd.Process.Kill()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *otherProcess) Signal(sig os.Signal) error {
|
||||||
|
return p.cmd.Process.Signal(sig)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *otherProcess) waitInternal() {
|
func (p *otherProcess) waitInternal() {
|
||||||
// The GC can garbage collect the TTY FD before the command
|
// The GC can garbage collect the TTY FD before the command
|
||||||
// has finished running. See:
|
// has finished running. See:
|
||||||
|
|
|
@ -243,6 +243,11 @@ func (p *windowsProcess) Kill() error {
|
||||||
return p.proc.Kill()
|
return p.proc.Kill()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *windowsProcess) Signal(sig os.Signal) error {
|
||||||
|
// Windows doesn't support signals.
|
||||||
|
return p.Kill()
|
||||||
|
}
|
||||||
|
|
||||||
// killOnContext waits for the context to be done and kills the process, unless it exits on its own first.
|
// killOnContext waits for the context to be done and kills the process, unless it exits on its own first.
|
||||||
func (p *windowsProcess) killOnContext(ctx context.Context) {
|
func (p *windowsProcess) killOnContext(ctx context.Context) {
|
||||||
select {
|
select {
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
package pty_test
|
package pty_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -46,6 +47,19 @@ func TestStart(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Interrupt", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
pty, ps := ptytest.Start(t, pty.Command("sleep", "30"))
|
||||||
|
err := ps.Signal(os.Interrupt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = ps.Wait()
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.True(t, xerrors.As(err, &exitErr))
|
||||||
|
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||||
|
err = pty.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("SSH_TTY", func(t *testing.T) {
|
t.Run("SSH_TTY", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
opts := pty.WithPTYOption(pty.WithSSHRequest(ssh.Pty{
|
opts := pty.WithPTYOption(pty.WithSSHRequest(ssh.Pty{
|
||||||
|
|
|
@ -5,6 +5,7 @@ package pty_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -51,6 +52,18 @@ func TestStart(t *testing.T) {
|
||||||
err = ptty.Close()
|
err = ptty.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
t.Run("Interrupt", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ptty, ps := ptytest.Start(t, pty.Command("cmd.exe"))
|
||||||
|
err := ps.Signal(os.Interrupt) // Actually does kill.
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = ps.Wait()
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.True(t, xerrors.As(err, &exitErr))
|
||||||
|
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||||
|
err = ptty.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// these constants/vars are used by Test_Start_copy
|
// these constants/vars are used by Test_Start_copy
|
||||||
|
|
Loading…
Reference in New Issue