diff --git a/agent/agent.go b/agent/agent.go index 34af2b59f3..ea22069fd0 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -216,11 +216,12 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM // if it can guarantee the clocks are synchronized. CollectedAt: time.Now(), } - cmd, err := a.sshServer.CreateCommand(ctx, md.Script, nil) + cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil) if err != nil { result.Error = fmt.Sprintf("create cmd: %+v", err) return result } + cmd := cmdPty.AsExec() cmd.Stdout = &out cmd.Stderr = &out @@ -842,10 +843,11 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error { }() } - cmd, err := a.sshServer.CreateCommand(ctx, script, nil) + cmdPty, err := a.sshServer.CreateCommand(ctx, script, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } + cmd := cmdPty.AsExec() cmd.Stdout = writer cmd.Stderr = writer err = cmd.Run() @@ -1044,16 +1046,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m circularBuffer: circularBuffer, } a.reconnectingPTYs.Store(msg.ID, rpty) - go func() { - // CommandContext isn't respected for Windows PTYs right now, - // so we need to manually track the lifecycle. - // When the context has been completed either: - // 1. The timeout completed. - // 2. The parent context was canceled. - <-ctx.Done() - logger.Debug(ctx, "context done", slog.Error(ctx.Err())) - _ = process.Kill() - }() // We don't need to separately monitor for the process exiting. // When it exits, our ptty.OutputReader() will return EOF after // reading all process output. diff --git a/agent/agent_test.go b/agent/agent_test.go index ee135a05ce..8914a5524f 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -12,7 +12,6 @@ import ( "net/http/httptest" "net/netip" "os" - "os/exec" "os/user" "path" "path/filepath" @@ -1697,7 +1696,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*pt "host", ) args = append(args, afterArgs...) - cmd := exec.Command("ssh", args...) + cmd := pty.Command("ssh", args...) return ptytest.Start(t, cmd) } diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index c9bd17362b..6221751ae8 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -255,7 +255,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er if isPty { return s.startPTYSession(session, cmd, sshPty, windowSize) } - return startNonPTYSession(session, cmd) + return startNonPTYSession(session, cmd.AsExec()) } func startNonPTYSession(session ssh.Session, cmd *exec.Cmd) error { @@ -287,7 +287,7 @@ type ptySession interface { RawCommand() string } -func (s *Server) startPTYSession(session ptySession, cmd *exec.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) { +func (s *Server) startPTYSession(session ptySession, cmd *pty.Cmd, sshPty ssh.Pty, windowSize <-chan ssh.Window) (retErr error) { ctx := session.Context() // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). // See https://github.com/coder/coder/issues/3371. @@ -413,7 +413,7 @@ func (s *Server) sftpHandler(session ssh.Session) { // CreateCommand processes raw command input with OpenSSH-like behavior. // If the script provided is empty, it will default to the users shell. // This injects environment variables specified by the user at launch too. -func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { +func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) { currentUser, err := user.Current() if err != nil { return nil, xerrors.Errorf("get current user: %w", err) @@ -449,7 +449,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string) } } - cmd := exec.CommandContext(ctx, shell, args...) + cmd := pty.CommandContext(ctx, shell, args...) cmd.Dir = manifest.Directory // If the metadata directory doesn't exist, we run the command diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index 33f41dd15a..ed05e53a04 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -7,7 +7,6 @@ import ( "context" "io" "net" - "os/exec" "testing" gliderssh "github.com/gliderlabs/ssh" @@ -15,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/coder/coder/pty" "github.com/coder/coder/testutil" "cdr.dev/slog/sloggers/slogtest" @@ -52,7 +52,7 @@ func Test_sessionStart_orphan(t *testing.T) { close(windowSize) // the command gets the session context so that Go will terminate it when // the session expires. - cmd := exec.CommandContext(sessionCtx, "sh", "-c", longScript) + cmd := pty.CommandContext(sessionCtx, "sh", "-c", longScript) done := make(chan struct{}) go func() { diff --git a/cli/ssh_test.go b/cli/ssh_test.go index ee544a328e..01d81107ab 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -540,7 +540,7 @@ Expire-Date: 0 require.NoError(t, err, "import ownertrust failed: %s", out) // Start the GPG agent. - agentCmd := exec.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath) + agentCmd := pty.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath) agentCmd.Env = append(agentCmd.Env, "GNUPGHOME="+gnupgHomeClient) agentPTY, agentProc, err := pty.Start(agentCmd, pty.WithPTYOption(pty.WithGPGTTY())) require.NoError(t, err, "launch agent failed") diff --git a/pty/pty_windows.go b/pty/pty_windows.go index 80f6b74f43..c7ddf046c2 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -3,6 +3,7 @@ package pty import ( + "context" "io" "os" "os/exec" @@ -214,3 +215,13 @@ func (p *windowsProcess) Wait() error { func (p *windowsProcess) Kill() error { return p.proc.Kill() } + +// killOnContext waits for the context to be done and kills the process, unless it exits on its own first. +func (p *windowsProcess) killOnContext(ctx context.Context) { + select { + case <-p.cmdDone: + return + case <-ctx.Done(): + p.Kill() + } +} diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 69eb81026e..5036668d82 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "io" - "os/exec" "runtime" "strings" "sync" @@ -44,7 +43,7 @@ func New(t *testing.T, opts ...pty.Option) *PTY { // Start starts a new process asynchronously and returns a PTYCmd and Process. // It kills the process and PTYCmd upon cleanup -func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) { +func Start(t *testing.T, cmd *pty.Cmd, opts ...pty.StartOption) (*PTYCmd, pty.Process) { t.Helper() ptty, ps, err := pty.Start(cmd, opts...) diff --git a/pty/start.go b/pty/start.go index 565edaca43..1105140ec3 100644 --- a/pty/start.go +++ b/pty/start.go @@ -1,6 +1,7 @@ package pty import ( + "context" "os/exec" ) @@ -18,8 +19,42 @@ func WithPTYOption(opts ...Option) StartOption { } } +// Cmd is a drop-in replacement for exec.Cmd with most of the same API, but +// it exposes the context.Context to our PTY code so that we can still kill the +// process when the Context expires. This is required because on Windows, we don't +// start the command using the `exec` library, so we have to manage the context +// ourselves. +type Cmd struct { + Context context.Context + Path string + Args []string + Env []string + Dir string +} + +func CommandContext(ctx context.Context, name string, arg ...string) *Cmd { + return &Cmd{ + Context: ctx, + Path: name, + Args: append([]string{name}, arg...), + Env: make([]string, 0), + } +} + +func Command(name string, arg ...string) *Cmd { + return CommandContext(context.Background(), name, arg...) +} + +func (c *Cmd) AsExec() *exec.Cmd { + //nolint: gosec + execCmd := exec.CommandContext(c.Context, c.Path, c.Args[1:]...) + execCmd.Dir = c.Dir + execCmd.Env = c.Env + return execCmd +} + // Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and // instead rely on the returned Process to manage the command/process. -func Start(cmd *exec.Cmd, opt ...StartOption) (PTYCmd, Process, error) { +func Start(cmd *Cmd, opt ...StartOption) (PTYCmd, Process, error) { return startPty(cmd, opt...) } diff --git a/pty/start_other.go b/pty/start_other.go index 33e3191100..2802e027ef 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -3,8 +3,8 @@ package pty import ( + "context" "fmt" - "os/exec" "runtime" "strings" "syscall" @@ -12,7 +12,7 @@ import ( "golang.org/x/xerrors" ) -func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) { +func startPty(cmdPty *Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) { var opts startOptions for _, o := range opt { o(&opts) @@ -23,30 +23,34 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process return nil, nil, xerrors.Errorf("newPty failed: %w", err) } - origEnv := cmd.Env + origEnv := cmdPty.Env if opty.opts.sshReq != nil { - cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name())) + cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name())) } if opty.opts.setGPGTTY { - cmd.Env = append(cmd.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name())) + cmdPty.Env = append(cmdPty.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name())) } + if cmdPty.Context == nil { + cmdPty.Context = context.Background() + } + cmdExec := cmdPty.AsExec() - cmd.SysProcAttr = &syscall.SysProcAttr{ + cmdExec.SysProcAttr = &syscall.SysProcAttr{ Setsid: true, Setctty: true, } - cmd.Stdout = opty.tty - cmd.Stderr = opty.tty - cmd.Stdin = opty.tty - err = cmd.Start() + cmdExec.Stdout = opty.tty + cmdExec.Stderr = opty.tty + cmdExec.Stdin = opty.tty + err = cmdExec.Start() if err != nil { _ = opty.Close() if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "bad file descriptor") { // macOS has an obscure issue where the PTY occasionally closes // before it's used. It's unknown why this is, but creating a new // TTY resolves it. - cmd.Env = origEnv - return startPty(cmd, opt...) + cmdPty.Env = origEnv + return startPty(cmdPty, opt...) } return nil, nil, xerrors.Errorf("start: %w", err) } @@ -64,14 +68,14 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process // confirming this, but I did find a thread of someone else's // observations: https://developer.apple.com/forums/thread/663632 if err := opty.tty.Close(); err != nil { - _ = cmd.Process.Kill() + _ = cmdExec.Process.Kill() return nil, nil, xerrors.Errorf("close tty: %w", err) } opty.tty = nil // remove so we don't attempt to close it again. } oProcess := &otherProcess{ pty: opty.pty, - cmd: cmd, + cmd: cmdExec, cmdDone: make(chan any), } go oProcess.waitInternal() diff --git a/pty/start_other_test.go b/pty/start_other_test.go index 264f7912a8..e7a2a3d69e 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -24,7 +24,7 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty, ps := ptytest.Start(t, exec.Command("echo", "test")) + pty, ps := ptytest.Start(t, pty.Command("echo", "test")) pty.ExpectMatch("test") err := ps.Wait() @@ -35,7 +35,7 @@ func TestStart(t *testing.T) { t.Run("Kill", func(t *testing.T) { t.Parallel() - pty, ps := ptytest.Start(t, exec.Command("sleep", "30")) + pty, ps := ptytest.Start(t, pty.Command("sleep", "30")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() @@ -54,7 +54,7 @@ func TestStart(t *testing.T) { Height: 24, }, })) - pty, ps := ptytest.Start(t, exec.Command("env"), opts) + pty, ps := ptytest.Start(t, pty.Command("env"), opts) pty.ExpectMatch("SSH_TTY=/dev/") err := ps.Wait() require.NoError(t, err) @@ -84,3 +84,9 @@ do echo "$i" done `} + +// these constants/vars are used by Test_Start_cancel_context + +const cmdSleep = "sleep" + +var argSleep = []string{"30"} diff --git a/pty/start_test.go b/pty/start_test.go index d8711cb99c..5f273428d2 100644 --- a/pty/start_test.go +++ b/pty/start_test.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "os/exec" "strings" "testing" "time" @@ -26,7 +25,7 @@ func Test_Start_copy(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdEcho, argEcho...)) + pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdEcho, argEcho...)) require.NoError(t, err) b := &bytes.Buffer{} readDone := make(chan error, 1) @@ -64,7 +63,7 @@ func Test_Start_truncation(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() - pc, cmd, err := pty.Start(exec.CommandContext(ctx, cmdCount, argCount...)) + pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdCount, argCount...)) require.NoError(t, err) readDone := make(chan struct{}) @@ -114,6 +113,35 @@ func Test_Start_truncation(t *testing.T) { } } +// Test_Start_cancel_context tests that we can cancel the command context and kill the process. +func Test_Start_cancel_context(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + cmdCtx, cmdCancel := context.WithCancel(ctx) + + pc, cmd, err := pty.Start(pty.CommandContext(cmdCtx, cmdSleep, argSleep...)) + require.NoError(t, err) + defer func() { + _ = pc.Close() + }() + cmdCancel() + + cmdDone := make(chan struct{}) + go func() { + defer close(cmdDone) + _ = cmd.Wait() + }() + + select { + case <-cmdDone: + // OK! + case <-ctx.Done(): + t.Error("cmd.Wait() timed out") + } +} + // readUntil reads one byte at a time until we either see the string we want, or the context expires func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error { // output can contain virtual terminal sequences, so we need to parse these diff --git a/pty/start_windows.go b/pty/start_windows.go index 2811900ffc..4e9a755e95 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -17,7 +17,7 @@ import ( // Allocates a PTY and starts the specified command attached to it. // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) { +func startPty(cmd *Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr error) { var opts startOptions for _, o := range opt { o(&opts) @@ -129,6 +129,9 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (_ PTYCmd, _ Process, retErr er return nil, nil, errI } go wp.waitInternal() + if cmd.Context != nil { + go wp.killOnContext(cmd.Context) + } return winPty, wp, nil } diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index a8e287e1ed..b0f862ea15 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -8,6 +8,7 @@ import ( "os/exec" "testing" + "github.com/coder/coder/pty" "github.com/coder/coder/pty/ptytest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -23,7 +24,7 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - ptty, ps := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + ptty, ps := ptytest.Start(t, pty.Command("cmd.exe", "/c", "echo", "test")) ptty.ExpectMatch("test") err := ps.Wait() require.NoError(t, err) @@ -32,7 +33,7 @@ func TestStart(t *testing.T) { }) t.Run("Resize", func(t *testing.T) { t.Parallel() - ptty, _ := ptytest.Start(t, exec.Command("cmd.exe")) + ptty, _ := ptytest.Start(t, pty.Command("cmd.exe")) err := ptty.Resize(100, 50) require.NoError(t, err) err = ptty.Close() @@ -40,7 +41,7 @@ func TestStart(t *testing.T) { }) t.Run("Kill", func(t *testing.T) { t.Parallel() - ptty, ps := ptytest.Start(t, exec.Command("cmd.exe")) + ptty, ps := ptytest.Start(t, pty.Command("cmd.exe")) err := ps.Kill() assert.NoError(t, err) err = ps.Wait() @@ -66,3 +67,9 @@ const ( ) var argCount = []string{"/c", fmt.Sprintf("for /L %%n in (1,1,%d) do @echo %%n", countEnd)} + +// these constants/vars are used by Test_Start_cancel_context + +const cmdSleep = "cmd.exe" + +var argSleep = []string{"/c", "timeout", "/t", "30"}