mirror of https://github.com/coder/coder.git
fix: pty.Start respects context on Windows too (#7373)
* fix: pty.Start respects context on Windows too Signed-off-by: Spike Curtis <spike@coder.com> * Fix windows imports; rename ToExec -> AsExec Signed-off-by: Spike Curtis <spike@coder.com> * Fix import in windows test Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
e6931d6920
commit
9c030a8888
|
@ -216,11 +216,12 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
|
|||
// if it can guarantee the clocks are synchronized.
|
||||
CollectedAt: time.Now(),
|
||||
}
|
||||
cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
|
||||
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
|
||||
if err != nil {
|
||||
result.Error = fmt.Sprintf("create cmd: %+v", err)
|
||||
return result
|
||||
}
|
||||
cmd := cmdPty.AsExec()
|
||||
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &out
|
||||
|
@ -842,10 +843,11 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
|
|||
}()
|
||||
}
|
||||
|
||||
cmd, err := a.sshServer.CreateCommand(ctx, script, nil)
|
||||
cmdPty, err := a.sshServer.CreateCommand(ctx, script, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create command: %w", err)
|
||||
}
|
||||
cmd := cmdPty.AsExec()
|
||||
cmd.Stdout = writer
|
||||
cmd.Stderr = writer
|
||||
err = cmd.Run()
|
||||
|
@ -1044,16 +1046,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
|||
circularBuffer: circularBuffer,
|
||||
}
|
||||
a.reconnectingPTYs.Store(msg.ID, rpty)
|
||||
go func() {
|
||||
// CommandContext isn't respected for Windows PTYs right now,
|
||||
// so we need to manually track the lifecycle.
|
||||
// When the context has been completed either:
|
||||
// 1. The timeout completed.
|
||||
// 2. The parent context was canceled.
|
||||
<-ctx.Done()
|
||||
logger.Debug(ctx, "context done", slog.Error(ctx.Err()))
|
||||
_ = process.Kill()
|
||||
}()
|
||||
// We don't need to separately monitor for the process exiting.
|
||||
// When it exits, our ptty.OutputReader() will return EOF after
|
||||
// reading all process output.
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
@ -1697,7 +1696,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*pt
|
|||
"host",
|
||||
)
|
||||
args = append(args, afterArgs...)
|
||||
cmd := exec.Command("ssh", args...)
|
||||
cmd := pty.Command("ssh", args...)
|
||||
return ptytest.Start(t, cmd)
|
||||
}
|
||||
|
||||
|
|
|
@ -255,7 +255,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
|
|||
if isPty {
|
||||
return s.startPTYSession(session, cmd, sshPty, windowSize)
|
||||
}
|
||||
return startNonPTYSession(session, cmd)
|
||||
return startNonPTYSession(session, cmd.AsExec())
|
||||
}
|
||||
|
||||
func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error {
|
||||
|
@ -287,7 +287,7 @@ type ptySession interface {
|
|||
RawCommand() string
|
||||
}
|
||||
|
||||
func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
|
||||
func (s *Server) startPTYSession(session ptySession, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) {
|
||||
ctx := session.Context()
|
||||
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
|
||||
// See https://github.com/coder/coder/issues/3371.
|
||||
|
@ -413,7 +413,7 @@ func (s *Server) sftpHandler(session ssh.Session) {
|
|||
// CreateCommand processes raw command input with OpenSSH-like behavior.
|
||||
// If the script provided is empty, it will default to the users shell.
|
||||
// This injects environment variables specified by the user at launch too.
|
||||
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) {
|
||||
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get current user: %w", err)
|
||||
|
@ -449,7 +449,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
|
|||
}
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, shell, args...)
|
||||
cmd := pty.CommandContext(ctx, shell, args...)
|
||||
cmd.Dir = manifest.Directory
|
||||
|
||||
// If the metadata directory doesn't exist, we run the command
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
gliderssh "github.com/gliderlabs/ssh"
|
||||
|
@ -15,6 +14,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/pty"
|
||||
"github.com/coder/coder/testutil"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
@ -52,7 +52,7 @@ func Test_sessionStart_orphan(t *testing.T) {
|
|||
close(windowSize)
|
||||
// the command gets the session context so that Go will terminate it when
|
||||
// the session expires.
|
||||
cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript)
|
||||
cmd := pty.CommandContext(sessionCtx, "sh", "-c", longScript)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
|
|
@ -540,7 +540,7 @@ Expire-Date: 0
|
|||
require.NoError(t, err, "import ownertrust failed: %s", out)
|
||||
|
||||
// Start the GPG agent.
|
||||
agentCmd := exec.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
|
||||
agentCmd := pty.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
|
||||
agentCmd.Env = append(agentCmd.Env, "GNUPGHOME="+gnupgHomeClient)
|
||||
agentPTY, agentProc, err := pty.Start(agentCmd, pty.WithPTYOption(pty.WithGPGTTY()))
|
||||
require.NoError(t, err, "launch agent failed")
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package pty
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
@ -214,3 +215,13 @@ func (p *windowsProcess) Wait() error {
|
|||
func (p *windowsProcess) Kill() error {
|
||||
return p.proc.Kill()
|
||||
}
|
||||
|
||||
// killOnContext waits for the context to be done and kills the process, unless it exits on its own first.
|
||||
func (p *windowsProcess) killOnContext(ctx context.Context) {
|
||||
select {
|
||||
case <-p.cmdDone:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
p.Kill()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -44,7 +43,7 @@ func New(t *testing.T, opts ...pty.Option) *PTY {
|
|||
|
||||
// Start starts a new process asynchronously and returns a PTYCmd and Process.
|
||||
// It kills the process and PTYCmd upon cleanup
|
||||
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
|
||||
func Start(t *testing.T, cmd *pty.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) {
|
||||
t.Helper()
|
||||
|
||||
ptty, ps, err := pty.Start(cmd, opts...)
|
||||
|
|
37
pty/start.go
37
pty/start.go
|
@ -1,6 +1,7 @@
|
|||
package pty
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
|
@ -18,8 +19,42 @@ func WithPTYOption(opts ...Option) StartOption {
|
|||
}
|
||||
}
|
||||
|
||||
// Cmd is a drop-in replacement for exec.Cmd with most of the same API, but
|
||||
// it exposes the context.Context to our PTY code so that we can still kill the
|
||||
// process when the Context expires. This is required because on Windows, we don't
|
||||
// start the command using the `exec` library, so we have to manage the context
|
||||
// ourselves.
|
||||
type Cmd struct {
|
||||
Context context.Context
|
||||
Path string
|
||||
Args []string
|
||||
Env []string
|
||||
Dir string
|
||||
}
|
||||
|
||||
func CommandContext(ctx context.Context, name string, arg ...string) *Cmd {
|
||||
return &Cmd{
|
||||
Context: ctx,
|
||||
Path: name,
|
||||
Args: append([]string{name}, arg...),
|
||||
Env: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func Command(name string, arg ...string) *Cmd {
|
||||
return CommandContext(context.Background(), name, arg...)
|
||||
}
|
||||
|
||||
func (c *Cmd) AsExec() *exec.Cmd {
|
||||
//nolint: gosec
|
||||
execCmd := exec.CommandContext(c.Context, c.Path, c.Args[1:]...)
|
||||
execCmd.Dir = c.Dir
|
||||
execCmd.Env = c.Env
|
||||
return execCmd
|
||||
}
|
||||
|
||||
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
|
||||
// instead rely on the returned Process to manage the command/process.
|
||||
func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) {
|
||||
func Start(cmd *Cmd, opt ...StartOption) (PTYCmd, Process, error) {
|
||||
return startPty(cmd, opt...)
|
||||
}
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
package pty
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
@ -12,7 +12,7 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
|
||||
func startPty(cmdPty *Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
|
||||
var opts startOptions
|
||||
for _, o := range opt {
|
||||
o(&opts)
|
||||
|
@ -23,30 +23,34 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
|
|||
return nil, nil, xerrors.Errorf("newPty failed: %w", err)
|
||||
}
|
||||
|
||||
origEnv := cmd.Env
|
||||
origEnv := cmdPty.Env
|
||||
if opty.opts.sshReq != nil {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
|
||||
cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
|
||||
}
|
||||
if opty.opts.setGPGTTY {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
|
||||
cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
|
||||
}
|
||||
if cmdPty.Context == nil {
|
||||
cmdPty.Context = context.Background()
|
||||
}
|
||||
cmdExec := cmdPty.AsExec()
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
cmdExec.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setsid: true,
|
||||
Setctty: true,
|
||||
}
|
||||
cmd.Stdout = opty.tty
|
||||
cmd.Stderr = opty.tty
|
||||
cmd.Stdin = opty.tty
|
||||
err = cmd.Start()
|
||||
cmdExec.Stdout = opty.tty
|
||||
cmdExec.Stderr = opty.tty
|
||||
cmdExec.Stdin = opty.tty
|
||||
err = cmdExec.Start()
|
||||
if err != nil {
|
||||
_ = opty.Close()
|
||||
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "bad file descriptor") {
|
||||
// macOS has an obscure issue where the PTY occasionally closes
|
||||
// before it's used. It's unknown why this is, but creating a new
|
||||
// TTY resolves it.
|
||||
cmd.Env = origEnv
|
||||
return startPty(cmd, opt...)
|
||||
cmdPty.Env = origEnv
|
||||
return startPty(cmdPty, opt...)
|
||||
}
|
||||
return nil, nil, xerrors.Errorf("start: %w", err)
|
||||
}
|
||||
|
@ -64,14 +68,14 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
|
|||
// confirming this, but I did find a thread of someone else's
|
||||
// observations: https://developer.apple.com/forums/thread/663632
|
||||
if err := opty.tty.Close(); err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmdExec.Process.Kill()
|
||||
return nil, nil, xerrors.Errorf("close tty: %w", err)
|
||||
}
|
||||
opty.tty = nil // remove so we don't attempt to close it again.
|
||||
}
|
||||
oProcess := &otherProcess{
|
||||
pty: opty.pty,
|
||||
cmd: cmd,
|
||||
cmd: cmdExec,
|
||||
cmdDone: make(chan any),
|
||||
}
|
||||
go oProcess.waitInternal()
|
||||
|
|
|
@ -24,7 +24,7 @@ func TestStart(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Echo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pty, ps := ptytest.Start(t, exec.Command("echo", "test"))
|
||||
pty, ps := ptytest.Start(t, pty.Command("echo", "test"))
|
||||
|
||||
pty.ExpectMatch("test")
|
||||
err := ps.Wait()
|
||||
|
@ -35,7 +35,7 @@ func TestStart(t *testing.T) {
|
|||
|
||||
t.Run("Kill", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pty, ps := ptytest.Start(t, exec.Command("sleep", "30"))
|
||||
pty, ps := ptytest.Start(t, pty.Command("sleep", "30"))
|
||||
err := ps.Kill()
|
||||
assert.NoError(t, err)
|
||||
err = ps.Wait()
|
||||
|
@ -54,7 +54,7 @@ func TestStart(t *testing.T) {
|
|||
Height: 24,
|
||||
},
|
||||
}))
|
||||
pty, ps := ptytest.Start(t, exec.Command("env"), opts)
|
||||
pty, ps := ptytest.Start(t, pty.Command("env"), opts)
|
||||
pty.ExpectMatch("SSH_TTY=/dev/")
|
||||
err := ps.Wait()
|
||||
require.NoError(t, err)
|
||||
|
@ -84,3 +84,9 @@ do
|
|||
echo "$i"
|
||||
done
|
||||
`}
|
||||
|
||||
// these constants/vars are used by Test_Start_cancel_context
|
||||
|
||||
const cmdSleep = "sleep"
|
||||
|
||||
var argSleep = []string{"30"}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -26,7 +25,7 @@ func Test_Start_copy(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...))
|
||||
pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdEcho, argEcho...))
|
||||
require.NoError(t, err)
|
||||
b := &bytes.Buffer{}
|
||||
readDone := make(chan error, 1)
|
||||
|
@ -64,7 +63,7 @@ func Test_Start_truncation(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
|
||||
pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...))
|
||||
pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdCount, argCount...))
|
||||
|
||||
require.NoError(t, err)
|
||||
readDone := make(chan struct{})
|
||||
|
@ -114,6 +113,35 @@ func Test_Start_truncation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test_Start_cancel_context tests that we can cancel the command context and kill the process.
|
||||
func Test_Start_cancel_context(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
cmdCtx, cmdCancel := context.WithCancel(ctx)
|
||||
|
||||
pc, cmd, err := pty.Start(pty.CommandContext(cmdCtx, cmdSleep, argSleep...))
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = pc.Close()
|
||||
}()
|
||||
cmdCancel()
|
||||
|
||||
cmdDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(cmdDone)
|
||||
_ = cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
// OK!
|
||||
case <-ctx.Done():
|
||||
t.Error("cmd.Wait() timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// readUntil reads one byte at a time until we either see the string we want, or the context expires
|
||||
func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error {
|
||||
// output can contain virtual terminal sequences, so we need to parse these
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
|
||||
// Allocates a PTY and starts the specified command attached to it.
|
||||
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
|
||||
func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) {
|
||||
func startPty(cmd *Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) {
|
||||
var opts startOptions
|
||||
for _, o := range opt {
|
||||
o(&opts)
|
||||
|
@ -129,6 +129,9 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr er
|
|||
return nil, nil, errI
|
||||
}
|
||||
go wp.waitInternal()
|
||||
if cmd.Context != nil {
|
||||
go wp.killOnContext(cmd.Context)
|
||||
}
|
||||
return winPty, wp, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/pty"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -23,7 +24,7 @@ func TestStart(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Echo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ptty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
|
||||
ptty, ps := ptytest.Start(t, pty.Command("cmd.exe", "/c", "echo", "test"))
|
||||
ptty.ExpectMatch("test")
|
||||
err := ps.Wait()
|
||||
require.NoError(t, err)
|
||||
|
@ -32,7 +33,7 @@ func TestStart(t *testing.T) {
|
|||
})
|
||||
t.Run("Resize", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ptty, _ := ptytest.Start(t, exec.Command("cmd.exe"))
|
||||
ptty, _ := ptytest.Start(t, pty.Command("cmd.exe"))
|
||||
err := ptty.Resize(100, 50)
|
||||
require.NoError(t, err)
|
||||
err = ptty.Close()
|
||||
|
@ -40,7 +41,7 @@ func TestStart(t *testing.T) {
|
|||
})
|
||||
t.Run("Kill", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ptty, ps := ptytest.Start(t, exec.Command("cmd.exe"))
|
||||
ptty, ps := ptytest.Start(t, pty.Command("cmd.exe"))
|
||||
err := ps.Kill()
|
||||
assert.NoError(t, err)
|
||||
err = ps.Wait()
|
||||
|
@ -66,3 +67,9 @@ const (
|
|||
)
|
||||
|
||||
var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)}
|
||||
|
||||
// these constants/vars are used by Test_Start_cancel_context
|
||||
|
||||
const cmdSleep = "cmd.exe"
|
||||
|
||||
var argSleep = []string{"/c", "timeout", "/t", "30"}
|
||||
|
|
Loading…
Reference in New Issue