refactor: PTY & SSH (#7100)

* Add ssh tests for longoutput, orphan

Signed-off-by: Spike Curtis <spike@coder.com>

* PTY/SSH tests & improvements

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix some tests

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix linting

Signed-off-by: Spike Curtis <spike@coder.com>

* fmt

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix windows test

Signed-off-by: Spike Curtis <spike@coder.com>

* Windows copy test

Signed-off-by: Spike Curtis <spike@coder.com>

* WIP Windows pty handling

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix truncation tests

Signed-off-by: Spike Curtis <spike@coder.com>

* Appease linter/fmt

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix typo

Signed-off-by: Spike Curtis <spike@coder.com>

* Rework truncation test to not assume OS buffers

Signed-off-by: Spike Curtis <spike@coder.com>

* Disable orphan test on Windows --- uses sh

Signed-off-by: Spike Curtis <spike@coder.com>

* agent_test running SSH in pty use ptytest.Start

Signed-off-by: Spike Curtis <spike@coder.com>

* More detail about closing pseudoconsole on windows

Signed-off-by: Spike Curtis <spike@coder.com>

* Code review fixes

Signed-off-by: Spike Curtis <spike@coder.com>

* Rearrange ptytest method order

Signed-off-by: Spike Curtis <spike@coder.com>

* Protect pty.Resize on windows from races

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix windows bugs

Signed-off-by: Spike Curtis <spike@coder.com>

* PTY doesn't extend PTYCmd

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix windows types

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2023-04-24 14:53:57 +04:00 committed by GitHub
parent c000f2ec28
commit daee91c6dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 803 additions and 288 deletions

View File

@ -1045,7 +1045,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
if err = a.trackConnGoroutine(func() {
buffer := make([]byte, 1024)
for {
read, err := rpty.ptty.Output().Read(buffer)
read, err := rpty.ptty.OutputReader().Read(buffer)
if err != nil {
// When the PTY is closed, this is triggered.
break
@ -1138,7 +1138,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
logger.Warn(ctx, "read conn", slog.Error(err))
return nil
}
_, err = rpty.ptty.Input().Write([]byte(req.Data))
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
if err != nil {
logger.Warn(ctx, "write to pty", slog.Error(err))
return nil
@ -1358,7 +1358,7 @@ type reconnectingPTY struct {
circularBuffer *circbuf.Buffer
circularBufferMutex sync.RWMutex
timeout *time.Timer
ptty pty.PTY
ptty pty.PTYCmd
}
// Close ends all connections to the reconnecting

View File

@ -45,6 +45,7 @@ import (
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/pty"
"github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
@ -481,17 +482,10 @@ func TestAgent_TCPLocalForwarding(t *testing.T) {
}
}()
pty := ptytest.New(t)
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
cmd.Stdin = pty.Input()
cmd.Stdout = pty.Output()
cmd.Stderr = pty.Output()
err = cmd.Start()
require.NoError(t, err)
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
go func() {
err := cmd.Wait()
err := proc.Wait()
select {
case <-done:
default:
@ -523,7 +517,7 @@ func TestAgent_TCPLocalForwarding(t *testing.T) {
<-done
_ = cmd.Process.Kill()
_ = proc.Kill()
}
//nolint:paralleltest // This test reserves a port.
@ -562,17 +556,10 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
}
}()
pty := ptytest.New(t)
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
cmd.Stdin = pty.Input()
cmd.Stdout = pty.Output()
cmd.Stderr = pty.Output()
err = cmd.Start()
require.NoError(t, err)
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
go func() {
err := cmd.Wait()
err := proc.Wait()
select {
case <-done:
default:
@ -604,7 +591,7 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
<-done
_ = cmd.Process.Kill()
_ = proc.Kill()
}
func TestAgent_UnixLocalForwarding(t *testing.T) {
@ -641,17 +628,10 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
}
}()
pty := ptytest.New(t)
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
cmd.Stdin = pty.Input()
cmd.Stdout = pty.Output()
cmd.Stderr = pty.Output()
err = cmd.Start()
require.NoError(t, err)
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
go func() {
err := cmd.Wait()
err := proc.Wait()
select {
case <-done:
default:
@ -676,7 +656,7 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
_ = conn.Close()
<-done
_ = cmd.Process.Kill()
_ = proc.Kill()
}
func TestAgent_UnixRemoteForwarding(t *testing.T) {
@ -713,17 +693,10 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
}
}()
pty := ptytest.New(t)
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
cmd.Stdin = pty.Input()
cmd.Stdout = pty.Output()
cmd.Stderr = pty.Output()
err = cmd.Start()
require.NoError(t, err)
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
go func() {
err := cmd.Wait()
err := proc.Wait()
select {
case <-done:
default:
@ -753,7 +726,7 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
<-done
_ = cmd.Process.Kill()
_ = proc.Kill()
}
func TestAgent_SFTP(t *testing.T) {
@ -1648,7 +1621,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
}, testutil.WaitShort, testutil.IntervalFast)
}
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
listener, err := net.Listen("tcp", "127.0.0.1:0")
@ -1690,7 +1663,8 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
"host",
)
args = append(args, afterArgs...)
return exec.Command("ssh", args...)
cmd := exec.Command("ssh", args...)
return ptytest.Start(t, cmd)
}
func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session {

View File

@ -253,102 +253,12 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
sshPty, windowSize, isPty := session.Pty()
if isPty {
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
// See https://github.com/coder/coder/issues/3371.
session.DisablePTYEmulation()
if !isQuietLogin(session.RawCommand()) {
manifest := s.Manifest.Load()
if manifest != nil {
err = showMOTD(session, manifest.MOTDFile)
if err != nil {
s.logger.Error(ctx, "show MOTD", slog.Error(err))
}
} else {
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
}
}
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
// The pty package sets `SSH_TTY` on supported platforms.
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
pty.WithSSHRequest(sshPty),
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
))
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
var wg sync.WaitGroup
defer func() {
defer wg.Wait()
closeErr := ptty.Close()
if closeErr != nil {
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}
}()
go func() {
for win := range windowSize {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
// If the pty is closed, then command has exited, no need to log.
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
}
}
}()
// We don't add input copy to wait group because
// it won't return until the session is closed.
go func() {
_, _ = io.Copy(ptty.Input(), session)
}()
// In low parallelism scenarios, the command may exit and we may close
// the pty before the output copy has started. This can result in the
// output being lost. To avoid this, we wait for the output copy to
// start before waiting for the command to exit. This ensures that the
// output copy goroutine will be scheduled before calling close on the
// pty. This shouldn't be needed because of `pty.Dup()` below, but it
// may not be supported on all platforms.
outputCopyStarted := make(chan struct{})
ptyOutput := func() io.ReadCloser {
defer close(outputCopyStarted)
// Try to dup so we can separate stdin and stdout closure.
// Once the original pty is closed, the dup will return
// input/output error once the buffered data has been read.
stdout, err := ptty.Dup()
if err == nil {
return stdout
}
// If we can't dup, we shouldn't close
// the fd since it's tied to stdin.
return readNopCloser{ptty.Output()}
}
wg.Add(1)
go func() {
// Ensure data is flushed to session on command exit, if we
// close the session too soon, we might lose data.
defer wg.Done()
stdout := ptyOutput()
defer stdout.Close()
_, _ = io.Copy(session, stdout)
}()
<-outputCopyStarted
err = process.Wait()
var exitErr *exec.ExitError
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
// and not something to be concerned about. But, if it's something else, we should log it.
if err != nil && !xerrors.As(err, &exitErr) {
s.logger.Warn(ctx, "wait error", slog.Error(err))
}
return err
return s.startPTYSession(session, cmd, sshPty, windowSize)
}
return startNonPTYSession(session, cmd)
}
func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
cmd.Stdout = session
cmd.Stderr = session.Stderr()
// This blocks forever until stdin is received if we don't
@ -368,10 +278,94 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
return cmd.Wait()
}
type readNopCloser struct{ io.Reader }
// ptySession is the interface to the ssh.Session that startPTYSession uses
// we use an interface here so that we can fake it in tests.
type ptySession interface {
io.ReadWriter
Context() ssh.Context
DisablePTYEmulation()
RawCommand() string
}
// Close implements io.Closer.
func (readNopCloser) Close() error { return nil }
func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
ctx := session.Context()
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
// See https://github.com/coder/coder/issues/3371.
session.DisablePTYEmulation()
if !isQuietLogin(session.RawCommand()) {
manifest := s.Manifest.Load()
if manifest != nil {
err := showMOTD(session, manifest.MOTDFile)
if err != nil {
s.logger.Error(ctx, "show MOTD", slog.Error(err))
}
} else {
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
}
}
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
// The pty package sets `SSH_TTY` on supported platforms.
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
pty.WithSSHRequest(sshPty),
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
))
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
defer func() {
closeErr := ptty.Close()
if closeErr != nil {
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}
}()
go func() {
for win := range windowSize {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
// If the pty is closed, then command has exited, no need to log.
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
}
}
}()
go func() {
_, _ = io.Copy(ptty.InputWriter(), session)
}()
// We need to wait for the command output to finish copying. It's safe to
// just do this copy on the main handler goroutine because one of two things
// will happen:
//
// 1. The command completes & closes the TTY, which then triggers an error
// after we've Read() all the buffered data from the PTY.
// 2. The client hangs up, which cancels the command's Context, and go will
// kill the command's process. This then has the same effect as (1).
n, err := io.Copy(session, ptty.OutputReader())
s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err))
if err != nil {
return xerrors.Errorf("copy error: %w", err)
}
// We've gotten all the output, but we need to wait for the process to
// complete so that we can get the exit code. This returns
// immediately if the TTY was closed as part of the command exiting.
err = process.Wait()
var exitErr *exec.ExitError
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
// and not something to be concerned about. But, if it's something else, we should log it.
if err != nil && !xerrors.As(err, &exitErr) {
s.logger.Warn(ctx, "wait error", slog.Error(err))
}
if err != nil {
return xerrors.Errorf("process wait: %w", err)
}
return nil
}
func (s *Server) sftpHandler(session ssh.Session) {
ctx := session.Context()

View File

@ -0,0 +1,190 @@
//go:build !windows
package agentssh
import (
"bufio"
"context"
"io"
"net"
"os/exec"
"testing"
gliderssh "github.com/gliderlabs/ssh"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/testutil"
"cdr.dev/slog/sloggers/slogtest"
)
const longScript = `
echo "started"
sleep 30
echo "done"
`
// Test_sessionStart_orphan tests running a command that takes a long time to
// exit normally, and terminate the SSH session context early to verify that we
// return quickly and don't leave the command running as an "orphan" with no
// active SSH session.
func Test_sessionStart_orphan(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil)
s, err := NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
// Here we're going to call the handler directly with a faked SSH session
// that just uses io.Pipes instead of a network socket. There is a large
// variation in the time between closing the socket from the client side and
// the SSH server canceling the session Context, which would lead to a flaky
// test if we did it that way. So instead, we directly cancel the context
// in this test.
sessionCtx, sessionCancel := context.WithCancel(ctx)
toClient, fromClient, sess := newTestSession(sessionCtx)
ptyInfo := gliderssh.Pty{}
windowSize := make(chan gliderssh.Window)
close(windowSize)
// the command gets the session context so that Go will terminate it when
// the session expires.
cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript)
done := make(chan struct{})
go func() {
defer close(done)
// we don't really care what the error is here. In the larger scenario,
// the client has disconnected, so we can't return any error information
// to them.
_ = s.startPTYSession(sess, cmd, ptyInfo, windowSize)
}()
readDone := make(chan struct{})
go func() {
defer close(readDone)
s := bufio.NewScanner(toClient)
assert.True(t, s.Scan())
txt := s.Text()
assert.Equal(t, "started", txt, "output corrupted")
}()
waitForChan(ctx, t, readDone, "read timeout")
// process is started, and should be sleeping for ~30 seconds
sessionCancel()
// now, we wait for the handler to complete. If it does so before the
// main test timeout, we consider this a pass. If not, it indicates
// that the server isn't properly shutting down sessions when they are
// disconnected client side, which could lead to processes hanging around
// indefinitely.
waitForChan(ctx, t, done, "handler timeout")
err = fromClient.Close()
require.NoError(t, err)
}
func waitForChan(ctx context.Context, t *testing.T, c <-chan struct{}, msg string) {
t.Helper()
select {
case <-c:
// OK!
case <-ctx.Done():
t.Fatal(msg)
}
}
type testSession struct {
ctx testSSHContext
// c2p is the client -> pty buffer
toPty *io.PipeReader
// p2c is the pty -> client buffer
fromPty *io.PipeWriter
}
type testSSHContext struct {
context.Context
}
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
toClient, fromPty := io.Pipe()
toPty, fromClient := io.Pipe()
return toClient, fromClient, &testSession{
ctx: testSSHContext{ctx},
toPty: toPty,
fromPty: fromPty,
}
}
func (s *testSession) Context() gliderssh.Context {
return s.ctx
}
func (*testSession) DisablePTYEmulation() {}
// RawCommand returns "quiet logon" so that the PTY handler doesn't attempt to
// write the message of the day, which will interfere with our tests. It writes
// the message of the day if it's a shell login (zero length RawCommand()).
func (*testSession) RawCommand() string { return "quiet logon" }
func (s *testSession) Read(p []byte) (n int, err error) {
return s.toPty.Read(p)
}
func (s *testSession) Write(p []byte) (n int, err error) {
return s.fromPty.Write(p)
}
func (testSSHContext) Lock() {
panic("not implemented")
}
func (testSSHContext) Unlock() {
panic("not implemented")
}
// User returns the username used when establishing the SSH connection.
func (testSSHContext) User() string {
panic("not implemented")
}
// SessionID returns the session hash.
func (testSSHContext) SessionID() string {
panic("not implemented")
}
// ClientVersion returns the version reported by the client.
func (testSSHContext) ClientVersion() string {
panic("not implemented")
}
// ServerVersion returns the version reported by the server.
func (testSSHContext) ServerVersion() string {
panic("not implemented")
}
// RemoteAddr returns the remote address for this connection.
func (testSSHContext) RemoteAddr() net.Addr {
panic("not implemented")
}
// LocalAddr returns the local address for this connection.
func (testSSHContext) LocalAddr() net.Addr {
panic("not implemented")
}
// Permissions returns the Permissions object used for this connection.
func (testSSHContext) Permissions() *gliderssh.Permissions {
panic("not implemented")
}
// SetValue allows you to easily write new values into the underlying context.
func (testSSHContext) SetValue(_, _ interface{}) {
panic("not implemented")
}

1
go.mod
View File

@ -108,6 +108,7 @@ require (
github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f
github.com/hashicorp/terraform-json v0.14.0
github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce
github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02
github.com/imulab/go-scim/pkg/v2 v2.2.0
github.com/jedib0t/go-pretty/v6 v6.4.0
github.com/jmoiron/sqlx v1.3.5

3
go.sum
View File

@ -972,8 +972,9 @@ github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce h1:7FO+LmZwiG/eDsB
github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 h1:aSVUgRRRtOrZOC1fYmY9gV0e9z/Iu+xNVSASWjsuyGU=
github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3/go.mod h1:5PC6ZNPde8bBqU/ewGZig35+UIZtw9Ytxez8/q5ZyFE=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 h1:AgcIVYPa6XJnU3phs104wLj8l5GEththEw6+F79YsIY=
github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.3.1/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/huandu/xstrings v1.3.2/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=

View File

@ -3,7 +3,6 @@ package pty
import (
"io"
"log"
"os"
"github.com/gliderlabs/ssh"
"golang.org/x/xerrors"
@ -12,10 +11,33 @@ import (
// ErrClosed is returned when a PTY is used after it has been closed.
var ErrClosed = xerrors.New("pty: closed")
// PTY is a minimal interface for interacting with a TTY.
// PTYCmd is an interface for interacting with a pseudo-TTY where we control
// only one end, and the other end has been passed to a running os.Process.
// nolint:revive
type PTYCmd interface {
io.Closer
// Resize sets the size of the PTY.
Resize(height uint16, width uint16) error
// OutputReader returns an io.Reader for reading the output from the process
// controlled by the pseudo-TTY
OutputReader() io.Reader
// InputWriter returns an io.Writer for writing into to the process
// controlled by the pseudo-TTY
InputWriter() io.Writer
}
// PTY is a minimal interface for interacting with pseudo-TTY where this
// process retains access to _both_ ends of the pseudo-TTY (i.e. `ptm` & `pts`
// on Linux).
type PTY interface {
io.Closer
// Resize sets the size of the PTY.
Resize(height uint16, width uint16) error
// Name of the TTY. Example on Linux would be "/dev/pts/1".
Name() string
@ -34,14 +56,6 @@ type PTY interface {
//
// The same stream would be used to provide user input: pty.Input().Write(...)
Input() ReadWriter
// Dup returns a new file descriptor for the PTY.
//
// This is useful for closing stdin and stdout separately.
Dup() (*os.File, error)
// Resize sets the size of the PTY.
Resize(height uint16, width uint16) error
}
// Process represents a process running in a PTY. We need to trigger special processing on the PTY
@ -108,8 +122,8 @@ func New(opts ...Option) (PTY, error) {
// underlying file descriptors, one for reading and one for writing, and allows
// them to be accessed separately.
type ReadWriter struct {
Reader *os.File
Writer *os.File
Reader io.Reader
Writer io.Writer
}
func (rw ReadWriter) Read(p []byte) (int, error) {

View File

@ -3,15 +3,17 @@
package pty
import (
"io"
"io/fs"
"os"
"os/exec"
"runtime"
"sync"
"syscall"
"github.com/creack/pty"
"github.com/u-root/u-root/pkg/termios"
"golang.org/x/sys/unix"
"golang.org/x/xerrors"
)
func newPty(opt ...Option) (retPTY *otherPty, err error) {
@ -28,6 +30,7 @@ func newPty(opt ...Option) (retPTY *otherPty, err error) {
pty: ptyFile,
tty: ttyFile,
opts: opts,
name: ttyFile.Name(),
}
defer func() {
if err != nil {
@ -53,6 +56,7 @@ type otherPty struct {
err error
pty, tty *os.File
opts ptyOptions
name string
}
func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) {
@ -85,7 +89,7 @@ func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error)
}
func (p *otherPty) Name() string {
return p.tty.Name()
return p.name
}
func (p *otherPty) Input() ReadWriter {
@ -95,13 +99,21 @@ func (p *otherPty) Input() ReadWriter {
}
}
func (p *otherPty) InputWriter() io.Writer {
return p.pty
}
func (p *otherPty) Output() ReadWriter {
return ReadWriter{
Reader: p.pty,
Reader: &ptmReader{p.pty},
Writer: p.tty,
}
}
func (p *otherPty) OutputReader() io.Reader {
return &ptmReader{p.pty}
}
func (p *otherPty) Resize(height uint16, width uint16) error {
return p.control(p.pty, func(fd uintptr) error {
return termios.SetWinSize(fd, &termios.Winsize{
@ -113,20 +125,6 @@ func (p *otherPty) Resize(height uint16, width uint16) error {
})
}
func (p *otherPty) Dup() (*os.File, error) {
var newfd int
err := p.control(p.pty, func(fd uintptr) error {
var err error
newfd, err = syscall.Dup(int(fd))
return err
})
if err != nil {
return nil, err
}
return os.NewFile(uintptr(newfd), p.pty.Name()), nil
}
func (p *otherPty) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
@ -137,9 +135,12 @@ func (p *otherPty) Close() error {
p.closed = true
err := p.pty.Close()
err2 := p.tty.Close()
if err == nil {
err = err2
// tty is closed & unset if we Start() a new process
if p.tty != nil {
err2 := p.tty.Close()
if err == nil {
err = err2
}
}
if err != nil {
@ -177,3 +178,21 @@ func (p *otherProcess) waitInternal() {
runtime.KeepAlive(p.pty)
close(p.cmdDone)
}
// ptmReader wraps a reference to the ptm side of a pseudo-TTY for portability
type ptmReader struct {
ptm io.Reader
}
func (r *ptmReader) Read(p []byte) (n int, err error) {
n, err = r.ptm.Read(p)
// output from the ptm will hit a PathErr when the process hangs up the
// other side (typically when the process exits, but could be earlier). For
// portability, and to fit with our use of io.Copy() to copy from the PTY,
// we want to translate this error into io.EOF
pathErr := &fs.PathError{}
if xerrors.As(err, &pathErr) {
return n, io.EOF
}
return n, err
}

View File

@ -3,6 +3,7 @@
package pty
import (
"io"
"os"
"os/exec"
"sync"
@ -21,7 +22,7 @@ var (
)
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session
func newPty(opt ...Option) (PTY, error) {
func newPty(opt ...Option) (*ptyWindows, error) {
var opts ptyOptions
for _, o := range opt {
o(&opts)
@ -88,6 +89,7 @@ type windowsProcess struct {
cmdDone chan any
cmdErr error
proc *os.Process
pw *ptyWindows
}
// Name returns the TTY name on Windows.
@ -104,6 +106,10 @@ func (p *ptyWindows) Output() ReadWriter {
}
}
func (p *ptyWindows) OutputReader() io.Reader {
return p.outputRead
}
func (p *ptyWindows) Input() ReadWriter {
return ReadWriter{
Reader: p.inputRead,
@ -111,7 +117,17 @@ func (p *ptyWindows) Input() ReadWriter {
}
}
func (p *ptyWindows) InputWriter() io.Writer {
return p.inputWrite
}
func (p *ptyWindows) Resize(height uint16, width uint16) error {
// hold the lock, so we don't race with anyone trying to close the console
p.closeMutex.Lock()
defer p.closeMutex.Unlock()
if p.closed || p.console == windows.InvalidHandle {
return ErrClosed
}
// Taken from: https://github.com/microsoft/hcsshim/blob/54a5ad86808d761e3e396aff3e2022840f39f9a8/internal/winapi/zsyscall_windows.go#L144
ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(*((*uint32)(unsafe.Pointer(&windows.Coord{
Y: int16(height),
@ -123,10 +139,6 @@ func (p *ptyWindows) Resize(height uint16, width uint16) error {
return nil
}
func (p *ptyWindows) Dup() (*os.File, error) {
return nil, xerrors.Errorf("not implemented")
}
func (p *ptyWindows) Close() error {
p.closeMutex.Lock()
defer p.closeMutex.Unlock()
@ -135,20 +147,54 @@ func (p *ptyWindows) Close() error {
}
p.closed = true
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
if ret < 0 {
return xerrors.Errorf("close pseudo console: %w", err)
// if we are running a command in the PTY, the corresponding *windowsProcess
// may have already closed the PseudoConsole when the command exited, so that
// output reads can get to EOF. In that case, we don't need to close it
// again here.
if p.console != windows.InvalidHandle {
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
if ret < 0 {
return xerrors.Errorf("close pseudo console: %w", err)
}
p.console = windows.InvalidHandle
}
_ = p.outputWrite.Close()
// We always have these files
_ = p.outputRead.Close()
_ = p.inputWrite.Close()
_ = p.inputRead.Close()
// These get closed & unset if we Start() a new process.
if p.outputWrite != nil {
_ = p.outputWrite.Close()
}
if p.inputRead != nil {
_ = p.inputRead.Close()
}
return nil
}
func (p *windowsProcess) waitInternal() {
// put this on the bottom of the defer stack since the next defer can write to p.cmdErr
defer close(p.cmdDone)
defer func() {
// close the pseudoconsole handle when the process exits, if it hasn't already been closed.
// this is important because the PseudoConsole (conhost.exe) holds the write-end
// of the output pipe. If it is not closed, reads on that pipe will block, even though
// the command has exited.
// c.f. https://devblogs.microsoft.com/commandline/windows-command-line-introducing-the-windows-pseudo-console-conpty/
p.pw.closeMutex.Lock()
defer p.pw.closeMutex.Unlock()
if p.pw.console != windows.InvalidHandle {
ret, _, err := procClosePseudoConsole.Call(uintptr(p.pw.console))
if ret < 0 && p.cmdErr == nil {
// if we already have an error from the command, prefer that error
// but if the command succeeded and closing the PseudoConsole fails
// then record that error so that we have a chance to see it
p.cmdErr = err
}
p.pw.console = windows.InvalidHandle
}
}()
state, err := p.proc.Wait()
if err != nil {
p.cmdErr = err

View File

@ -30,12 +30,21 @@ func New(t *testing.T, opts ...pty.Option) *PTY {
ptty, err := pty.New(opts...)
require.NoError(t, err)
return create(t, ptty, "cmd")
e := newExpecter(t, ptty.Output(), "cmd")
r := &PTY{
outExpecter: e,
PTY: ptty,
}
// Ensure pty is cleaned up at the end of test.
t.Cleanup(func() {
_ = r.Close()
})
return r
}
// Start starts a new process asynchronously and returns a PTY and Process.
// It kills the process upon cleanup.
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Process) {
// Start starts a new process asynchronously and returns a PTYCmd and Process.
// It kills the process and PTYCmd upon cleanup
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
t.Helper()
ptty, ps, err := pty.Start(cmd, opts...)
@ -44,10 +53,19 @@ func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Proc
_ = ps.Kill()
_ = ps.Wait()
})
return create(t, ptty, cmd.Args[0]), ps
ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0])
r := &PTYCmd{
outExpecter: ex,
PTYCmd: ptty,
}
t.Cleanup(func() {
_ = r.Close()
})
return r, ps
}
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
func newExpecter(t *testing.T, r io.Reader, name string) outExpecter {
// Use pipe for logging.
logDone := make(chan struct{})
logr, logw := io.Pipe()
@ -57,37 +75,30 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY {
out := newStdbuf()
w := io.MultiWriter(logw, out)
tpty := &PTY{
ex := outExpecter{
t: t,
PTY: ptty,
out: out,
name: name,
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
}
// Ensure pty is cleaned up at the end of test.
t.Cleanup(func() {
_ = tpty.Close()
})
logClose := func(name string, c io.Closer) {
tpty.logf("closing %s", name)
ex.logf("closing %s", name)
err := c.Close()
tpty.logf("closed %s: %v", name, err)
ex.logf("closed %s: %v", name, err)
}
// Set the actual close function for the tpty.
tpty.close = func(reason string) error {
// Set the actual close function for the outExpecter.
ex.close = func(reason string) error {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
tpty.logf("closing tpty: %s", reason)
ex.logf("closing expecter: %s", reason)
// Close pty only so that the copy goroutine can consume the
// remainder of it's buffer and then exit.
logClose("pty", ptty)
// Caller needs to have closed the PTY so that copying can complete
select {
case <-ctx.Done():
tpty.fatalf("close", "copy did not close in time")
ex.fatalf("close", "copy did not close in time")
case <-copyDone:
}
@ -95,22 +106,22 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY {
logClose("logr", logr)
select {
case <-ctx.Done():
tpty.fatalf("close", "log pipe did not close in time")
ex.fatalf("close", "log pipe did not close in time")
case <-logDone:
}
tpty.logf("closed tpty")
ex.logf("closed expecter")
return nil
}
go func() {
defer close(copyDone)
_, err := io.Copy(w, ptty.Output())
tpty.logf("copy done: %v", err)
tpty.logf("closing out")
_, err := io.Copy(w, r)
ex.logf("copy done: %v", err)
ex.logf("closing out")
err = out.closeErr(err)
tpty.logf("closed out: %v", err)
ex.logf("closed out: %v", err)
}()
// Log all output as part of test for easier debugging on errors.
@ -118,15 +129,14 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY {
defer close(logDone)
s := bufio.NewScanner(logr)
for s.Scan() {
tpty.logf("%q", stripansi.Strip(s.Text()))
ex.logf("%q", stripansi.Strip(s.Text()))
}
}()
return tpty
return ex
}
type PTY struct {
pty.PTY
type outExpecter struct {
t *testing.T
close func(reason string) error
out *stdbuf
@ -135,38 +145,23 @@ type PTY struct {
runeReader *bufio.Reader
}
func (p *PTY) Close() error {
p.t.Helper()
return p.close("close")
}
func (p *PTY) Attach(inv *clibase.Invocation) *PTY {
p.t.Helper()
inv.Stdout = p.Output()
inv.Stderr = p.Output()
inv.Stdin = p.Input()
return p
}
func (p *PTY) ExpectMatch(str string) string {
p.t.Helper()
func (e *outExpecter) ExpectMatch(str string) string {
e.t.Helper()
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
return p.ExpectMatchContext(timeout, str)
return e.ExpectMatchContext(timeout, str)
}
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
p.t.Helper()
func (e *outExpecter) ExpectMatchContext(ctx context.Context, str string) string {
e.t.Helper()
var buffer bytes.Buffer
err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
err := e.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
for {
r, _, err := p.runeReader.ReadRune()
r, _, err := e.runeReader.ReadRune()
if err != nil {
return err
}
@ -180,54 +175,54 @@ func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
}
})
if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
e.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return ""
}
p.logf("matched %q = %q", str, stripansi.Strip(buffer.String()))
e.logf("matched %q = %q", str, stripansi.Strip(buffer.String()))
return buffer.String()
}
func (p *PTY) Peek(ctx context.Context, n int) []byte {
p.t.Helper()
func (e *outExpecter) Peek(ctx context.Context, n int) []byte {
e.t.Helper()
var out []byte
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
err := e.doMatchWithDeadline(ctx, "Peek", func() error {
var err error
out, err = p.runeReader.Peek(n)
out, err = e.runeReader.Peek(n)
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
e.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
return nil
}
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
e.logf("peeked %d/%d bytes = %q", len(out), n, out)
return slices.Clone(out)
}
func (p *PTY) ReadRune(ctx context.Context) rune {
p.t.Helper()
func (e *outExpecter) ReadRune(ctx context.Context) rune {
e.t.Helper()
var r rune
err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
err := e.doMatchWithDeadline(ctx, "ReadRune", func() error {
var err error
r, _, err = p.runeReader.ReadRune()
r, _, err = e.runeReader.ReadRune()
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
e.fatalf("read error", "%v (wanted rune; got %q)", err, r)
return 0
}
p.logf("matched rune = %q", r)
e.logf("matched rune = %q", r)
return r
}
func (p *PTY) ReadLine(ctx context.Context) string {
p.t.Helper()
func (e *outExpecter) ReadLine(ctx context.Context) string {
e.t.Helper()
var buffer bytes.Buffer
err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
err := e.doMatchWithDeadline(ctx, "ReadLine", func() error {
for {
r, _, err := p.runeReader.ReadRune()
r, _, err := e.runeReader.ReadRune()
if err != nil {
return err
}
@ -240,14 +235,14 @@ func (p *PTY) ReadLine(ctx context.Context) string {
// Unicode code points can be up to 4 bytes, but the
// ones we're looking for are only 1 byte.
b, _ := p.runeReader.Peek(1)
b, _ := e.runeReader.Peek(1)
if len(b) == 0 {
return nil
}
r, _ = utf8.DecodeRune(b)
if r == '\n' {
_, _, err = p.runeReader.ReadRune()
_, _, err = e.runeReader.ReadRune()
if err != nil {
return err
}
@ -263,21 +258,21 @@ func (p *PTY) ReadLine(ctx context.Context) string {
}
})
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
e.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return ""
}
p.logf("matched newline = %q", buffer.String())
e.logf("matched newline = %q", buffer.String())
return buffer.String()
}
func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
p.t.Helper()
func (e *outExpecter) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
e.t.Helper()
// A timeout is mandatory, caller can decide by passing a context
// that times out.
if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium
p.logf("%s ctx has no deadline, using %s", name, timeout)
e.logf("%s ctx has no deadline, using %s", name, timeout)
var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout)
@ -294,13 +289,55 @@ func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() er
return err
case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("match deadline exceeded")
_ = e.close("match deadline exceeded")
<-match
return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
}
}
func (e *outExpecter) logf(format string, args ...interface{}) {
e.t.Helper()
// Match regular logger timestamp format, we seem to be logging in
// UTC in other places as well, so match here.
e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name, fmt.Sprintf(format, args...))
}
func (e *outExpecter) fatalf(reason string, format string, args ...interface{}) {
e.t.Helper()
// Ensure the message is part of the normal log stream before
// failing the test.
e.logf("%s: %s", reason, fmt.Sprintf(format, args...))
require.FailNowf(e.t, reason, format, args...)
}
type PTY struct {
outExpecter
pty.PTY
}
func (p *PTY) Close() error {
p.t.Helper()
pErr := p.PTY.Close()
eErr := p.outExpecter.close("close")
if pErr != nil {
return pErr
}
return eErr
}
func (p *PTY) Attach(inv *clibase.Invocation) *PTY {
p.t.Helper()
inv.Stdout = p.Output()
inv.Stderr = p.Output()
inv.Stdin = p.Input()
return p
}
func (p *PTY) Write(r rune) {
p.t.Helper()
@ -321,22 +358,19 @@ func (p *PTY) WriteLine(str string) {
require.NoError(p.t, err, "write line failed")
}
func (p *PTY) logf(format string, args ...interface{}) {
p.t.Helper()
// Match regular logger timestamp format, we seem to be logging in
// UTC in other places as well, so match here.
p.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), p.name, fmt.Sprintf(format, args...))
type PTYCmd struct {
outExpecter
pty.PTYCmd
}
func (p *PTY) fatalf(reason string, format string, args ...interface{}) {
func (p *PTYCmd) Close() error {
p.t.Helper()
// Ensure the message is part of the normal log stream before
// failing the test.
p.logf("%s: %s", reason, fmt.Sprintf(format, args...))
require.FailNowf(p.t, reason, format, args...)
pErr := p.PTYCmd.Close()
eErr := p.outExpecter.close("close")
if pErr != nil {
return pErr
}
return eErr
}
// stdbuf is like a buffered stdout, it buffers writes until read.

View File

@ -20,6 +20,6 @@ func WithPTYOption(opts ...Option) StartOption {
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
// instead rely on the returned Process to manage the command/process.
func Start(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) {
return startPty(cmd, opt...)
}

View File

@ -50,6 +50,17 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
}
return nil, nil, xerrors.Errorf("start: %w", err)
}
// Now that we've started the command, and passed the TTY to it, close our
// file so that the other process has the only open file to the TTY. Once
// the process closes the TTY (usually on exit), there will be no open
// references and the OS kernel returns an error when trying to read or
// write to our PTY end. Without this, reading from the process output
// will block until we close our TTY.
if err := opty.tty.Close(); err != nil {
_ = cmd.Process.Kill()
return nil, nil, xerrors.Errorf("close tty: %w", err)
}
opty.tty = nil // remove so we don't attempt to close it again.
oProcess := &otherProcess{
pty: opty.pty,
cmd: cmd,

View File

@ -25,20 +25,25 @@ func TestStart(t *testing.T) {
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
pty.ExpectMatch("test")
err := ps.Wait()
require.NoError(t, err)
err = pty.Close()
require.NoError(t, err)
})
t.Run("Kill", func(t *testing.T) {
t.Parallel()
_, ps := ptytest.Start(t, exec.Command("sleep", "30"))
pty, ps := ptytest.Start(t, exec.Command("sleep", "30"))
err := ps.Kill()
assert.NoError(t, err)
err = ps.Wait()
var exitErr *exec.ExitError
require.True(t, xerrors.As(err, &exitErr))
assert.NotEqual(t, 0, exitErr.ExitCode())
err = pty.Close()
require.NoError(t, err)
})
t.Run("SSH_TTY", func(t *testing.T) {
@ -53,5 +58,29 @@ func TestStart(t *testing.T) {
pty.ExpectMatch("SSH_TTY=/dev/")
err := ps.Wait()
require.NoError(t, err)
err = pty.Close()
require.NoError(t, err)
})
}
// these constants/vars are used by Test_Start_copy
const cmdEcho = "echo"
var argEcho = []string{"test"}
// these constants/vars are used by Test_Start_truncate
const (
countEnd = 1000
cmdCount = "sh"
)
var argCount = []string{"-c", `
i=0
while [ $i -ne 1000 ]
do
i=$(($i+1))
echo "$i"
done
`}

148
pty/start_test.go Normal file
View File

@ -0,0 +1,148 @@
package pty_test
import (
"bytes"
"context"
"fmt"
"io"
"os/exec"
"strings"
"testing"
"time"
"github.com/hinshun/vt10x"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty"
"github.com/coder/coder/testutil"
)
// Test_Start_copy tests that we can use io.Copy() on command output
// without deadlocking.
func Test_Start_copy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...))
require.NoError(t, err)
b := &bytes.Buffer{}
readDone := make(chan error, 1)
go func() {
_, err := io.Copy(b, pc.OutputReader())
readDone <- err
}()
select {
case err := <-readDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("read timed out")
}
assert.Contains(t, b.String(), "test")
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Wait()
}()
select {
case err := <-cmdDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("cmd.Wait() timed out")
}
}
// Test_Start_truncation tests that we can read command output without truncation
// even after the command has exited.
func Test_Start_truncation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...))
require.NoError(t, err)
readDone := make(chan struct{})
go func() {
defer close(readDone)
// avoid buffered IO so that we can precisely control how many bytes to read.
n := 1
for n <= countEnd {
want := fmt.Sprintf("%d", n)
err := readUntil(ctx, t, want, pc.OutputReader())
assert.NoError(t, err, "want: %s", want)
if err != nil {
return
}
n++
if (countEnd - n) < 100 {
// If the OS buffers the output, the process can exit even if
// we're not done reading. We want to slow our reads so that
// if there is a race between reading the data and it being
// truncated, we will lose and fail the test.
time.Sleep(testutil.IntervalFast)
}
}
// ensure we still get to EOF
endB := &bytes.Buffer{}
_, err := io.Copy(endB, pc.OutputReader())
assert.NoError(t, err)
}()
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Wait()
}()
select {
case err := <-cmdDone:
require.NoError(t, err)
case <-ctx.Done():
t.Fatal("cmd.Wait() timed out")
}
select {
case <-readDone:
// OK!
case <-ctx.Done():
t.Fatal("read timed out")
}
}
// readUntil reads one byte at a time until we either see the string we want, or the context expires
func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error {
// output can contain virtual terminal sequences, so we need to parse these
// to correctly interpret getting what we want.
term := vt10x.New(vt10x.WithSize(80, 80))
readErrs := make(chan error, 1)
for {
b := make([]byte, 1)
go func() {
_, err := r.Read(b)
readErrs <- err
}()
select {
case err := <-readErrs:
if err != nil {
t.Logf("err: %v\ngot: %v", err, term)
return err
}
term.Write(b)
case <-ctx.Done():
return ctx.Err()
}
got := term.String()
lines := strings.Split(got, "\n")
for _, line := range lines {
if strings.TrimSpace(line) == want {
t.Logf("want: %v\n got:%v", want, line)
return nil
}
}
}
}

View File

@ -17,7 +17,7 @@ import (
// Allocates a PTY and starts the specified command attached to it.
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) {
var opts startOptions
for _, o := range opt {
o(&opts)
@ -45,11 +45,18 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
if err != nil {
return nil, nil, err
}
pty, err := newPty(opts.ptyOpts...)
winPty, err := newPty(opts.ptyOpts...)
if err != nil {
return nil, nil, err
}
winPty := pty.(*ptyWindows)
defer func() {
if retErr != nil {
// we hit some error finishing setup; close pty, so
// we don't leak the kernel resources associated with it
_ = winPty.Close()
}
}()
if winPty.opts.sshReq != nil {
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", winPty.Name()))
}
@ -95,9 +102,34 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
wp := &windowsProcess{
cmdDone: make(chan any),
proc: process,
pw: winPty,
}
defer func() {
if retErr != nil {
// if we later error out, kill the process since
// the caller will have no way to interact with it
_ = process.Kill()
}
}()
// Now that we've started the command, and passed the pseudoconsole to it,
// close the output write and input read files, so that the other process
// has the only handles to them. Once the process closes the console, there
// will be no open references and the OS kernel returns an error when trying
// to read or write to our end. Without this, reading from the process
// output will block until they are closed.
errO := winPty.outputWrite.Close()
winPty.outputWrite = nil
errI := winPty.inputRead.Close()
winPty.inputRead = nil
if errO != nil {
return nil, nil, errO
}
if errI != nil {
return nil, nil, errI
}
go wp.waitInternal()
return pty, wp, nil
return winPty, wp, nil
}
// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476

View File

@ -4,6 +4,7 @@
package pty_test
import (
"fmt"
"os/exec"
"testing"
@ -22,25 +23,46 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
pty.ExpectMatch("test")
ptty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
ptty.ExpectMatch("test")
err := ps.Wait()
require.NoError(t, err)
err = ptty.Close()
require.NoError(t, err)
})
t.Run("Resize", func(t *testing.T) {
t.Parallel()
pty, _ := ptytest.Start(t, exec.Command("cmd.exe"))
err := pty.Resize(100, 50)
ptty, _ := ptytest.Start(t, exec.Command("cmd.exe"))
err := ptty.Resize(100, 50)
require.NoError(t, err)
err = ptty.Close()
require.NoError(t, err)
})
t.Run("Kill", func(t *testing.T) {
t.Parallel()
_, ps := ptytest.Start(t, exec.Command("cmd.exe"))
ptty, ps := ptytest.Start(t, exec.Command("cmd.exe"))
err := ps.Kill()
assert.NoError(t, err)
err = ps.Wait()
var exitErr *exec.ExitError
require.True(t, xerrors.As(err, &exitErr))
assert.NotEqual(t, 0, exitErr.ExitCode())
err = ptty.Close()
require.NoError(t, err)
})
}
// these constants/vars are used by Test_Start_copy
const cmdEcho = "cmd.exe"
var argEcho = []string{"/c", "echo", "test"}
// these constants/vars are used by Test_Start_truncate
const (
countEnd = 1000
cmdCount = "cmd.exe"
)
var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)}