mirror of https://github.com/coder/coder.git
refactor: PTY & SSH (#7100)
* Add ssh tests for longoutput, orphan Signed-off-by: Spike Curtis <spike@coder.com> * PTY/SSH tests & improvements Signed-off-by: Spike Curtis <spike@coder.com> * Fix some tests Signed-off-by: Spike Curtis <spike@coder.com> * Fix linting Signed-off-by: Spike Curtis <spike@coder.com> * fmt Signed-off-by: Spike Curtis <spike@coder.com> * Fix windows test Signed-off-by: Spike Curtis <spike@coder.com> * Windows copy test Signed-off-by: Spike Curtis <spike@coder.com> * WIP Windows pty handling Signed-off-by: Spike Curtis <spike@coder.com> * Fix truncation tests Signed-off-by: Spike Curtis <spike@coder.com> * Appease linter/fmt Signed-off-by: Spike Curtis <spike@coder.com> * Fix typo Signed-off-by: Spike Curtis <spike@coder.com> * Rework truncation test to not assume OS buffers Signed-off-by: Spike Curtis <spike@coder.com> * Disable orphan test on Windows --- uses sh Signed-off-by: Spike Curtis <spike@coder.com> * agent_test running SSH in pty use ptytest.Start Signed-off-by: Spike Curtis <spike@coder.com> * More detail about closing pseudoconsole on windows Signed-off-by: Spike Curtis <spike@coder.com> * Code review fixes Signed-off-by: Spike Curtis <spike@coder.com> * Rearrange ptytest method order Signed-off-by: Spike Curtis <spike@coder.com> * Protect pty.Resize on windows from races Signed-off-by: Spike Curtis <spike@coder.com> * Fix windows bugs Signed-off-by: Spike Curtis <spike@coder.com> * PTY doesn't extend PTYCmd Signed-off-by: Spike Curtis <spike@coder.com> * Fix windows types Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
c000f2ec28
commit
daee91c6dc
|
@ -1045,7 +1045,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
|||
if err = a.trackConnGoroutine(func() {
|
||||
buffer := make([]byte, 1024)
|
||||
for {
|
||||
read, err := rpty.ptty.Output().Read(buffer)
|
||||
read, err := rpty.ptty.OutputReader().Read(buffer)
|
||||
if err != nil {
|
||||
// When the PTY is closed, this is triggered.
|
||||
break
|
||||
|
@ -1138,7 +1138,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
|||
logger.Warn(ctx, "read conn", slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
_, err = rpty.ptty.Input().Write([]byte(req.Data))
|
||||
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "write to pty", slog.Error(err))
|
||||
return nil
|
||||
|
@ -1358,7 +1358,7 @@ type reconnectingPTY struct {
|
|||
circularBuffer *circbuf.Buffer
|
||||
circularBufferMutex sync.RWMutex
|
||||
timeout *time.Timer
|
||||
ptty pty.PTY
|
||||
ptty pty.PTYCmd
|
||||
}
|
||||
|
||||
// Close ends all connections to the reconnecting
|
||||
|
|
|
@ -45,6 +45,7 @@ import (
|
|||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/pty"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/tailnet/tailnettest"
|
||||
|
@ -481,17 +482,10 @@ func TestAgent_TCPLocalForwarding(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
pty := ptytest.New(t)
|
||||
|
||||
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
|
||||
cmd.Stdin = pty.Input()
|
||||
cmd.Stdout = pty.Output()
|
||||
cmd.Stderr = pty.Output()
|
||||
err = cmd.Start()
|
||||
require.NoError(t, err)
|
||||
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
err := proc.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
|
@ -523,7 +517,7 @@ func TestAgent_TCPLocalForwarding(t *testing.T) {
|
|||
|
||||
<-done
|
||||
|
||||
_ = cmd.Process.Kill()
|
||||
_ = proc.Kill()
|
||||
}
|
||||
|
||||
//nolint:paralleltest // This test reserves a port.
|
||||
|
@ -562,17 +556,10 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
pty := ptytest.New(t)
|
||||
|
||||
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
|
||||
cmd.Stdin = pty.Input()
|
||||
cmd.Stdout = pty.Output()
|
||||
cmd.Stderr = pty.Output()
|
||||
err = cmd.Start()
|
||||
require.NoError(t, err)
|
||||
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
err := proc.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
|
@ -604,7 +591,7 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
|||
|
||||
<-done
|
||||
|
||||
_ = cmd.Process.Kill()
|
||||
_ = proc.Kill()
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
|
@ -641,17 +628,10 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
pty := ptytest.New(t)
|
||||
|
||||
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
|
||||
cmd.Stdin = pty.Input()
|
||||
cmd.Stdout = pty.Output()
|
||||
cmd.Stderr = pty.Output()
|
||||
err = cmd.Start()
|
||||
require.NoError(t, err)
|
||||
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
err := proc.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
|
@ -676,7 +656,7 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
|
|||
_ = conn.Close()
|
||||
<-done
|
||||
|
||||
_ = cmd.Process.Kill()
|
||||
_ = proc.Kill()
|
||||
}
|
||||
|
||||
func TestAgent_UnixRemoteForwarding(t *testing.T) {
|
||||
|
@ -713,17 +693,10 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
pty := ptytest.New(t)
|
||||
|
||||
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
|
||||
cmd.Stdin = pty.Input()
|
||||
cmd.Stdout = pty.Output()
|
||||
cmd.Stderr = pty.Output()
|
||||
err = cmd.Start()
|
||||
require.NoError(t, err)
|
||||
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
err := proc.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
|
@ -753,7 +726,7 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
|
|||
|
||||
<-done
|
||||
|
||||
_ = cmd.Process.Kill()
|
||||
_ = proc.Kill()
|
||||
}
|
||||
|
||||
func TestAgent_SFTP(t *testing.T) {
|
||||
|
@ -1648,7 +1621,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
|||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
|
||||
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
|
@ -1690,7 +1663,8 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
|
|||
"host",
|
||||
)
|
||||
args = append(args, afterArgs...)
|
||||
return exec.Command("ssh", args...)
|
||||
cmd := exec.Command("ssh", args...)
|
||||
return ptytest.Start(t, cmd)
|
||||
}
|
||||
|
||||
func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session {
|
||||
|
|
|
@ -253,102 +253,12 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
|
|||
|
||||
sshPty, windowSize, isPty := session.Pty()
|
||||
if isPty {
|
||||
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
|
||||
// See https://github.com/coder/coder/issues/3371.
|
||||
session.DisablePTYEmulation()
|
||||
|
||||
if !isQuietLogin(session.RawCommand()) {
|
||||
manifest := s.Manifest.Load()
|
||||
if manifest != nil {
|
||||
err = showMOTD(session, manifest.MOTDFile)
|
||||
if err != nil {
|
||||
s.logger.Error(ctx, "show MOTD", slog.Error(err))
|
||||
}
|
||||
} else {
|
||||
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
|
||||
|
||||
// The pty package sets `SSH_TTY` on supported platforms.
|
||||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
|
||||
pty.WithSSHRequest(sshPty),
|
||||
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
|
||||
))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("start command: %w", err)
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
defer func() {
|
||||
defer wg.Wait()
|
||||
closeErr := ptty.Close()
|
||||
if closeErr != nil {
|
||||
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
|
||||
if retErr == nil {
|
||||
retErr = closeErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for win := range windowSize {
|
||||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||
// If the pty is closed, then command has exited, no need to log.
|
||||
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
|
||||
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
|
||||
}
|
||||
}
|
||||
}()
|
||||
// We don't add input copy to wait group because
|
||||
// it won't return until the session is closed.
|
||||
go func() {
|
||||
_, _ = io.Copy(ptty.Input(), session)
|
||||
}()
|
||||
|
||||
// In low parallelism scenarios, the command may exit and we may close
|
||||
// the pty before the output copy has started. This can result in the
|
||||
// output being lost. To avoid this, we wait for the output copy to
|
||||
// start before waiting for the command to exit. This ensures that the
|
||||
// output copy goroutine will be scheduled before calling close on the
|
||||
// pty. This shouldn't be needed because of `pty.Dup()` below, but it
|
||||
// may not be supported on all platforms.
|
||||
outputCopyStarted := make(chan struct{})
|
||||
ptyOutput := func() io.ReadCloser {
|
||||
defer close(outputCopyStarted)
|
||||
// Try to dup so we can separate stdin and stdout closure.
|
||||
// Once the original pty is closed, the dup will return
|
||||
// input/output error once the buffered data has been read.
|
||||
stdout, err := ptty.Dup()
|
||||
if err == nil {
|
||||
return stdout
|
||||
}
|
||||
// If we can't dup, we shouldn't close
|
||||
// the fd since it's tied to stdin.
|
||||
return readNopCloser{ptty.Output()}
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
// Ensure data is flushed to session on command exit, if we
|
||||
// close the session too soon, we might lose data.
|
||||
defer wg.Done()
|
||||
|
||||
stdout := ptyOutput()
|
||||
defer stdout.Close()
|
||||
|
||||
_, _ = io.Copy(session, stdout)
|
||||
}()
|
||||
<-outputCopyStarted
|
||||
|
||||
err = process.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
|
||||
// and not something to be concerned about. But, if it's something else, we should log it.
|
||||
if err != nil && !xerrors.As(err, &exitErr) {
|
||||
s.logger.Warn(ctx, "wait error", slog.Error(err))
|
||||
}
|
||||
return err
|
||||
return s.startPTYSession(session, cmd, sshPty, windowSize)
|
||||
}
|
||||
return startNonPTYSession(session, cmd)
|
||||
}
|
||||
|
||||
func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
|
||||
cmd.Stdout = session
|
||||
cmd.Stderr = session.Stderr()
|
||||
// This blocks forever until stdin is received if we don't
|
||||
|
@ -368,10 +278,94 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
|
|||
return cmd.Wait()
|
||||
}
|
||||
|
||||
type readNopCloser struct{ io.Reader }
|
||||
// ptySession is the interface to the ssh.Session that startPTYSession uses
|
||||
// we use an interface here so that we can fake it in tests.
|
||||
type ptySession interface {
|
||||
io.ReadWriter
|
||||
Context() ssh.Context
|
||||
DisablePTYEmulation()
|
||||
RawCommand() string
|
||||
}
|
||||
|
||||
// Close implements io.Closer.
|
||||
func (readNopCloser) Close() error { return nil }
|
||||
func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
|
||||
ctx := session.Context()
|
||||
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
|
||||
// See https://github.com/coder/coder/issues/3371.
|
||||
session.DisablePTYEmulation()
|
||||
|
||||
if !isQuietLogin(session.RawCommand()) {
|
||||
manifest := s.Manifest.Load()
|
||||
if manifest != nil {
|
||||
err := showMOTD(session, manifest.MOTDFile)
|
||||
if err != nil {
|
||||
s.logger.Error(ctx, "show MOTD", slog.Error(err))
|
||||
}
|
||||
} else {
|
||||
s.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
|
||||
|
||||
// The pty package sets `SSH_TTY` on supported platforms.
|
||||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
|
||||
pty.WithSSHRequest(sshPty),
|
||||
pty.WithLogger(slog.Stdlib(ctx, s.logger, slog.LevelInfo)),
|
||||
))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("start command: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := ptty.Close()
|
||||
if closeErr != nil {
|
||||
s.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
|
||||
if retErr == nil {
|
||||
retErr = closeErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for win := range windowSize {
|
||||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||
// If the pty is closed, then command has exited, no need to log.
|
||||
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
|
||||
s.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(ptty.InputWriter(), session)
|
||||
}()
|
||||
|
||||
// We need to wait for the command output to finish copying. It's safe to
|
||||
// just do this copy on the main handler goroutine because one of two things
|
||||
// will happen:
|
||||
//
|
||||
// 1. The command completes & closes the TTY, which then triggers an error
|
||||
// after we've Read() all the buffered data from the PTY.
|
||||
// 2. The client hangs up, which cancels the command's Context, and go will
|
||||
// kill the command's process. This then has the same effect as (1).
|
||||
n, err := io.Copy(session, ptty.OutputReader())
|
||||
s.logger.Debug(ctx, "copy output done", slog.F("bytes", n), slog.Error(err))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("copy error: %w", err)
|
||||
}
|
||||
// We've gotten all the output, but we need to wait for the process to
|
||||
// complete so that we can get the exit code. This returns
|
||||
// immediately if the TTY was closed as part of the command exiting.
|
||||
err = process.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
|
||||
// and not something to be concerned about. But, if it's something else, we should log it.
|
||||
if err != nil && !xerrors.As(err, &exitErr) {
|
||||
s.logger.Warn(ctx, "wait error", slog.Error(err))
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("process wait: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sftpHandler(session ssh.Session) {
|
||||
ctx := session.Context()
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
//go:build !windows
|
||||
|
||||
package agentssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
gliderssh "github.com/gliderlabs/ssh"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/testutil"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
)
|
||||
|
||||
const longScript = `
|
||||
echo "started"
|
||||
sleep 30
|
||||
echo "done"
|
||||
`
|
||||
|
||||
// Test_sessionStart_orphan tests running a command that takes a long time to
|
||||
// exit normally, and terminate the SSH session context early to verify that we
|
||||
// return quickly and don't leave the command running as an "orphan" with no
|
||||
// active SSH session.
|
||||
func Test_sessionStart_orphan(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
s, err := NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Here we're going to call the handler directly with a faked SSH session
|
||||
// that just uses io.Pipes instead of a network socket. There is a large
|
||||
// variation in the time between closing the socket from the client side and
|
||||
// the SSH server canceling the session Context, which would lead to a flaky
|
||||
// test if we did it that way. So instead, we directly cancel the context
|
||||
// in this test.
|
||||
sessionCtx, sessionCancel := context.WithCancel(ctx)
|
||||
toClient, fromClient, sess := newTestSession(sessionCtx)
|
||||
ptyInfo := gliderssh.Pty{}
|
||||
windowSize := make(chan gliderssh.Window)
|
||||
close(windowSize)
|
||||
// the command gets the session context so that Go will terminate it when
|
||||
// the session expires.
|
||||
cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
// we don't really care what the error is here. In the larger scenario,
|
||||
// the client has disconnected, so we can't return any error information
|
||||
// to them.
|
||||
_ = s.startPTYSession(sess, cmd, ptyInfo, windowSize)
|
||||
}()
|
||||
|
||||
readDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
s := bufio.NewScanner(toClient)
|
||||
assert.True(t, s.Scan())
|
||||
txt := s.Text()
|
||||
assert.Equal(t, "started", txt, "output corrupted")
|
||||
}()
|
||||
|
||||
waitForChan(ctx, t, readDone, "read timeout")
|
||||
// process is started, and should be sleeping for ~30 seconds
|
||||
|
||||
sessionCancel()
|
||||
|
||||
// now, we wait for the handler to complete. If it does so before the
|
||||
// main test timeout, we consider this a pass. If not, it indicates
|
||||
// that the server isn't properly shutting down sessions when they are
|
||||
// disconnected client side, which could lead to processes hanging around
|
||||
// indefinitely.
|
||||
waitForChan(ctx, t, done, "handler timeout")
|
||||
|
||||
err = fromClient.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func waitForChan(ctx context.Context, t *testing.T, c <-chan struct{}, msg string) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-c:
|
||||
// OK!
|
||||
case <-ctx.Done():
|
||||
t.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
type testSession struct {
|
||||
ctx testSSHContext
|
||||
|
||||
// c2p is the client -> pty buffer
|
||||
toPty *io.PipeReader
|
||||
// p2c is the pty -> client buffer
|
||||
fromPty *io.PipeWriter
|
||||
}
|
||||
|
||||
type testSSHContext struct {
|
||||
context.Context
|
||||
}
|
||||
|
||||
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
|
||||
toClient, fromPty := io.Pipe()
|
||||
toPty, fromClient := io.Pipe()
|
||||
|
||||
return toClient, fromClient, &testSession{
|
||||
ctx: testSSHContext{ctx},
|
||||
toPty: toPty,
|
||||
fromPty: fromPty,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testSession) Context() gliderssh.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (*testSession) DisablePTYEmulation() {}
|
||||
|
||||
// RawCommand returns "quiet logon" so that the PTY handler doesn't attempt to
|
||||
// write the message of the day, which will interfere with our tests. It writes
|
||||
// the message of the day if it's a shell login (zero length RawCommand()).
|
||||
func (*testSession) RawCommand() string { return "quiet logon" }
|
||||
|
||||
func (s *testSession) Read(p []byte) (n int, err error) {
|
||||
return s.toPty.Read(p)
|
||||
}
|
||||
|
||||
func (s *testSession) Write(p []byte) (n int, err error) {
|
||||
return s.fromPty.Write(p)
|
||||
}
|
||||
|
||||
func (testSSHContext) Lock() {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (testSSHContext) Unlock() {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// User returns the username used when establishing the SSH connection.
|
||||
func (testSSHContext) User() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// SessionID returns the session hash.
|
||||
func (testSSHContext) SessionID() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// ClientVersion returns the version reported by the client.
|
||||
func (testSSHContext) ClientVersion() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// ServerVersion returns the version reported by the server.
|
||||
func (testSSHContext) ServerVersion() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address for this connection.
|
||||
func (testSSHContext) RemoteAddr() net.Addr {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address for this connection.
|
||||
func (testSSHContext) LocalAddr() net.Addr {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// Permissions returns the Permissions object used for this connection.
|
||||
func (testSSHContext) Permissions() *gliderssh.Permissions {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// SetValue allows you to easily write new values into the underlying context.
|
||||
func (testSSHContext) SetValue(_, _ interface{}) {
|
||||
panic("not implemented")
|
||||
}
|
1
go.mod
1
go.mod
|
@ -108,6 +108,7 @@ require (
|
|||
github.com/hashicorp/terraform-config-inspect v0.0.0-20211115214459-90acf1ca460f
|
||||
github.com/hashicorp/terraform-json v0.14.0
|
||||
github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce
|
||||
github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02
|
||||
github.com/imulab/go-scim/pkg/v2 v2.2.0
|
||||
github.com/jedib0t/go-pretty/v6 v6.4.0
|
||||
github.com/jmoiron/sqlx v1.3.5
|
||||
|
|
3
go.sum
3
go.sum
|
@ -972,8 +972,9 @@ github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce h1:7FO+LmZwiG/eDsB
|
|||
github.com/hashicorp/yamux v0.0.0-20220718163420-dd80a7ee44ce/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
|
||||
github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 h1:aSVUgRRRtOrZOC1fYmY9gV0e9z/Iu+xNVSASWjsuyGU=
|
||||
github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3/go.mod h1:5PC6ZNPde8bBqU/ewGZig35+UIZtw9Ytxez8/q5ZyFE=
|
||||
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
|
||||
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
|
||||
github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 h1:AgcIVYPa6XJnU3phs104wLj8l5GEththEw6+F79YsIY=
|
||||
github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/huandu/xstrings v1.3.1/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/huandu/xstrings v1.3.2/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
|
|
38
pty/pty.go
38
pty/pty.go
|
@ -3,7 +3,6 @@ package pty
|
|||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -12,10 +11,33 @@ import (
|
|||
// ErrClosed is returned when a PTY is used after it has been closed.
|
||||
var ErrClosed = xerrors.New("pty: closed")
|
||||
|
||||
// PTY is a minimal interface for interacting with a TTY.
|
||||
// PTYCmd is an interface for interacting with a pseudo-TTY where we control
|
||||
// only one end, and the other end has been passed to a running os.Process.
|
||||
// nolint:revive
|
||||
type PTYCmd interface {
|
||||
io.Closer
|
||||
|
||||
// Resize sets the size of the PTY.
|
||||
Resize(height uint16, width uint16) error
|
||||
|
||||
// OutputReader returns an io.Reader for reading the output from the process
|
||||
// controlled by the pseudo-TTY
|
||||
OutputReader() io.Reader
|
||||
|
||||
// InputWriter returns an io.Writer for writing into to the process
|
||||
// controlled by the pseudo-TTY
|
||||
InputWriter() io.Writer
|
||||
}
|
||||
|
||||
// PTY is a minimal interface for interacting with pseudo-TTY where this
|
||||
// process retains access to _both_ ends of the pseudo-TTY (i.e. `ptm` & `pts`
|
||||
// on Linux).
|
||||
type PTY interface {
|
||||
io.Closer
|
||||
|
||||
// Resize sets the size of the PTY.
|
||||
Resize(height uint16, width uint16) error
|
||||
|
||||
// Name of the TTY. Example on Linux would be "/dev/pts/1".
|
||||
Name() string
|
||||
|
||||
|
@ -34,14 +56,6 @@ type PTY interface {
|
|||
//
|
||||
// The same stream would be used to provide user input: pty.Input().Write(...)
|
||||
Input() ReadWriter
|
||||
|
||||
// Dup returns a new file descriptor for the PTY.
|
||||
//
|
||||
// This is useful for closing stdin and stdout separately.
|
||||
Dup() (*os.File, error)
|
||||
|
||||
// Resize sets the size of the PTY.
|
||||
Resize(height uint16, width uint16) error
|
||||
}
|
||||
|
||||
// Process represents a process running in a PTY. We need to trigger special processing on the PTY
|
||||
|
@ -108,8 +122,8 @@ func New(opts ...Option) (PTY, error) {
|
|||
// underlying file descriptors, one for reading and one for writing, and allows
|
||||
// them to be accessed separately.
|
||||
type ReadWriter struct {
|
||||
Reader *os.File
|
||||
Writer *os.File
|
||||
Reader io.Reader
|
||||
Writer io.Writer
|
||||
}
|
||||
|
||||
func (rw ReadWriter) Read(p []byte) (int, error) {
|
||||
|
|
|
@ -3,15 +3,17 @@
|
|||
package pty
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/u-root/u-root/pkg/termios"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func newPty(opt ...Option) (retPTY *otherPty, err error) {
|
||||
|
@ -28,6 +30,7 @@ func newPty(opt ...Option) (retPTY *otherPty, err error) {
|
|||
pty: ptyFile,
|
||||
tty: ttyFile,
|
||||
opts: opts,
|
||||
name: ttyFile.Name(),
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
|
@ -53,6 +56,7 @@ type otherPty struct {
|
|||
err error
|
||||
pty, tty *os.File
|
||||
opts ptyOptions
|
||||
name string
|
||||
}
|
||||
|
||||
func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) {
|
||||
|
@ -85,7 +89,7 @@ func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error)
|
|||
}
|
||||
|
||||
func (p *otherPty) Name() string {
|
||||
return p.tty.Name()
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *otherPty) Input() ReadWriter {
|
||||
|
@ -95,13 +99,21 @@ func (p *otherPty) Input() ReadWriter {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *otherPty) InputWriter() io.Writer {
|
||||
return p.pty
|
||||
}
|
||||
|
||||
func (p *otherPty) Output() ReadWriter {
|
||||
return ReadWriter{
|
||||
Reader: p.pty,
|
||||
Reader: &ptmReader{p.pty},
|
||||
Writer: p.tty,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *otherPty) OutputReader() io.Reader {
|
||||
return &ptmReader{p.pty}
|
||||
}
|
||||
|
||||
func (p *otherPty) Resize(height uint16, width uint16) error {
|
||||
return p.control(p.pty, func(fd uintptr) error {
|
||||
return termios.SetWinSize(fd, &termios.Winsize{
|
||||
|
@ -113,20 +125,6 @@ func (p *otherPty) Resize(height uint16, width uint16) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (p *otherPty) Dup() (*os.File, error) {
|
||||
var newfd int
|
||||
err := p.control(p.pty, func(fd uintptr) error {
|
||||
var err error
|
||||
newfd, err = syscall.Dup(int(fd))
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(newfd), p.pty.Name()), nil
|
||||
}
|
||||
|
||||
func (p *otherPty) Close() error {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
|
@ -137,9 +135,12 @@ func (p *otherPty) Close() error {
|
|||
p.closed = true
|
||||
|
||||
err := p.pty.Close()
|
||||
err2 := p.tty.Close()
|
||||
if err == nil {
|
||||
err = err2
|
||||
// tty is closed & unset if we Start() a new process
|
||||
if p.tty != nil {
|
||||
err2 := p.tty.Close()
|
||||
if err == nil {
|
||||
err = err2
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -177,3 +178,21 @@ func (p *otherProcess) waitInternal() {
|
|||
runtime.KeepAlive(p.pty)
|
||||
close(p.cmdDone)
|
||||
}
|
||||
|
||||
// ptmReader wraps a reference to the ptm side of a pseudo-TTY for portability
|
||||
type ptmReader struct {
|
||||
ptm io.Reader
|
||||
}
|
||||
|
||||
func (r *ptmReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.ptm.Read(p)
|
||||
// output from the ptm will hit a PathErr when the process hangs up the
|
||||
// other side (typically when the process exits, but could be earlier). For
|
||||
// portability, and to fit with our use of io.Copy() to copy from the PTY,
|
||||
// we want to translate this error into io.EOF
|
||||
pathErr := &fs.PathError{}
|
||||
if xerrors.As(err, &pathErr) {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package pty
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
@ -21,7 +22,7 @@ var (
|
|||
)
|
||||
|
||||
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session
|
||||
func newPty(opt ...Option) (PTY, error) {
|
||||
func newPty(opt ...Option) (*ptyWindows, error) {
|
||||
var opts ptyOptions
|
||||
for _, o := range opt {
|
||||
o(&opts)
|
||||
|
@ -88,6 +89,7 @@ type windowsProcess struct {
|
|||
cmdDone chan any
|
||||
cmdErr error
|
||||
proc *os.Process
|
||||
pw *ptyWindows
|
||||
}
|
||||
|
||||
// Name returns the TTY name on Windows.
|
||||
|
@ -104,6 +106,10 @@ func (p *ptyWindows) Output() ReadWriter {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *ptyWindows) OutputReader() io.Reader {
|
||||
return p.outputRead
|
||||
}
|
||||
|
||||
func (p *ptyWindows) Input() ReadWriter {
|
||||
return ReadWriter{
|
||||
Reader: p.inputRead,
|
||||
|
@ -111,7 +117,17 @@ func (p *ptyWindows) Input() ReadWriter {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *ptyWindows) InputWriter() io.Writer {
|
||||
return p.inputWrite
|
||||
}
|
||||
|
||||
func (p *ptyWindows) Resize(height uint16, width uint16) error {
|
||||
// hold the lock, so we don't race with anyone trying to close the console
|
||||
p.closeMutex.Lock()
|
||||
defer p.closeMutex.Unlock()
|
||||
if p.closed || p.console == windows.InvalidHandle {
|
||||
return ErrClosed
|
||||
}
|
||||
// Taken from: https://github.com/microsoft/hcsshim/blob/54a5ad86808d761e3e396aff3e2022840f39f9a8/internal/winapi/zsyscall_windows.go#L144
|
||||
ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(*((*uint32)(unsafe.Pointer(&windows.Coord{
|
||||
Y: int16(height),
|
||||
|
@ -123,10 +139,6 @@ func (p *ptyWindows) Resize(height uint16, width uint16) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *ptyWindows) Dup() (*os.File, error) {
|
||||
return nil, xerrors.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (p *ptyWindows) Close() error {
|
||||
p.closeMutex.Lock()
|
||||
defer p.closeMutex.Unlock()
|
||||
|
@ -135,20 +147,54 @@ func (p *ptyWindows) Close() error {
|
|||
}
|
||||
p.closed = true
|
||||
|
||||
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
|
||||
if ret < 0 {
|
||||
return xerrors.Errorf("close pseudo console: %w", err)
|
||||
// if we are running a command in the PTY, the corresponding *windowsProcess
|
||||
// may have already closed the PseudoConsole when the command exited, so that
|
||||
// output reads can get to EOF. In that case, we don't need to close it
|
||||
// again here.
|
||||
if p.console != windows.InvalidHandle {
|
||||
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
|
||||
if ret < 0 {
|
||||
return xerrors.Errorf("close pseudo console: %w", err)
|
||||
}
|
||||
p.console = windows.InvalidHandle
|
||||
}
|
||||
|
||||
_ = p.outputWrite.Close()
|
||||
// We always have these files
|
||||
_ = p.outputRead.Close()
|
||||
_ = p.inputWrite.Close()
|
||||
_ = p.inputRead.Close()
|
||||
// These get closed & unset if we Start() a new process.
|
||||
if p.outputWrite != nil {
|
||||
_ = p.outputWrite.Close()
|
||||
}
|
||||
if p.inputRead != nil {
|
||||
_ = p.inputRead.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *windowsProcess) waitInternal() {
|
||||
// put this on the bottom of the defer stack since the next defer can write to p.cmdErr
|
||||
defer close(p.cmdDone)
|
||||
defer func() {
|
||||
// close the pseudoconsole handle when the process exits, if it hasn't already been closed.
|
||||
// this is important because the PseudoConsole (conhost.exe) holds the write-end
|
||||
// of the output pipe. If it is not closed, reads on that pipe will block, even though
|
||||
// the command has exited.
|
||||
// c.f. https://devblogs.microsoft.com/commandline/windows-command-line-introducing-the-windows-pseudo-console-conpty/
|
||||
p.pw.closeMutex.Lock()
|
||||
defer p.pw.closeMutex.Unlock()
|
||||
if p.pw.console != windows.InvalidHandle {
|
||||
ret, _, err := procClosePseudoConsole.Call(uintptr(p.pw.console))
|
||||
if ret < 0 && p.cmdErr == nil {
|
||||
// if we already have an error from the command, prefer that error
|
||||
// but if the command succeeded and closing the PseudoConsole fails
|
||||
// then record that error so that we have a chance to see it
|
||||
p.cmdErr = err
|
||||
}
|
||||
p.pw.console = windows.InvalidHandle
|
||||
}
|
||||
}()
|
||||
|
||||
state, err := p.proc.Wait()
|
||||
if err != nil {
|
||||
p.cmdErr = err
|
||||
|
|
|
@ -30,12 +30,21 @@ func New(t *testing.T, opts ...pty.Option) *PTY {
|
|||
ptty, err := pty.New(opts...)
|
||||
require.NoError(t, err)
|
||||
|
||||
return create(t, ptty, "cmd")
|
||||
e := newExpecter(t, ptty.Output(), "cmd")
|
||||
r := &PTY{
|
||||
outExpecter: e,
|
||||
PTY: ptty,
|
||||
}
|
||||
// Ensure pty is cleaned up at the end of test.
|
||||
t.Cleanup(func() {
|
||||
_ = r.Close()
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
// Start starts a new process asynchronously and returns a PTY and Process.
|
||||
// It kills the process upon cleanup.
|
||||
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Process) {
|
||||
// Start starts a new process asynchronously and returns a PTYCmd and Process.
|
||||
// It kills the process and PTYCmd upon cleanup
|
||||
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
|
||||
t.Helper()
|
||||
|
||||
ptty, ps, err := pty.Start(cmd, opts...)
|
||||
|
@ -44,10 +53,19 @@ func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Proc
|
|||
_ = ps.Kill()
|
||||
_ = ps.Wait()
|
||||
})
|
||||
return create(t, ptty, cmd.Args[0]), ps
|
||||
ex := newExpecter(t, ptty.OutputReader(), cmd.Args[0])
|
||||
|
||||
r := &PTYCmd{
|
||||
outExpecter: ex,
|
||||
PTYCmd: ptty,
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = r.Close()
|
||||
})
|
||||
return r, ps
|
||||
}
|
||||
|
||||
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
|
||||
func newExpecter(t *testing.T, r io.Reader, name string) outExpecter {
|
||||
// Use pipe for logging.
|
||||
logDone := make(chan struct{})
|
||||
logr, logw := io.Pipe()
|
||||
|
@ -57,37 +75,30 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY {
|
|||
out := newStdbuf()
|
||||
w := io.MultiWriter(logw, out)
|
||||
|
||||
tpty := &PTY{
|
||||
ex := outExpecter{
|
||||
t: t,
|
||||
PTY: ptty,
|
||||
out: out,
|
||||
name: name,
|
||||
|
||||
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
|
||||
}
|
||||
// Ensure pty is cleaned up at the end of test.
|
||||
t.Cleanup(func() {
|
||||
_ = tpty.Close()
|
||||
})
|
||||
|
||||
logClose := func(name string, c io.Closer) {
|
||||
tpty.logf("closing %s", name)
|
||||
ex.logf("closing %s", name)
|
||||
err := c.Close()
|
||||
tpty.logf("closed %s: %v", name, err)
|
||||
ex.logf("closed %s: %v", name, err)
|
||||
}
|
||||
// Set the actual close function for the tpty.
|
||||
tpty.close = func(reason string) error {
|
||||
// Set the actual close function for the outExpecter.
|
||||
ex.close = func(reason string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
tpty.logf("closing tpty: %s", reason)
|
||||
ex.logf("closing expecter: %s", reason)
|
||||
|
||||
// Close pty only so that the copy goroutine can consume the
|
||||
// remainder of it's buffer and then exit.
|
||||
logClose("pty", ptty)
|
||||
// Caller needs to have closed the PTY so that copying can complete
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
tpty.fatalf("close", "copy did not close in time")
|
||||
ex.fatalf("close", "copy did not close in time")
|
||||
case <-copyDone:
|
||||
}
|
||||
|
||||
|
@ -95,22 +106,22 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY {
|
|||
logClose("logr", logr)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
tpty.fatalf("close", "log pipe did not close in time")
|
||||
ex.fatalf("close", "log pipe did not close in time")
|
||||
case <-logDone:
|
||||
}
|
||||
|
||||
tpty.logf("closed tpty")
|
||||
ex.logf("closed expecter")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(copyDone)
|
||||
_, err := io.Copy(w, ptty.Output())
|
||||
tpty.logf("copy done: %v", err)
|
||||
tpty.logf("closing out")
|
||||
_, err := io.Copy(w, r)
|
||||
ex.logf("copy done: %v", err)
|
||||
ex.logf("closing out")
|
||||
err = out.closeErr(err)
|
||||
tpty.logf("closed out: %v", err)
|
||||
ex.logf("closed out: %v", err)
|
||||
}()
|
||||
|
||||
// Log all output as part of test for easier debugging on errors.
|
||||
|
@ -118,15 +129,14 @@ func create(t *testing.T, ptty pty.PTY, name string) *PTY {
|
|||
defer close(logDone)
|
||||
s := bufio.NewScanner(logr)
|
||||
for s.Scan() {
|
||||
tpty.logf("%q", stripansi.Strip(s.Text()))
|
||||
ex.logf("%q", stripansi.Strip(s.Text()))
|
||||
}
|
||||
}()
|
||||
|
||||
return tpty
|
||||
return ex
|
||||
}
|
||||
|
||||
type PTY struct {
|
||||
pty.PTY
|
||||
type outExpecter struct {
|
||||
t *testing.T
|
||||
close func(reason string) error
|
||||
out *stdbuf
|
||||
|
@ -135,38 +145,23 @@ type PTY struct {
|
|||
runeReader *bufio.Reader
|
||||
}
|
||||
|
||||
func (p *PTY) Close() error {
|
||||
p.t.Helper()
|
||||
|
||||
return p.close("close")
|
||||
}
|
||||
|
||||
func (p *PTY) Attach(inv *clibase.Invocation) *PTY {
|
||||
p.t.Helper()
|
||||
|
||||
inv.Stdout = p.Output()
|
||||
inv.Stderr = p.Output()
|
||||
inv.Stdin = p.Input()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PTY) ExpectMatch(str string) string {
|
||||
p.t.Helper()
|
||||
func (e *outExpecter) ExpectMatch(str string) string {
|
||||
e.t.Helper()
|
||||
|
||||
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
return p.ExpectMatchContext(timeout, str)
|
||||
return e.ExpectMatchContext(timeout, str)
|
||||
}
|
||||
|
||||
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
|
||||
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
|
||||
p.t.Helper()
|
||||
func (e *outExpecter) ExpectMatchContext(ctx context.Context, str string) string {
|
||||
e.t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
|
||||
err := e.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
|
||||
for {
|
||||
r, _, err := p.runeReader.ReadRune()
|
||||
r, _, err := e.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -180,54 +175,54 @@ func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
|
|||
}
|
||||
})
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
|
||||
e.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
|
||||
return ""
|
||||
}
|
||||
p.logf("matched %q = %q", str, stripansi.Strip(buffer.String()))
|
||||
e.logf("matched %q = %q", str, stripansi.Strip(buffer.String()))
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (p *PTY) Peek(ctx context.Context, n int) []byte {
|
||||
p.t.Helper()
|
||||
func (e *outExpecter) Peek(ctx context.Context, n int) []byte {
|
||||
e.t.Helper()
|
||||
|
||||
var out []byte
|
||||
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
|
||||
err := e.doMatchWithDeadline(ctx, "Peek", func() error {
|
||||
var err error
|
||||
out, err = p.runeReader.Peek(n)
|
||||
out, err = e.runeReader.Peek(n)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
|
||||
e.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
|
||||
return nil
|
||||
}
|
||||
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
|
||||
e.logf("peeked %d/%d bytes = %q", len(out), n, out)
|
||||
return slices.Clone(out)
|
||||
}
|
||||
|
||||
func (p *PTY) ReadRune(ctx context.Context) rune {
|
||||
p.t.Helper()
|
||||
func (e *outExpecter) ReadRune(ctx context.Context) rune {
|
||||
e.t.Helper()
|
||||
|
||||
var r rune
|
||||
err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
|
||||
err := e.doMatchWithDeadline(ctx, "ReadRune", func() error {
|
||||
var err error
|
||||
r, _, err = p.runeReader.ReadRune()
|
||||
r, _, err = e.runeReader.ReadRune()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
|
||||
e.fatalf("read error", "%v (wanted rune; got %q)", err, r)
|
||||
return 0
|
||||
}
|
||||
p.logf("matched rune = %q", r)
|
||||
e.logf("matched rune = %q", r)
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *PTY) ReadLine(ctx context.Context) string {
|
||||
p.t.Helper()
|
||||
func (e *outExpecter) ReadLine(ctx context.Context) string {
|
||||
e.t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
|
||||
err := e.doMatchWithDeadline(ctx, "ReadLine", func() error {
|
||||
for {
|
||||
r, _, err := p.runeReader.ReadRune()
|
||||
r, _, err := e.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -240,14 +235,14 @@ func (p *PTY) ReadLine(ctx context.Context) string {
|
|||
|
||||
// Unicode code points can be up to 4 bytes, but the
|
||||
// ones we're looking for are only 1 byte.
|
||||
b, _ := p.runeReader.Peek(1)
|
||||
b, _ := e.runeReader.Peek(1)
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r, _ = utf8.DecodeRune(b)
|
||||
if r == '\n' {
|
||||
_, _, err = p.runeReader.ReadRune()
|
||||
_, _, err = e.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -263,21 +258,21 @@ func (p *PTY) ReadLine(ctx context.Context) string {
|
|||
}
|
||||
})
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
|
||||
e.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
|
||||
return ""
|
||||
}
|
||||
p.logf("matched newline = %q", buffer.String())
|
||||
e.logf("matched newline = %q", buffer.String())
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
|
||||
p.t.Helper()
|
||||
func (e *outExpecter) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
|
||||
e.t.Helper()
|
||||
|
||||
// A timeout is mandatory, caller can decide by passing a context
|
||||
// that times out.
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
timeout := testutil.WaitMedium
|
||||
p.logf("%s ctx has no deadline, using %s", name, timeout)
|
||||
e.logf("%s ctx has no deadline, using %s", name, timeout)
|
||||
var cancel context.CancelFunc
|
||||
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
|
@ -294,13 +289,55 @@ func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() er
|
|||
return err
|
||||
case <-ctx.Done():
|
||||
// Ensure goroutine is cleaned up before test exit.
|
||||
_ = p.close("match deadline exceeded")
|
||||
_ = e.close("match deadline exceeded")
|
||||
<-match
|
||||
|
||||
return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (e *outExpecter) logf(format string, args ...interface{}) {
|
||||
e.t.Helper()
|
||||
|
||||
// Match regular logger timestamp format, we seem to be logging in
|
||||
// UTC in other places as well, so match here.
|
||||
e.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), e.name, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (e *outExpecter) fatalf(reason string, format string, args ...interface{}) {
|
||||
e.t.Helper()
|
||||
|
||||
// Ensure the message is part of the normal log stream before
|
||||
// failing the test.
|
||||
e.logf("%s: %s", reason, fmt.Sprintf(format, args...))
|
||||
|
||||
require.FailNowf(e.t, reason, format, args...)
|
||||
}
|
||||
|
||||
type PTY struct {
|
||||
outExpecter
|
||||
pty.PTY
|
||||
}
|
||||
|
||||
func (p *PTY) Close() error {
|
||||
p.t.Helper()
|
||||
pErr := p.PTY.Close()
|
||||
eErr := p.outExpecter.close("close")
|
||||
if pErr != nil {
|
||||
return pErr
|
||||
}
|
||||
return eErr
|
||||
}
|
||||
|
||||
func (p *PTY) Attach(inv *clibase.Invocation) *PTY {
|
||||
p.t.Helper()
|
||||
|
||||
inv.Stdout = p.Output()
|
||||
inv.Stderr = p.Output()
|
||||
inv.Stdin = p.Input()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PTY) Write(r rune) {
|
||||
p.t.Helper()
|
||||
|
||||
|
@ -321,22 +358,19 @@ func (p *PTY) WriteLine(str string) {
|
|||
require.NoError(p.t, err, "write line failed")
|
||||
}
|
||||
|
||||
func (p *PTY) logf(format string, args ...interface{}) {
|
||||
p.t.Helper()
|
||||
|
||||
// Match regular logger timestamp format, we seem to be logging in
|
||||
// UTC in other places as well, so match here.
|
||||
p.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), p.name, fmt.Sprintf(format, args...))
|
||||
type PTYCmd struct {
|
||||
outExpecter
|
||||
pty.PTYCmd
|
||||
}
|
||||
|
||||
func (p *PTY) fatalf(reason string, format string, args ...interface{}) {
|
||||
func (p *PTYCmd) Close() error {
|
||||
p.t.Helper()
|
||||
|
||||
// Ensure the message is part of the normal log stream before
|
||||
// failing the test.
|
||||
p.logf("%s: %s", reason, fmt.Sprintf(format, args...))
|
||||
|
||||
require.FailNowf(p.t, reason, format, args...)
|
||||
pErr := p.PTYCmd.Close()
|
||||
eErr := p.outExpecter.close("close")
|
||||
if pErr != nil {
|
||||
return pErr
|
||||
}
|
||||
return eErr
|
||||
}
|
||||
|
||||
// stdbuf is like a buffered stdout, it buffers writes until read.
|
||||
|
|
|
@ -20,6 +20,6 @@ func WithPTYOption(opts ...Option) StartOption {
|
|||
|
||||
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
|
||||
// instead rely on the returned Process to manage the command/process.
|
||||
func Start(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
|
||||
func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) {
|
||||
return startPty(cmd, opt...)
|
||||
}
|
||||
|
|
|
@ -50,6 +50,17 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
|
|||
}
|
||||
return nil, nil, xerrors.Errorf("start: %w", err)
|
||||
}
|
||||
// Now that we've started the command, and passed the TTY to it, close our
|
||||
// file so that the other process has the only open file to the TTY. Once
|
||||
// the process closes the TTY (usually on exit), there will be no open
|
||||
// references and the OS kernel returns an error when trying to read or
|
||||
// write to our PTY end. Without this, reading from the process output
|
||||
// will block until we close our TTY.
|
||||
if err := opty.tty.Close(); err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
return nil, nil, xerrors.Errorf("close tty: %w", err)
|
||||
}
|
||||
opty.tty = nil // remove so we don't attempt to close it again.
|
||||
oProcess := &otherProcess{
|
||||
pty: opty.pty,
|
||||
cmd: cmd,
|
||||
|
|
|
@ -25,20 +25,25 @@ func TestStart(t *testing.T) {
|
|||
t.Run("Echo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
|
||||
|
||||
pty.ExpectMatch("test")
|
||||
err := ps.Wait()
|
||||
require.NoError(t, err)
|
||||
err = pty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Kill", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ps := ptytest.Start(t, exec.Command("sleep", "30"))
|
||||
pty, ps := ptytest.Start(t, exec.Command("sleep", "30"))
|
||||
err := ps.Kill()
|
||||
assert.NoError(t, err)
|
||||
err = ps.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, xerrors.As(err, &exitErr))
|
||||
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||
err = pty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SSH_TTY", func(t *testing.T) {
|
||||
|
@ -53,5 +58,29 @@ func TestStart(t *testing.T) {
|
|||
pty.ExpectMatch("SSH_TTY=/dev/")
|
||||
err := ps.Wait()
|
||||
require.NoError(t, err)
|
||||
err = pty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// these constants/vars are used by Test_Start_copy
|
||||
|
||||
const cmdEcho = "echo"
|
||||
|
||||
var argEcho = []string{"test"}
|
||||
|
||||
// these constants/vars are used by Test_Start_truncate
|
||||
|
||||
const (
|
||||
countEnd = 1000
|
||||
cmdCount = "sh"
|
||||
)
|
||||
|
||||
var argCount = []string{"-c", `
|
||||
i=0
|
||||
while [ $i -ne 1000 ]
|
||||
do
|
||||
i=$(($i+1))
|
||||
echo "$i"
|
||||
done
|
||||
`}
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
package pty_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hinshun/vt10x"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/pty"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
// Test_Start_copy tests that we can use io.Copy() on command output
|
||||
// without deadlocking.
|
||||
func Test_Start_copy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...))
|
||||
require.NoError(t, err)
|
||||
b := &bytes.Buffer{}
|
||||
readDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := io.Copy(b, pc.OutputReader())
|
||||
readDone <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-readDone:
|
||||
require.NoError(t, err)
|
||||
case <-ctx.Done():
|
||||
t.Error("read timed out")
|
||||
}
|
||||
assert.Contains(t, b.String(), "test")
|
||||
|
||||
cmdDone := make(chan error, 1)
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-cmdDone:
|
||||
require.NoError(t, err)
|
||||
case <-ctx.Done():
|
||||
t.Error("cmd.Wait() timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// Test_Start_truncation tests that we can read command output without truncation
|
||||
// even after the command has exited.
|
||||
func Test_Start_truncation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
|
||||
pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...))
|
||||
|
||||
require.NoError(t, err)
|
||||
readDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
// avoid buffered IO so that we can precisely control how many bytes to read.
|
||||
n := 1
|
||||
for n <= countEnd {
|
||||
want := fmt.Sprintf("%d", n)
|
||||
err := readUntil(ctx, t, want, pc.OutputReader())
|
||||
assert.NoError(t, err, "want: %s", want)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n++
|
||||
if (countEnd - n) < 100 {
|
||||
// If the OS buffers the output, the process can exit even if
|
||||
// we're not done reading. We want to slow our reads so that
|
||||
// if there is a race between reading the data and it being
|
||||
// truncated, we will lose and fail the test.
|
||||
time.Sleep(testutil.IntervalFast)
|
||||
}
|
||||
}
|
||||
// ensure we still get to EOF
|
||||
endB := &bytes.Buffer{}
|
||||
_, err := io.Copy(endB, pc.OutputReader())
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
cmdDone := make(chan error, 1)
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-cmdDone:
|
||||
require.NoError(t, err)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("cmd.Wait() timed out")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-readDone:
|
||||
// OK!
|
||||
case <-ctx.Done():
|
||||
t.Fatal("read timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// readUntil reads one byte at a time until we either see the string we want, or the context expires
|
||||
func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error {
|
||||
// output can contain virtual terminal sequences, so we need to parse these
|
||||
// to correctly interpret getting what we want.
|
||||
term := vt10x.New(vt10x.WithSize(80, 80))
|
||||
readErrs := make(chan error, 1)
|
||||
for {
|
||||
b := make([]byte, 1)
|
||||
go func() {
|
||||
_, err := r.Read(b)
|
||||
readErrs <- err
|
||||
}()
|
||||
select {
|
||||
case err := <-readErrs:
|
||||
if err != nil {
|
||||
t.Logf("err: %v\ngot: %v", err, term)
|
||||
return err
|
||||
}
|
||||
term.Write(b)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
got := term.String()
|
||||
lines := strings.Split(got, "\n")
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == want {
|
||||
t.Logf("want: %v\n got:%v", want, line)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
// Allocates a PTY and starts the specified command attached to it.
|
||||
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
|
||||
func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
|
||||
func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) {
|
||||
var opts startOptions
|
||||
for _, o := range opt {
|
||||
o(&opts)
|
||||
|
@ -45,11 +45,18 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pty, err := newPty(opts.ptyOpts...)
|
||||
|
||||
winPty, err := newPty(opts.ptyOpts...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
winPty := pty.(*ptyWindows)
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
// we hit some error finishing setup; close pty, so
|
||||
// we don't leak the kernel resources associated with it
|
||||
_ = winPty.Close()
|
||||
}
|
||||
}()
|
||||
if winPty.opts.sshReq != nil {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", winPty.Name()))
|
||||
}
|
||||
|
@ -95,9 +102,34 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
|
|||
wp := &windowsProcess{
|
||||
cmdDone: make(chan any),
|
||||
proc: process,
|
||||
pw: winPty,
|
||||
}
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
// if we later error out, kill the process since
|
||||
// the caller will have no way to interact with it
|
||||
_ = process.Kill()
|
||||
}
|
||||
}()
|
||||
|
||||
// Now that we've started the command, and passed the pseudoconsole to it,
|
||||
// close the output write and input read files, so that the other process
|
||||
// has the only handles to them. Once the process closes the console, there
|
||||
// will be no open references and the OS kernel returns an error when trying
|
||||
// to read or write to our end. Without this, reading from the process
|
||||
// output will block until they are closed.
|
||||
errO := winPty.outputWrite.Close()
|
||||
winPty.outputWrite = nil
|
||||
errI := winPty.inputRead.Close()
|
||||
winPty.inputRead = nil
|
||||
if errO != nil {
|
||||
return nil, nil, errO
|
||||
}
|
||||
if errI != nil {
|
||||
return nil, nil, errI
|
||||
}
|
||||
go wp.waitInternal()
|
||||
return pty, wp, nil
|
||||
return winPty, wp, nil
|
||||
}
|
||||
|
||||
// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
package pty_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
|
@ -22,25 +23,46 @@ func TestStart(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Echo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
|
||||
pty.ExpectMatch("test")
|
||||
ptty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
|
||||
ptty.ExpectMatch("test")
|
||||
err := ps.Wait()
|
||||
require.NoError(t, err)
|
||||
err = ptty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("Resize", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pty, _ := ptytest.Start(t, exec.Command("cmd.exe"))
|
||||
err := pty.Resize(100, 50)
|
||||
ptty, _ := ptytest.Start(t, exec.Command("cmd.exe"))
|
||||
err := ptty.Resize(100, 50)
|
||||
require.NoError(t, err)
|
||||
err = ptty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("Kill", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ps := ptytest.Start(t, exec.Command("cmd.exe"))
|
||||
ptty, ps := ptytest.Start(t, exec.Command("cmd.exe"))
|
||||
err := ps.Kill()
|
||||
assert.NoError(t, err)
|
||||
err = ps.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, xerrors.As(err, &exitErr))
|
||||
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||
err = ptty.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// these constants/vars are used by Test_Start_copy
|
||||
|
||||
const cmdEcho = "cmd.exe"
|
||||
|
||||
var argEcho = []string{"/c", "echo", "test"}
|
||||
|
||||
// these constants/vars are used by Test_Start_truncate
|
||||
|
||||
const (
|
||||
countEnd = 1000
|
||||
cmdCount = "cmd.exe"
|
||||
)
|
||||
|
||||
var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)}
|
||||
|
|
Loading…
Reference in New Issue