fix: pty.Start respects context on Windows too (#7373)

* fix: pty.Start respects context on Windows too

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix windows imports; rename ToExec -> AsExec

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix import in windows test

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2023-05-03 11:43:05 +04:00 committed by GitHub
parent e6931d6920
commit 9c030a8888
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 132 additions and 48 deletions

View File

@ -216,11 +216,12 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
// if it can guarantee the clocks are synchronized.
CollectedAt: time.Now(),
}
cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
if err != nil {
result.Error = fmt.Sprintf("create cmd: %+v", err)
return result
}
cmd := cmdPty.AsExec()
cmd.Stdout = &out
cmd.Stderr = &out
@ -842,10 +843,11 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
}()
}
cmd, err := a.sshServer.CreateCommand(ctx, script, nil)
cmdPty, err := a.sshServer.CreateCommand(ctx, script, nil)
if err != nil {
return xerrors.Errorf("create command: %w", err)
}
cmd := cmdPty.AsExec()
cmd.Stdout = writer
cmd.Stderr = writer
err = cmd.Run()
@ -1044,16 +1046,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
circularBuffer: circularBuffer,
}
a.reconnectingPTYs.Store(msg.ID, rpty)
go func() {
// CommandContext isn't respected for Windows PTYs right now,
// so we need to manually track the lifecycle.
// When the context has been completed either:
// 1. The timeout completed.
// 2. The parent context was canceled.
<-ctx.Done()
logger.Debug(ctx, "context done", slog.Error(ctx.Err()))
_ = process.Kill()
}()
// We don't need to separately monitor for the process exiting.
// When it exits, our ptty.OutputReader() will return EOF after
// reading all process output.

View File

@ -12,7 +12,6 @@ import (
"net/http/httptest"
"net/netip"
"os"
"os/exec"
"os/user"
"path"
"path/filepath"
@ -1697,7 +1696,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*pt
"host",
)
args = append(args, afterArgs...)
cmd := exec.Command("ssh", args...)
cmd := pty.Command("ssh", args...)
return ptytest.Start(t, cmd)
}

View File

@ -255,7 +255,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
if isPty {
return s.startPTYSession(session, cmd, sshPty, windowSize)
}
return startNonPTYSession(session, cmd)
return startNonPTYSession(session, cmd.AsExec())
}
func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
@ -287,7 +287,7 @@ type ptySession interface {
RawCommand() string
}
func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
func (s *Server) startPTYSession(session ptySession, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
ctx := session.Context()
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
// See https://github.com/coder/coder/issues/3371.
@ -413,7 +413,7 @@ func (s *Server) sftpHandler(session ssh.Session) {
// CreateCommand processes raw command input with OpenSSH-like behavior.
// If the script provided is empty, it will default to the users shell.
// This injects environment variables specified by the user at launch too.
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) {
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
currentUser, err := user.Current()
if err != nil {
return nil, xerrors.Errorf("get current user: %w", err)
@ -449,7 +449,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
}
}
cmd := exec.CommandContext(ctx, shell, args...)
cmd := pty.CommandContext(ctx, shell, args...)
cmd.Dir = manifest.Directory
// If the metadata directory doesn't exist, we run the command

View File

@ -7,7 +7,6 @@ import (
"context"
"io"
"net"
"os/exec"
"testing"
gliderssh "github.com/gliderlabs/ssh"
@ -15,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty"
"github.com/coder/coder/testutil"
"cdr.dev/slog/sloggers/slogtest"
@ -52,7 +52,7 @@ func Test_sessionStart_orphan(t *testing.T) {
close(windowSize)
// the command gets the session context so that Go will terminate it when
// the session expires.
cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript)
cmd := pty.CommandContext(sessionCtx, "sh", "-c", longScript)
done := make(chan struct{})
go func() {

View File

@ -540,7 +540,7 @@ Expire-Date: 0
require.NoError(t, err, "import ownertrust failed: %s", out)
// Start the GPG agent.
agentCmd := exec.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
agentCmd := pty.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
agentCmd.Env = append(agentCmd.Env, "GNUPGHOME="+gnupgHomeClient)
agentPTY, agentProc, err := pty.Start(agentCmd, pty.WithPTYOption(pty.WithGPGTTY()))
require.NoError(t, err, "launch agent failed")

View File

@ -3,6 +3,7 @@
package pty
import (
"context"
"io"
"os"
"os/exec"
@ -214,3 +215,13 @@ func (p *windowsProcess) Wait() error {
func (p *windowsProcess) Kill() error {
return p.proc.Kill()
}
// killOnContext waits for the context to be done and kills the process, unless it exits on its own first.
func (p *windowsProcess) killOnContext(ctx context.Context) {
select {
case <-p.cmdDone:
return
case <-ctx.Done():
p.Kill()
}
}

View File

@ -6,7 +6,6 @@ import (
"context"
"fmt"
"io"
"os/exec"
"runtime"
"strings"
"sync"
@ -44,7 +43,7 @@ func New(t *testing.T, opts ...pty.Option) *PTY {
// Start starts a new process asynchronously and returns a PTYCmd and Process.
// It kills the process and PTYCmd upon cleanup
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
func Start(t *testing.T, cmd *pty.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
t.Helper()
ptty, ps, err := pty.Start(cmd, opts...)

View File

@ -1,6 +1,7 @@
package pty
import (
"context"
"os/exec"
)
@ -18,8 +19,42 @@ func WithPTYOption(opts ...Option) StartOption {
}
}
// Cmd is a drop-in replacement for exec.Cmd with most of the same API, but
// it exposes the context.Context to our PTY code so that we can still kill the
// process when the Context expires. This is required because on Windows, we don't
// start the command using the `exec` library, so we have to manage the context
// ourselves.
type Cmd struct {
Context context.Context
Path string
Args []string
Env []string
Dir string
}
func CommandContext(ctx context.Context, name string, arg ...string) *Cmd {
return &Cmd{
Context: ctx,
Path: name,
Args: append([]string{name}, arg...),
Env: make([]string, 0),
}
}
func Command(name string, arg ...string) *Cmd {
return CommandContext(context.Background(), name, arg...)
}
func (c *Cmd) AsExec() *exec.Cmd {
//nolint: gosec
execCmd := exec.CommandContext(c.Context, c.Path, c.Args[1:]...)
execCmd.Dir = c.Dir
execCmd.Env = c.Env
return execCmd
}
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
// instead rely on the returned Process to manage the command/process.
func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) {
func Start(cmd *Cmd, opt ...StartOption) (PTYCmd, Process, error) {
return startPty(cmd, opt...)
}

View File

@ -3,8 +3,8 @@
package pty
import (
"context"
"fmt"
"os/exec"
"runtime"
"strings"
"syscall"
@ -12,7 +12,7 @@ import (
"golang.org/x/xerrors"
)
func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
func startPty(cmdPty *Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
var opts startOptions
for _, o := range opt {
o(&opts)
@ -23,30 +23,34 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
return nil, nil, xerrors.Errorf("newPty failed: %w", err)
}
origEnv := cmd.Env
origEnv := cmdPty.Env
if opty.opts.sshReq != nil {
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
}
if opty.opts.setGPGTTY {
cmd.Env = append(cmd.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
}
if cmdPty.Context == nil {
cmdPty.Context = context.Background()
}
cmdExec := cmdPty.AsExec()
cmd.SysProcAttr = &syscall.SysProcAttr{
cmdExec.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
Setctty: true,
}
cmd.Stdout = opty.tty
cmd.Stderr = opty.tty
cmd.Stdin = opty.tty
err = cmd.Start()
cmdExec.Stdout = opty.tty
cmdExec.Stderr = opty.tty
cmdExec.Stdin = opty.tty
err = cmdExec.Start()
if err != nil {
_ = opty.Close()
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "bad file descriptor") {
// macOS has an obscure issue where the PTY occasionally closes
// before it's used. It's unknown why this is, but creating a new
// TTY resolves it.
cmd.Env = origEnv
return startPty(cmd, opt...)
cmdPty.Env = origEnv
return startPty(cmdPty, opt...)
}
return nil, nil, xerrors.Errorf("start: %w", err)
}
@ -64,14 +68,14 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
// confirming this, but I did find a thread of someone else's
// observations: https://developer.apple.com/forums/thread/663632
if err := opty.tty.Close(); err != nil {
_ = cmd.Process.Kill()
_ = cmdExec.Process.Kill()
return nil, nil, xerrors.Errorf("close tty: %w", err)
}
opty.tty = nil // remove so we don't attempt to close it again.
}
oProcess := &otherProcess{
pty: opty.pty,
cmd: cmd,
cmd: cmdExec,
cmdDone: make(chan any),
}
go oProcess.waitInternal()

View File

@ -24,7 +24,7 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
pty, ps := ptytest.Start(t, pty.Command("echo", "test"))
pty.ExpectMatch("test")
err := ps.Wait()
@ -35,7 +35,7 @@ func TestStart(t *testing.T) {
t.Run("Kill", func(t *testing.T) {
t.Parallel()
pty, ps := ptytest.Start(t, exec.Command("sleep", "30"))
pty, ps := ptytest.Start(t, pty.Command("sleep", "30"))
err := ps.Kill()
assert.NoError(t, err)
err = ps.Wait()
@ -54,7 +54,7 @@ func TestStart(t *testing.T) {
Height: 24,
},
}))
pty, ps := ptytest.Start(t, exec.Command("env"), opts)
pty, ps := ptytest.Start(t, pty.Command("env"), opts)
pty.ExpectMatch("SSH_TTY=/dev/")
err := ps.Wait()
require.NoError(t, err)
@ -84,3 +84,9 @@ do
echo "$i"
done
`}
// these constants/vars are used by Test_Start_cancel_context
const cmdSleep = "sleep"
var argSleep = []string{"30"}

View File

@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"os/exec"
"strings"
"testing"
"time"
@ -26,7 +25,7 @@ func Test_Start_copy(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...))
pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdEcho, argEcho...))
require.NoError(t, err)
b := &bytes.Buffer{}
readDone := make(chan error, 1)
@ -64,7 +63,7 @@ func Test_Start_truncation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...))
pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdCount, argCount...))
require.NoError(t, err)
readDone := make(chan struct{})
@ -114,6 +113,35 @@ func Test_Start_truncation(t *testing.T) {
}
}
// Test_Start_cancel_context tests that we can cancel the command context and kill the process.
func Test_Start_cancel_context(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
cmdCtx, cmdCancel := context.WithCancel(ctx)
pc, cmd, err := pty.Start(pty.CommandContext(cmdCtx, cmdSleep, argSleep...))
require.NoError(t, err)
defer func() {
_ = pc.Close()
}()
cmdCancel()
cmdDone := make(chan struct{})
go func() {
defer close(cmdDone)
_ = cmd.Wait()
}()
select {
case <-cmdDone:
// OK!
case <-ctx.Done():
t.Error("cmd.Wait() timed out")
}
}
// readUntil reads one byte at a time until we either see the string we want, or the context expires
func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error {
// output can contain virtual terminal sequences, so we need to parse these

View File

@ -17,7 +17,7 @@ import (
// Allocates a PTY and starts the specified command attached to it.
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) {
func startPty(cmd *Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) {
var opts startOptions
for _, o := range opt {
o(&opts)
@ -129,6 +129,9 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr er
return nil, nil, errI
}
go wp.waitInternal()
if cmd.Context != nil {
go wp.killOnContext(cmd.Context)
}
return winPty, wp, nil
}

View File

@ -8,6 +8,7 @@ import (
"os/exec"
"testing"
"github.com/coder/coder/pty"
"github.com/coder/coder/pty/ptytest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -23,7 +24,7 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
ptty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
ptty, ps := ptytest.Start(t, pty.Command("cmd.exe", "/c", "echo", "test"))
ptty.ExpectMatch("test")
err := ps.Wait()
require.NoError(t, err)
@ -32,7 +33,7 @@ func TestStart(t *testing.T) {
})
t.Run("Resize", func(t *testing.T) {
t.Parallel()
ptty, _ := ptytest.Start(t, exec.Command("cmd.exe"))
ptty, _ := ptytest.Start(t, pty.Command("cmd.exe"))
err := ptty.Resize(100, 50)
require.NoError(t, err)
err = ptty.Close()
@ -40,7 +41,7 @@ func TestStart(t *testing.T) {
})
t.Run("Kill", func(t *testing.T) {
t.Parallel()
ptty, ps := ptytest.Start(t, exec.Command("cmd.exe"))
ptty, ps := ptytest.Start(t, pty.Command("cmd.exe"))
err := ps.Kill()
assert.NoError(t, err)
err = ps.Wait()
@ -66,3 +67,9 @@ const (
)
var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)}
// these constants/vars are used by Test_Start_cancel_context
const cmdSleep = "cmd.exe"
var argSleep = []string{"/c", "timeout", "/t", "30"}