mirror of https://github.com/coder/coder.git
Return proper exit code on ssh with TTY (#3192)
* Return proper exit code on ssh with TTY Signed-off-by: Spike Curtis <spike@coder.com> * Fix revive lint Signed-off-by: Spike Curtis <spike@coder.com> * Fix Windows exit code for missing command Signed-off-by: Spike Curtis <spike@coder.com> * Fix close error handling on agent TTY Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
a37e61a099
commit
36ffdce065
|
@ -43,6 +43,11 @@ const (
|
|||
ProtocolReconnectingPTY = "reconnecting-pty"
|
||||
ProtocolSSH = "ssh"
|
||||
ProtocolDial = "dial"
|
||||
|
||||
// MagicSessionErrorCode indicates that something went wrong with the session, rather than the
|
||||
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
|
||||
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
|
||||
MagicSessionErrorCode = 229
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
|
@ -273,9 +278,17 @@ func (a *agent) init(ctx context.Context) {
|
|||
},
|
||||
Handler: func(session ssh.Session) {
|
||||
err := a.handleSSHSession(session)
|
||||
var exitError *exec.ExitError
|
||||
if xerrors.As(err, &exitError) {
|
||||
a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
|
||||
_ = session.Exit(exitError.ExitCode())
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "ssh session failed", slog.Error(err))
|
||||
_ = session.Exit(1)
|
||||
// This exit code is designed to be unlikely to be confused for a legit exit code
|
||||
// from the process.
|
||||
_ = session.Exit(MagicSessionErrorCode)
|
||||
return
|
||||
}
|
||||
},
|
||||
|
@ -403,7 +416,7 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
|
|||
return cmd, nil
|
||||
}
|
||||
|
||||
func (a *agent) handleSSHSession(session ssh.Session) error {
|
||||
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
||||
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -426,14 +439,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
|
|||
if err != nil {
|
||||
return xerrors.Errorf("start command: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := ptty.Close()
|
||||
if closeErr != nil {
|
||||
a.logger.Warn(context.Background(), "failed to close tty",
|
||||
slog.Error(closeErr))
|
||||
if retErr == nil {
|
||||
retErr = closeErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resize ptty: %w", err)
|
||||
}
|
||||
go func() {
|
||||
for win := range windowSize {
|
||||
err = ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||
if err != nil {
|
||||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||
if resizeErr != nil {
|
||||
a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
@ -444,9 +467,15 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
|
|||
go func() {
|
||||
_, _ = io.Copy(session, ptty.Output())
|
||||
}()
|
||||
_, _ = process.Wait()
|
||||
_ = ptty.Close()
|
||||
return nil
|
||||
err = process.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
|
||||
// and not something to be concerned about. But, if it's something else, we should log it.
|
||||
if err != nil && !xerrors.As(err, &exitErr) {
|
||||
a.logger.Warn(context.Background(), "wait error",
|
||||
slog.Error(err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Stdout = session
|
||||
|
@ -549,7 +578,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
|
|||
go func() {
|
||||
// If the process dies randomly, we should
|
||||
// close the pty.
|
||||
_, _ = process.Wait()
|
||||
_ = process.Wait()
|
||||
rpty.Close()
|
||||
}()
|
||||
go func() {
|
||||
|
|
|
@ -16,6 +16,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
scp "github.com/bramvdbogaerde/go-scp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/udp"
|
||||
|
@ -69,7 +71,7 @@ func TestAgent(t *testing.T) {
|
|||
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
|
||||
})
|
||||
|
||||
t.Run("SessionTTY", func(t *testing.T) {
|
||||
t.Run("SessionTTYShell", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
// This might be our implementation, or ConPTY itself.
|
||||
|
@ -103,6 +105,29 @@ func TestAgent(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SessionTTYExitCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
session := setupSSHSession(t, agent.Metadata{})
|
||||
command := "areallynotrealcommand"
|
||||
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||||
require.NoError(t, err)
|
||||
ptty := ptytest.New(t)
|
||||
require.NoError(t, err)
|
||||
session.Stdout = ptty.Output()
|
||||
session.Stderr = ptty.Output()
|
||||
session.Stdin = ptty.Input()
|
||||
err = session.Start(command)
|
||||
require.NoError(t, err)
|
||||
err = session.Wait()
|
||||
exitErr := &ssh.ExitError{}
|
||||
require.True(t, xerrors.As(err, &exitErr))
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Equal(t, 1, exitErr.ExitStatus())
|
||||
} else {
|
||||
assert.Equal(t, 127, exitErr.ExitStatus())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LocalForwarding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
random, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
|
|
10
pty/pty.go
10
pty/pty.go
|
@ -29,6 +29,16 @@ type PTY interface {
|
|||
Resize(height uint16, width uint16) error
|
||||
}
|
||||
|
||||
// Process represents a process running in a PTY
|
||||
type Process interface {
|
||||
|
||||
// Wait for the command to complete. Returned error is as for exec.Cmd.Wait()
|
||||
Wait() error
|
||||
|
||||
// Kill the command process. Returned error is as for os.Process.Kill()
|
||||
Kill() error
|
||||
}
|
||||
|
||||
// WithFlags represents a PTY whose flags can be inspected, in particular
|
||||
// to determine whether local echo is enabled.
|
||||
type WithFlags interface {
|
||||
|
|
|
@ -5,6 +5,8 @@ package pty
|
|||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/creack/pty"
|
||||
|
@ -27,6 +29,15 @@ type otherPty struct {
|
|||
pty, tty *os.File
|
||||
}
|
||||
|
||||
type otherProcess struct {
|
||||
pty *os.File
|
||||
cmd *exec.Cmd
|
||||
|
||||
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
|
||||
cmdDone chan any
|
||||
cmdErr error
|
||||
}
|
||||
|
||||
func (p *otherPty) Input() ReadWriter {
|
||||
return ReadWriter{
|
||||
Reader: p.tty,
|
||||
|
@ -66,3 +77,21 @@ func (p *otherPty) Close() error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *otherProcess) Wait() error {
|
||||
<-p.cmdDone
|
||||
return p.cmdErr
|
||||
}
|
||||
|
||||
func (p *otherProcess) Kill() error {
|
||||
return p.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
func (p *otherProcess) waitInternal() {
|
||||
// The GC can garbage collect the TTY FD before the command
|
||||
// has finished running. See:
|
||||
// https://github.com/creack/pty/issues/127#issuecomment-932764012
|
||||
p.cmdErr = p.cmd.Wait()
|
||||
runtime.KeepAlive(p.pty)
|
||||
close(p.cmdDone)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ package pty
|
|||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
|
@ -66,6 +67,13 @@ type ptyWindows struct {
|
|||
closed bool
|
||||
}
|
||||
|
||||
type windowsProcess struct {
|
||||
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
|
||||
cmdDone chan any
|
||||
cmdErr error
|
||||
proc *os.Process
|
||||
}
|
||||
|
||||
func (p *ptyWindows) Output() ReadWriter {
|
||||
return ReadWriter{
|
||||
Reader: p.outputRead,
|
||||
|
@ -111,3 +119,25 @@ func (p *ptyWindows) Close() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *windowsProcess) waitInternal() {
|
||||
defer close(p.cmdDone)
|
||||
state, err := p.proc.Wait()
|
||||
if err != nil {
|
||||
p.cmdErr = err
|
||||
return
|
||||
}
|
||||
if !state.Success() {
|
||||
p.cmdErr = &exec.ExitError{ProcessState: state}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *windowsProcess) Wait() error {
|
||||
<-p.cmdDone
|
||||
return p.cmdErr
|
||||
}
|
||||
|
||||
func (p *windowsProcess) Kill() error {
|
||||
return p.proc.Kill()
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
@ -27,7 +26,7 @@ func New(t *testing.T) *PTY {
|
|||
return create(t, ptty, "cmd")
|
||||
}
|
||||
|
||||
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
|
||||
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, pty.Process) {
|
||||
ptty, ps, err := pty.Start(cmd)
|
||||
require.NoError(t, err)
|
||||
return create(t, ptty, cmd.Args[0]), ps
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package pty
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func Start(cmd *exec.Cmd) (PTY, *os.Process, error) {
|
||||
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
|
||||
// instead rely on the returned Process to manage the command/process.
|
||||
func Start(cmd *exec.Cmd) (PTY, Process, error) {
|
||||
return startPty(cmd)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
package pty
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
@ -14,7 +13,7 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
|
||||
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
|
||||
ptty, tty, err := pty.Open()
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("open: %w", err)
|
||||
|
@ -37,16 +36,15 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
|
|||
}
|
||||
return nil, nil, xerrors.Errorf("start: %w", err)
|
||||
}
|
||||
go func() {
|
||||
// The GC can garbage collect the TTY FD before the command
|
||||
// has finished running. See:
|
||||
// https://github.com/creack/pty/issues/127#issuecomment-932764012
|
||||
_ = cmd.Wait()
|
||||
runtime.KeepAlive(ptty)
|
||||
}()
|
||||
oPty := &otherPty{
|
||||
pty: ptty,
|
||||
tty: tty,
|
||||
}
|
||||
return oPty, cmd.Process, nil
|
||||
oProcess := &otherProcess{
|
||||
pty: ptty,
|
||||
cmd: cmd,
|
||||
cmdDone: make(chan any),
|
||||
}
|
||||
go oProcess.waitInternal()
|
||||
return oPty, oProcess, nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,10 @@ import (
|
|||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
|
@ -20,7 +24,20 @@ func TestStart(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Echo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pty, _ := ptytest.Start(t, exec.Command("echo", "test"))
|
||||
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
|
||||
pty.ExpectMatch("test")
|
||||
err := ps.Wait()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Kill", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ps := ptytest.Start(t, exec.Command("sleep", "30"))
|
||||
err := ps.Kill()
|
||||
assert.NoError(t, err)
|
||||
err = ps.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, xerrors.As(err, &exitErr))
|
||||
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
|
||||
// Allocates a PTY and starts the specified command attached to it.
|
||||
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
|
||||
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
|
||||
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
|
||||
fullPath, err := exec.LookPath(cmd.Path)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -83,7 +83,12 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
|
|||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err)
|
||||
}
|
||||
return pty, process, nil
|
||||
wp := &windowsProcess{
|
||||
cmdDone: make(chan any),
|
||||
proc: process,
|
||||
}
|
||||
go wp.waitInternal()
|
||||
return pty, wp, nil
|
||||
}
|
||||
|
||||
// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476
|
||||
|
|
|
@ -8,8 +8,10 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -20,8 +22,10 @@ func TestStart(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Echo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pty, _ := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
|
||||
pty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
|
||||
pty.ExpectMatch("test")
|
||||
err := ps.Wait()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("Resize", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
@ -29,4 +33,14 @@ func TestStart(t *testing.T) {
|
|||
err := pty.Resize(100, 50)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("Kill", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ps := ptytest.Start(t, exec.Command("cmd.exe"))
|
||||
err := ps.Kill()
|
||||
assert.NoError(t, err)
|
||||
err = ps.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, xerrors.As(err, &exitErr))
|
||||
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue