Return proper exit code on ssh with TTY (#3192)

* Return proper exit code on ssh with TTY

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix revive lint

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix Windows exit code for missing command

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix close error handling on agent TTY

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2022-07-27 12:23:28 -07:00 committed by GitHub
parent a37e61a099
commit 36ffdce065
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 184 additions and 27 deletions

View File

@ -43,6 +43,11 @@ const (
ProtocolReconnectingPTY = "reconnecting-pty"
ProtocolSSH = "ssh"
ProtocolDial = "dial"
// MagicSessionErrorCode indicates that something went wrong with the session, rather than the
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
MagicSessionErrorCode = 229
)
type Options struct {
@ -273,9 +278,17 @@ func (a *agent) init(ctx context.Context) {
},
Handler: func(session ssh.Session) {
err := a.handleSSHSession(session)
var exitError *exec.ExitError
if xerrors.As(err, &exitError) {
a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
_ = session.Exit(exitError.ExitCode())
return
}
if err != nil {
a.logger.Warn(ctx, "ssh session failed", slog.Error(err))
_ = session.Exit(1)
// This exit code is designed to be unlikely to be confused for a legit exit code
// from the process.
_ = session.Exit(MagicSessionErrorCode)
return
}
},
@ -403,7 +416,7 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
return cmd, nil
}
func (a *agent) handleSSHSession(session ssh.Session) error {
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
if err != nil {
return err
@ -426,14 +439,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
defer func() {
closeErr := ptty.Close()
if closeErr != nil {
a.logger.Warn(context.Background(), "failed to close tty",
slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}
}()
err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width))
if err != nil {
return xerrors.Errorf("resize ptty: %w", err)
}
go func() {
for win := range windowSize {
err = ptty.Resize(uint16(win.Height), uint16(win.Width))
if err != nil {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
if resizeErr != nil {
a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(err))
}
}
@ -444,9 +467,15 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
go func() {
_, _ = io.Copy(session, ptty.Output())
}()
_, _ = process.Wait()
_ = ptty.Close()
return nil
err = process.Wait()
var exitErr *exec.ExitError
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
// and not something to be concerned about. But, if it's something else, we should log it.
if err != nil && !xerrors.As(err, &exitErr) {
a.logger.Warn(context.Background(), "wait error",
slog.Error(err))
}
return err
}
cmd.Stdout = session
@ -549,7 +578,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
go func() {
// If the process dies randomly, we should
// close the pty.
_, _ = process.Wait()
_ = process.Wait()
rpty.Close()
}()
go func() {

View File

@ -16,6 +16,8 @@ import (
"testing"
"time"
"golang.org/x/xerrors"
scp "github.com/bramvdbogaerde/go-scp"
"github.com/google/uuid"
"github.com/pion/udp"
@ -69,7 +71,7 @@ func TestAgent(t *testing.T) {
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
})
t.Run("SessionTTY", func(t *testing.T) {
t.Run("SessionTTYShell", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
@ -103,6 +105,29 @@ func TestAgent(t *testing.T) {
require.NoError(t, err)
})
t.Run("SessionTTYExitCode", func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agent.Metadata{})
command := "areallynotrealcommand"
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
require.NoError(t, err)
session.Stdout = ptty.Output()
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Start(command)
require.NoError(t, err)
err = session.Wait()
exitErr := &ssh.ExitError{}
require.True(t, xerrors.As(err, &exitErr))
if runtime.GOOS == "windows" {
assert.Equal(t, 1, exitErr.ExitStatus())
} else {
assert.Equal(t, 127, exitErr.ExitStatus())
}
})
t.Run("LocalForwarding", func(t *testing.T) {
t.Parallel()
random, err := net.Listen("tcp", "127.0.0.1:0")

View File

@ -29,6 +29,16 @@ type PTY interface {
Resize(height uint16, width uint16) error
}
// Process represents a process running in a PTY
type Process interface {
// Wait for the command to complete. Returned error is as for exec.Cmd.Wait()
Wait() error
// Kill the command process. Returned error is as for os.Process.Kill()
Kill() error
}
// WithFlags represents a PTY whose flags can be inspected, in particular
// to determine whether local echo is enabled.
type WithFlags interface {

View File

@ -5,6 +5,8 @@ package pty
import (
"os"
"os/exec"
"runtime"
"sync"
"github.com/creack/pty"
@ -27,6 +29,15 @@ type otherPty struct {
pty, tty *os.File
}
type otherProcess struct {
pty *os.File
cmd *exec.Cmd
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
cmdDone chan any
cmdErr error
}
func (p *otherPty) Input() ReadWriter {
return ReadWriter{
Reader: p.tty,
@ -66,3 +77,21 @@ func (p *otherPty) Close() error {
}
return nil
}
func (p *otherProcess) Wait() error {
<-p.cmdDone
return p.cmdErr
}
func (p *otherProcess) Kill() error {
return p.cmd.Process.Kill()
}
func (p *otherProcess) waitInternal() {
// The GC can garbage collect the TTY FD before the command
// has finished running. See:
// https://github.com/creack/pty/issues/127#issuecomment-932764012
p.cmdErr = p.cmd.Wait()
runtime.KeepAlive(p.pty)
close(p.cmdDone)
}

View File

@ -5,6 +5,7 @@ package pty
import (
"os"
"os/exec"
"sync"
"unsafe"
@ -66,6 +67,13 @@ type ptyWindows struct {
closed bool
}
type windowsProcess struct {
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
cmdDone chan any
cmdErr error
proc *os.Process
}
func (p *ptyWindows) Output() ReadWriter {
return ReadWriter{
Reader: p.outputRead,
@ -111,3 +119,25 @@ func (p *ptyWindows) Close() error {
return nil
}
func (p *windowsProcess) waitInternal() {
defer close(p.cmdDone)
state, err := p.proc.Wait()
if err != nil {
p.cmdErr = err
return
}
if !state.Success() {
p.cmdErr = &exec.ExitError{ProcessState: state}
return
}
}
func (p *windowsProcess) Wait() error {
<-p.cmdDone
return p.cmdErr
}
func (p *windowsProcess) Kill() error {
return p.proc.Kill()
}

View File

@ -5,7 +5,6 @@ import (
"bytes"
"context"
"io"
"os"
"os/exec"
"runtime"
"strings"
@ -27,7 +26,7 @@ func New(t *testing.T) *PTY {
return create(t, ptty, "cmd")
}
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, pty.Process) {
ptty, ps, err := pty.Start(cmd)
require.NoError(t, err)
return create(t, ptty, cmd.Args[0]), ps

View File

@ -1,10 +1,11 @@
package pty
import (
"os"
"os/exec"
)
func Start(cmd *exec.Cmd) (PTY, *os.Process, error) {
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
// instead rely on the returned Process to manage the command/process.
func Start(cmd *exec.Cmd) (PTY, Process, error) {
return startPty(cmd)
}

View File

@ -4,7 +4,6 @@
package pty
import (
"os"
"os/exec"
"runtime"
"strings"
@ -14,7 +13,7 @@ import (
"golang.org/x/xerrors"
)
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
ptty, tty, err := pty.Open()
if err != nil {
return nil, nil, xerrors.Errorf("open: %w", err)
@ -37,16 +36,15 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
}
return nil, nil, xerrors.Errorf("start: %w", err)
}
go func() {
// The GC can garbage collect the TTY FD before the command
// has finished running. See:
// https://github.com/creack/pty/issues/127#issuecomment-932764012
_ = cmd.Wait()
runtime.KeepAlive(ptty)
}()
oPty := &otherPty{
pty: ptty,
tty: tty,
}
return oPty, cmd.Process, nil
oProcess := &otherProcess{
pty: ptty,
cmd: cmd,
cmdDone: make(chan any),
}
go oProcess.waitInternal()
return oPty, oProcess, nil
}

View File

@ -7,6 +7,10 @@ import (
"os/exec"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"go.uber.org/goleak"
"github.com/coder/coder/pty/ptytest"
@ -20,7 +24,20 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty, _ := ptytest.Start(t, exec.Command("echo", "test"))
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
pty.ExpectMatch("test")
err := ps.Wait()
require.NoError(t, err)
})
t.Run("Kill", func(t *testing.T) {
t.Parallel()
_, ps := ptytest.Start(t, exec.Command("sleep", "30"))
err := ps.Kill()
assert.NoError(t, err)
err = ps.Wait()
var exitErr *exec.ExitError
require.True(t, xerrors.As(err, &exitErr))
assert.NotEqual(t, 0, exitErr.ExitCode())
})
}

View File

@ -16,7 +16,7 @@ import (
// Allocates a PTY and starts the specified command attached to it.
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
fullPath, err := exec.LookPath(cmd.Path)
if err != nil {
return nil, nil, err
@ -83,7 +83,12 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
if err != nil {
return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err)
}
return pty, process, nil
wp := &windowsProcess{
cmdDone: make(chan any),
proc: process,
}
go wp.waitInternal()
return pty, wp, nil
}
// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476

View File

@ -8,8 +8,10 @@ import (
"testing"
"github.com/coder/coder/pty/ptytest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/xerrors"
)
func TestMain(m *testing.M) {
@ -20,8 +22,10 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty, _ := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
pty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
pty.ExpectMatch("test")
err := ps.Wait()
require.NoError(t, err)
})
t.Run("Resize", func(t *testing.T) {
t.Parallel()
@ -29,4 +33,14 @@ func TestStart(t *testing.T) {
err := pty.Resize(100, 50)
require.NoError(t, err)
})
t.Run("Kill", func(t *testing.T) {
t.Parallel()
_, ps := ptytest.Start(t, exec.Command("cmd.exe"))
err := ps.Kill()
assert.NoError(t, err)
err = ps.Wait()
var exitErr *exec.ExitError
require.True(t, xerrors.As(err, &exitErr))
assert.NotEqual(t, 0, exitErr.ExitCode())
})
}