fix(agent): Close stdin and stdout separately to fix pty output loss (#6862)

Fixes #6656
Closes #6840
This commit is contained in:
Mathias Fredriksson 2023-03-29 21:58:38 +03:00 committed by GitHub
parent 349bfad2e9
commit 90d18dd2e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 61 additions and 29 deletions

View File

@ -1115,7 +1115,8 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
go func() {
for win := range windowSize {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
if resizeErr != nil {
// If the pty is closed, then command has exited, no need to log.
if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) {
a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
}
}
@ -1131,19 +1132,32 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
// output being lost. To avoid this, we wait for the output copy to
// start before waiting for the command to exit. This ensures that the
// output copy goroutine will be scheduled before calling close on the
// pty. There is still a risk of data loss if a command produces a lot
// of output, see TestAgent_Session_TTY_HugeOutputIsNotLost (skipped).
// pty. This shouldn't be needed because of `pty.Dup()` below, but it
// may not be supported on all platforms.
outputCopyStarted := make(chan struct{})
ptyOutput := func() io.Reader {
ptyOutput := func() io.ReadCloser {
defer close(outputCopyStarted)
return ptty.Output()
// Try to dup so we can separate stdin and stdout closure.
// Once the original pty is closed, the dup will return
// input/output error once the buffered data has been read.
stdout, err := ptty.Dup()
if err == nil {
return stdout
}
// If we can't dup, we shouldn't close
// the fd since it's tied to stdin.
return readNopCloser{ptty.Output()}
}
wg.Add(1)
go func() {
// Ensure data is flushed to session on command exit, if we
// close the session too soon, we might lose data.
defer wg.Done()
_, _ = io.Copy(session, ptyOutput())
stdout := ptyOutput()
defer stdout.Close()
_, _ = io.Copy(session, stdout)
}()
<-outputCopyStarted
@ -1176,6 +1190,11 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
return cmd.Wait()
}
type readNopCloser struct{ io.Reader }
// Close implements io.Closer.
func (readNopCloser) Close() error { return nil }
func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) {
defer conn.Close()

View File

@ -350,15 +350,8 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
// This test is here to prevent regressions where quickly executing
// commands (with TTY) don't flush their output to the SSH session.
// commands (with TTY) don't sync their output to the SSH session.
//
// See: https://github.com/coder/coder/issues/6656
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@ -404,20 +397,13 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
t.Skip("This test proves we have a bug where parts of large output on a PTY can be lost after the command exits, skipped to avoid test failures.")
// This test is here to prevent prove we have a bug where quickly executing
// commands (with TTY) don't flush their output to the SSH session. This is
// due to the pty being closed before all the output has been copied, but
// protecting against this requires a non-trivial rewrite of the output
// processing (or figuring out a way to put the pty in a mode where this
// does not happen).
// This test is here to prevent regressions where a command (with or
// without) a large amount of output would not be fully copied to the
// SSH session. On unix systems, this was fixed by duplicating the file
// descriptor of the PTY master and using it for copying the output.
//
// See: https://github.com/coder/coder/issues/6656
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled

View File

@ -6,8 +6,12 @@ import (
"os"
"github.com/gliderlabs/ssh"
"golang.org/x/xerrors"
)
// ErrClosed is returned when a PTY is used after it has been closed.
var ErrClosed = xerrors.New("pty: closed")
// PTY is a minimal interface for interacting with a TTY.
type PTY interface {
io.Closer
@ -31,6 +35,11 @@ type PTY interface {
// The same stream would be used to provide user input: pty.Input().Write(...)
Input() ReadWriter
// Dup returns a new file descriptor for the PTY.
//
// This is useful for closing stdin and stdout separately.
Dup() (*os.File, error)
// Resize sets the size of the PTY.
Resize(height uint16, width uint16) error
}

View File

@ -7,11 +7,11 @@ import (
"os/exec"
"runtime"
"sync"
"syscall"
"github.com/creack/pty"
"github.com/u-root/u-root/pkg/termios"
"golang.org/x/sys/unix"
"golang.org/x/xerrors"
)
func newPty(opt ...Option) (retPTY *otherPty, err error) {
@ -113,6 +113,20 @@ func (p *otherPty) Resize(height uint16, width uint16) error {
})
}
func (p *otherPty) Dup() (*os.File, error) {
var newfd int
err := p.control(p.pty, func(fd uintptr) error {
var err error
newfd, err = syscall.Dup(int(fd))
return err
})
if err != nil {
return nil, err
}
return os.NewFile(uintptr(newfd), p.pty.Name()), nil
}
func (p *otherPty) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
@ -131,7 +145,7 @@ func (p *otherPty) Close() error {
if err != nil {
p.err = err
} else {
p.err = xerrors.New("pty: closed")
p.err = ErrClosed
}
return err

View File

@ -123,6 +123,10 @@ func (p *ptyWindows) Resize(height uint16, width uint16) error {
return nil
}
func (p *ptyWindows) Dup() (*os.File, error) {
return nil, xerrors.Errorf("not implemented")
}
func (p *ptyWindows) Close() error {
p.closeMutex.Lock()
defer p.closeMutex.Unlock()