diff --git a/.vscode/settings.json b/.vscode/settings.json index 3733ff8f7a..c328f9a746 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -31,6 +31,7 @@ "gonet", "gossh", "gsyslog", + "GTTY", "hashicorp", "hclsyntax", "httpapi", @@ -67,6 +68,7 @@ "ntqry", "OIDC", "oneof", + "opty", "paralleltest", "parameterscopeid", "pqtype", @@ -76,6 +78,7 @@ "provisionerd", "provisionersdk", "ptty", + "ptys", "ptytest", "reconfig", "retrier", @@ -87,6 +90,7 @@ "sourcemapped", "Srcs", "stretchr", + "STTY", "stuntest", "tailbroker", "tailcfg", @@ -105,6 +109,7 @@ "tfjson", "tfplan", "tfstate", + "tios", "tparallel", "trimprefix", "tsdial", diff --git a/agent/agent.go b/agent/agent.go index c477122e7b..6018062b37 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -374,7 +374,7 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error { return nil } - writer, err := os.OpenFile(filepath.Join(os.TempDir(), "coder-startup-script.log"), os.O_CREATE|os.O_RDWR, 0600) + writer, err := os.OpenFile(filepath.Join(os.TempDir(), "coder-startup-script.log"), os.O_CREATE|os.O_RDWR, 0o600) if err != nil { return xerrors.Errorf("open startup script log file: %w", err) } @@ -537,6 +537,8 @@ func (a *agent) init(ctx context.Context) { }, SubsystemHandlers: map[string]ssh.SubsystemHandler{ "sftp": func(session ssh.Session) { + session.DisablePTYEmulation() + server, err := sftp.NewServer(session) if err != nil { a.logger.Debug(session.Context(), "initialize sftp server", slog.Error(err)) @@ -661,7 +663,8 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri } func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { - cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ()) + ctx := session.Context() + cmd, err := a.createCommand(ctx, session.RawCommand(), session.Environ()) if err != nil { return err } @@ -678,32 +681,34 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { sshPty, windowSize, isPty := session.Pty() if isPty { + // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). + // See https://github.com/coder/coder/issues/3371. + session.DisablePTYEmulation() + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) // The pty package sets `SSH_TTY` on supported platforms. - ptty, process, err := pty.Start(cmd) + ptty, process, err := pty.Start(cmd, pty.WithPTYOption( + pty.WithSSHRequest(sshPty), + pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), + )) if err != nil { return xerrors.Errorf("start command: %w", err) } defer func() { closeErr := ptty.Close() if closeErr != nil { - a.logger.Warn(context.Background(), "failed to close tty", - slog.Error(closeErr)) + a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) if retErr == nil { retErr = closeErr } } }() - err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width)) - if err != nil { - return xerrors.Errorf("resize ptty: %w", err) - } go func() { for win := range windowSize { resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) if resizeErr != nil { - a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(resizeErr)) + a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) } } }() @@ -718,8 +723,7 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { // ExitErrors just mean the command we run returned a non-zero exit code, which is normal // and not something to be concerned about. But, if it's something else, we should log it. if err != nil && !xerrors.As(err, &exitErr) { - a.logger.Warn(context.Background(), "wait error", - slog.Error(err)) + a.logger.Warn(ctx, "wait error", slog.Error(err)) } return err } diff --git a/cli/ssh.go b/cli/ssh.go index 7f23cce706..796c1849a3 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -195,6 +195,13 @@ func ssh() *cobra.Command { // shutdown of services. defer cancel() + if validOut { + // Set initial window size. + width, height, err := term.GetSize(int(stdoutFile.Fd())) + if err == nil { + _ = sshSession.WindowChange(height, width) + } + } err = sshSession.Wait() if err != nil { // If the connection drops unexpectedly, we get an ExitMissingError but no other diff --git a/go.mod b/go.mod index 988f487591..c30bce4831 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,15 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0 // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220907193453-fb5ba5ab658d +// Switch to our fork that imports fixes from http://github.com/tailscale/ssh. +// See: https://github.com/coder/coder/issues/3371 +// +// Note that http://github.com/tailscale/ssh has been merged into the Tailscale +// repo as tailscale.com/tempfork/gliderlabs/ssh, however, we can't replace the +// subpath and it includes changes to golang.org/x/crypto/ssh as well which +// makes importing it directly a bit messy. +replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 + require ( cdr.dev/slog v1.4.2-0.20220525200111-18dce5c2cd5f cloud.google.com/go/compute v1.7.0 @@ -126,6 +135,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.0 github.com/tabbed/pqtype v0.1.1 + github.com/u-root/u-root v0.9.0 github.com/unrolled/secure v1.12.0 go.mozilla.org/pkcs7 v0.0.0-20200128120323-432b2356ecb1 go.opentelemetry.io/otel v1.8.0 diff --git a/go.sum b/go.sum index 0987a4ea34..1ed3d06587 100644 --- a/go.sum +++ b/go.sum @@ -352,6 +352,8 @@ github.com/coder/glog v1.0.1-0.20220322161911-7365fe7f2cd1 h1:UqBrPWSYvRI2s5RtOu github.com/coder/glog v1.0.1-0.20220322161911-7365fe7f2cd1/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= +github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko= +github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= github.com/coder/tailscale v1.1.1-0.20220907193453-fb5ba5ab658d h1:IQ8wJn8MfDS+sesYPpn3EDAyvoGMxFvyyE9uWtcfU6w= github.com/coder/tailscale v1.1.1-0.20220907193453-fb5ba5ab658d/go.mod h1:MO+tWkQp2YIF3KBnnej/mQvgYccRS5Xk/IrEpZ4Z3BU= github.com/coder/wireguard-go/tun/netstack v0.0.0-20220823170024-a78136eb0cab h1:9yEvRWXXfyKzXu8AqywCi+tFZAoqCy4wVcsXwuvZNMc= @@ -638,9 +640,6 @@ github.com/gin-gonic/gin v1.7.0 h1:jGB9xAJQ12AIGNB4HguylppmDK1Am9ppF7XnGXXJuoU= github.com/gin-gonic/gin v1.7.0/go.mod h1:jD2toBW3GZUr5UMcdrwQA10I7RuaFOl/SGeDjXkfUtY= github.com/github/fakeca v0.1.0 h1:Km/MVOFvclqxPM9dZBC4+QE564nU4gz4iZ0D9pMw28I= github.com/github/fakeca v0.1.0/go.mod h1:+bormgoGMMuamOscx7N91aOuUST7wdaJ2rNjeohylyo= -github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= -github.com/gliderlabs/ssh v0.3.4 h1:+AXBtim7MTKaLVPgvE+3mhewYRawNLTd+jEEz/wExZw= -github.com/gliderlabs/ssh v0.3.4/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs= github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= @@ -1810,6 +1809,8 @@ github.com/tomarrell/wrapcheck/v2 v2.4.0/go.mod h1:68bQ/eJg55BROaRTbMjC7vuhL2Ogf github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4= github.com/tommy-muehle/go-mnd/v2 v2.4.0/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw= github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM= +github.com/u-root/u-root v0.9.0 h1:1dpUzrE0FyKrNEjxpKFOkyveuV1f3T0Ko5CQg4gTkCg= +github.com/u-root/u-root v0.9.0/go.mod h1:ewc9w6JF1ayZCVC9Y5wsrUiCBw3nMmPC3QItvrEwmew= github.com/u-root/uio v0.0.0-20210528114334-82958018845c/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA= github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4 h1:hl6sK6aFgTLISijk6xIzeqnPzQcsLqqvL6vEfTPinME= github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA= diff --git a/pty/pty.go b/pty/pty.go index 7db4a88884..b37369c9c1 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -2,13 +2,19 @@ package pty import ( "io" + "log" "os" + + "github.com/gliderlabs/ssh" ) // PTY is a minimal interface for interacting with a TTY. type PTY interface { io.Closer + // Name of the TTY. Example on Linux would be "/dev/pts/1". + Name() string + // Output handles TTY output. // // cmd.SetOutput(pty.Output()) would be used to specify a command @@ -35,7 +41,6 @@ type PTY interface { // to Wait() on a process, this abstraction provides a goroutine-safe interface for interacting with // the process. type Process interface { - // Wait for the command to complete. Returned error is as for exec.Cmd.Wait() Wait() error @@ -52,9 +57,33 @@ type WithFlags interface { EchoEnabled() (bool, error) } +// Options represents a an option for a PTY. +type Option func(*ptyOptions) + +type ptyOptions struct { + logger *log.Logger + sshReq *ssh.Pty +} + +// WithSSHRequest applies the ssh.Pty request to the PTY. +// +// Only partially supported on Windows (e.g. window size). +func WithSSHRequest(req ssh.Pty) Option { + return func(opts *ptyOptions) { + opts.sshReq = &req + } +} + +// WithLogger sets a logger for logging errors. +func WithLogger(logger *log.Logger) Option { + return func(opts *ptyOptions) { + opts.logger = logger + } +} + // New constructs a new Pty. -func New() (PTY, error) { - return newPty() +func New(opts ...Option) (PTY, error) { + return newPty(opts...) } // ReadWriter is an implementation of io.ReadWriter that wraps two separate diff --git a/pty/pty_linux.go b/pty/pty_linux.go index b18d801c22..c0a5d31f63 100644 --- a/pty/pty_linux.go +++ b/pty/pty_linux.go @@ -2,12 +2,23 @@ package pty -import "golang.org/x/sys/unix" +import ( + "github.com/u-root/u-root/pkg/termios" + "golang.org/x/sys/unix" +) -func (p *otherPty) EchoEnabled() (bool, error) { - termios, err := unix.IoctlGetTermios(int(p.pty.Fd()), unix.TCGETS) +func (p *otherPty) EchoEnabled() (echo bool, err error) { + err = p.control(p.pty, func(fd uintptr) error { + t, err := termios.GetTermios(fd) + if err != nil { + return err + } + + echo = (t.Lflag & unix.ECHO) != 0 + return nil + }) if err != nil { return false, err } - return (termios.Lflag & unix.ECHO) != 0, nil + return echo, nil } diff --git a/pty/pty_other.go b/pty/pty_other.go index 869d77fe0b..cfe9ccd47e 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package pty @@ -10,19 +9,42 @@ import ( "sync" "github.com/creack/pty" + "github.com/u-root/u-root/pkg/termios" + "golang.org/x/sys/unix" "golang.org/x/xerrors" ) -func newPty() (PTY, error) { +func newPty(opt ...Option) (retPTY *otherPty, err error) { + var opts ptyOptions + for _, o := range opt { + o(&opts) + } + ptyFile, ttyFile, err := pty.Open() if err != nil { return nil, err } + opty := &otherPty{ + pty: ptyFile, + tty: ttyFile, + opts: opts, + } + defer func() { + if err != nil { + _ = opty.Close() + } + }() - return &otherPty{ - pty: ptyFile, - tty: ttyFile, - }, nil + if opts.sshReq != nil { + err = opty.control(opty.tty, func(fd uintptr) error { + return applyTerminalModesToFd(opts.logger, fd, *opts.sshReq) + }) + if err != nil { + return nil, err + } + } + + return opty, nil } type otherPty struct { @@ -30,15 +52,40 @@ type otherPty struct { closed bool err error pty, tty *os.File + opts ptyOptions } -type otherProcess struct { - pty *os.File - cmd *exec.Cmd +func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) { + defer func() { + // Always echo the close error for closed ptys. + p.mutex.Lock() + defer p.mutex.Unlock() + if p.closed { + err = p.err + } + }() - // cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first. - cmdDone chan any - cmdErr error + rawConn, err := tty.SyscallConn() + if err != nil { + return err + } + + var ctlErr error + err = rawConn.Control(func(fd uintptr) { + ctlErr = fn(fd) + }) + switch { + case err != nil: + return err + case ctlErr != nil: + return ctlErr + default: + return nil + } +} + +func (p *otherPty) Name() string { + return p.tty.Name() } func (p *otherPty) Input() ReadWriter { @@ -56,14 +103,13 @@ func (p *otherPty) Output() ReadWriter { } func (p *otherPty) Resize(height uint16, width uint16) error { - p.mutex.Lock() - defer p.mutex.Unlock() - if p.closed { - return p.err - } - return pty.Setsize(p.pty, &pty.Winsize{ - Rows: height, - Cols: width, + return p.control(p.pty, func(fd uintptr) error { + return termios.SetWinSize(fd, &termios.Winsize{ + Winsize: unix.Winsize{ + Row: height, + Col: width, + }, + }) }) } @@ -91,6 +137,15 @@ func (p *otherPty) Close() error { return err } +type otherProcess struct { + pty *os.File + cmd *exec.Cmd + + // cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first. + cmdDone chan any + cmdErr error +} + func (p *otherProcess) Wait() error { <-p.cmdDone return p.cmdErr diff --git a/pty/pty_windows.go b/pty/pty_windows.go index f206921c42..7cdbf0316d 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -1,5 +1,4 @@ //go:build windows -// +build windows package pty @@ -22,7 +21,12 @@ var ( ) // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session -func newPty() (PTY, error) { +func newPty(opt ...Option) (PTY, error) { + var opts ptyOptions + for _, o := range opt { + o(&opts) + } + // We use the CreatePseudoConsole API which was introduced in build 17763 vsn := windows.RtlGetVersion() if vsn.MajorVersion < 10 || @@ -32,30 +36,42 @@ func newPty() (PTY, error) { return nil, xerrors.Errorf("pty not supported") } - ptyWindows := &ptyWindows{} + pty := &ptyWindows{ + opts: opts, + } var err error - ptyWindows.inputRead, ptyWindows.inputWrite, err = os.Pipe() + pty.inputRead, pty.inputWrite, err = os.Pipe() if err != nil { return nil, err } - ptyWindows.outputRead, ptyWindows.outputWrite, err = os.Pipe() + pty.outputRead, pty.outputWrite, err = os.Pipe() + if err != nil { + _ = pty.inputRead.Close() + _ = pty.inputWrite.Close() + return nil, err + } consoleSize := uintptr(80) + (uintptr(80) << 16) + if opts.sshReq != nil { + consoleSize = uintptr(opts.sshReq.Window.Width) + (uintptr(opts.sshReq.Window.Height) << 16) + } ret, _, err := procCreatePseudoConsole.Call( consoleSize, - uintptr(ptyWindows.inputRead.Fd()), - uintptr(ptyWindows.outputWrite.Fd()), + uintptr(pty.inputRead.Fd()), + uintptr(pty.outputWrite.Fd()), 0, - uintptr(unsafe.Pointer(&ptyWindows.console)), + uintptr(unsafe.Pointer(&pty.console)), ) if int32(ret) < 0 { + _ = pty.Close() return nil, xerrors.Errorf("create pseudo console (%d): %w", int32(ret), err) } - return ptyWindows, nil + return pty, nil } type ptyWindows struct { + opts ptyOptions console windows.Handle outputWrite *os.File @@ -74,6 +90,13 @@ type windowsProcess struct { proc *os.Process } +// Name returns the TTY name on Windows. +// +// Not implemented. +func (p *ptyWindows) Name() string { + return "" +} + func (p *ptyWindows) Output() ReadWriter { return ReadWriter{ Reader: p.outputRead, diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index a6d6d8ab46..378c1a1fc1 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -21,19 +21,19 @@ import ( "github.com/coder/coder/testutil" ) -func New(t *testing.T) *PTY { +func New(t *testing.T, opts ...pty.Option) *PTY { t.Helper() - ptty, err := pty.New() + ptty, err := pty.New(opts...) require.NoError(t, err) return create(t, ptty, "cmd") } -func Start(t *testing.T, cmd *exec.Cmd) (*PTY, pty.Process) { +func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Process) { t.Helper() - ptty, ps, err := pty.Start(cmd) + ptty, ps, err := pty.Start(cmd, opts...) require.NoError(t, err) t.Cleanup(func() { _ = ps.Kill() diff --git a/pty/ssh_other.go b/pty/ssh_other.go new file mode 100644 index 0000000000..fabe869870 --- /dev/null +++ b/pty/ssh_other.go @@ -0,0 +1,127 @@ +//go:build !windows + +package pty + +import ( + "log" + + "github.com/gliderlabs/ssh" + "github.com/u-root/u-root/pkg/termios" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" +) + +// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic +// names used by the termios package. +var terminalModeFlagNames = map[uint8]string{ + gossh.VINTR: "intr", + gossh.VQUIT: "quit", + gossh.VERASE: "erase", + gossh.VKILL: "kill", + gossh.VEOF: "eof", + gossh.VEOL: "eol", + gossh.VEOL2: "eol2", + gossh.VSTART: "start", + gossh.VSTOP: "stop", + gossh.VSUSP: "susp", + gossh.VDSUSP: "dsusp", + gossh.VREPRINT: "rprnt", + gossh.VWERASE: "werase", + gossh.VLNEXT: "lnext", + gossh.VFLUSH: "flush", + gossh.VSWTCH: "swtch", + gossh.VSTATUS: "status", + gossh.VDISCARD: "discard", + gossh.IGNPAR: "ignpar", + gossh.PARMRK: "parmrk", + gossh.INPCK: "inpck", + gossh.ISTRIP: "istrip", + gossh.INLCR: "inlcr", + gossh.IGNCR: "igncr", + gossh.ICRNL: "icrnl", + gossh.IUCLC: "iuclc", + gossh.IXON: "ixon", + gossh.IXANY: "ixany", + gossh.IXOFF: "ixoff", + gossh.IMAXBEL: "imaxbel", + gossh.IUTF8: "iutf8", + gossh.ISIG: "isig", + gossh.ICANON: "icanon", + gossh.XCASE: "xcase", + gossh.ECHO: "echo", + gossh.ECHOE: "echoe", + gossh.ECHOK: "echok", + gossh.ECHONL: "echonl", + gossh.NOFLSH: "noflsh", + gossh.TOSTOP: "tostop", + gossh.IEXTEN: "iexten", + gossh.ECHOCTL: "echoctl", + gossh.ECHOKE: "echoke", + gossh.PENDIN: "pendin", + gossh.OPOST: "opost", + gossh.OLCUC: "olcuc", + gossh.ONLCR: "onlcr", + gossh.OCRNL: "ocrnl", + gossh.ONOCR: "onocr", + gossh.ONLRET: "onlret", + gossh.CS7: "cs7", + gossh.CS8: "cs8", + gossh.PARENB: "parenb", + gossh.PARODD: "parodd", + gossh.TTY_OP_ISPEED: "tty_op_ispeed", + gossh.TTY_OP_OSPEED: "tty_op_ospeed", +} + +// applyTerminalModesToFd applies the terminal settings from the SSH +// request to the given fd. +// +// This is based on code from Tailscale's tailssh package: +// https://github.com/tailscale/tailscale/blob/main/ssh/tailssh/incubator.go +func applyTerminalModesToFd(logger *log.Logger, fd uintptr, req ssh.Pty) error { + // Get the current TTY configuration. + tios, err := termios.GTTY(int(fd)) + if err != nil { + return xerrors.Errorf("GTTY: %w", err) + } + + // Apply the modes from the SSH request. + tios.Row = req.Window.Height + tios.Col = req.Window.Width + + for c, v := range req.Modes { + if c == gossh.TTY_OP_ISPEED { + tios.Ispeed = int(v) + continue + } + if c == gossh.TTY_OP_OSPEED { + tios.Ospeed = int(v) + continue + } + k, ok := terminalModeFlagNames[c] + if !ok { + if logger != nil { + logger.Printf("unknown terminal mode: %d", c) + } + continue + } + if _, ok := tios.CC[k]; ok { + tios.CC[k] = uint8(v) + continue + } + if _, ok := tios.Opts[k]; ok { + tios.Opts[k] = v > 0 + continue + } + + if logger != nil { + logger.Printf("unsupported terminal mode: k=%s, c=%d, v=%d", k, c, v) + } + } + + // Save the new TTY configuration. + if _, err := tios.STTY(int(fd)); err != nil { + return xerrors.Errorf("STTY: %w", err) + } + + return nil +} diff --git a/pty/start.go b/pty/start.go index 385eddcd43..ea09cbb251 100644 --- a/pty/start.go +++ b/pty/start.go @@ -4,8 +4,22 @@ import ( "os/exec" ) +// StartOption represents a configuration option passed to Start. +type StartOption func(*startOptions) + +type startOptions struct { + ptyOpts []Option +} + +// WithPTYOption applies the given options to the underlying PTY. +func WithPTYOption(opts ...Option) StartOption { + return func(o *startOptions) { + o.ptyOpts = append(o.ptyOpts, opts...) + } +} + // Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and // instead rely on the returned Process to manage the command/process. -func Start(cmd *exec.Cmd) (PTY, Process, error) { - return startPty(cmd) +func Start(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { + return startPty(cmd, opt...) } diff --git a/pty/start_other.go b/pty/start_other.go index 918db86bd2..2bf3bdfebc 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package pty @@ -10,45 +9,49 @@ import ( "strings" "syscall" - "github.com/creack/pty" "golang.org/x/xerrors" ) -func startPty(cmd *exec.Cmd) (PTY, Process, error) { - ptty, tty, err := pty.Open() - if err != nil { - return nil, nil, xerrors.Errorf("open: %w", err) +func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) { + var opts startOptions + for _, o := range opt { + o(&opts) + } + + opty, err := newPty(opts.ptyOpts...) + if err != nil { + return nil, nil, xerrors.Errorf("newPty failed: %w", err) + } + + origEnv := cmd.Env + if opty.opts.sshReq != nil { + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name())) } - cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_PTY=%s", tty.Name())) cmd.SysProcAttr = &syscall.SysProcAttr{ Setsid: true, Setctty: true, } - cmd.Stdout = tty - cmd.Stderr = tty - cmd.Stdin = tty + cmd.Stdout = opty.tty + cmd.Stderr = opty.tty + cmd.Stdin = opty.tty err = cmd.Start() if err != nil { - _ = ptty.Close() - _ = tty.Close() + _ = opty.Close() if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "bad file descriptor") { // macOS has an obscure issue where the PTY occasionally closes // before it's used. It's unknown why this is, but creating a new // TTY resolves it. - return startPty(cmd) + cmd.Env = origEnv + return startPty(cmd, opt...) } return nil, nil, xerrors.Errorf("start: %w", err) } - oPty := &otherPty{ - pty: ptty, - tty: tty, - } oProcess := &otherProcess{ - pty: ptty, + pty: opty.pty, cmd: cmd, cmdDone: make(chan any), } go oProcess.waitInternal() - return oPty, oProcess, nil + return opty, oProcess, nil } diff --git a/pty/start_other_test.go b/pty/start_other_test.go index 25a0a52124..d1f11a419e 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -6,12 +6,13 @@ import ( "os/exec" "testing" + "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "golang.org/x/xerrors" - "go.uber.org/goleak" - + "github.com/coder/coder/pty" "github.com/coder/coder/pty/ptytest" ) @@ -40,10 +41,16 @@ func TestStart(t *testing.T) { assert.NotEqual(t, 0, exitErr.ExitCode()) }) - t.Run("SSH_PTY", func(t *testing.T) { + t.Run("SSH_TTY", func(t *testing.T) { t.Parallel() - pty, ps := ptytest.Start(t, exec.Command("env")) - pty.ExpectMatch("SSH_PTY=/dev/") + opts := pty.WithPTYOption(pty.WithSSHRequest(ssh.Pty{ + Window: ssh.Window{ + Width: 80, + Height: 24, + }, + })) + pty, ps := ptytest.Start(t, exec.Command("env"), opts) + pty.ExpectMatch("SSH_TTY=/dev/") err := ps.Wait() require.NoError(t, err) }) diff --git a/pty/start_windows.go b/pty/start_windows.go index d638e5cdd1..f9307cd364 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -4,6 +4,7 @@ package pty import ( + "fmt" "os" "os/exec" "strings" @@ -16,7 +17,12 @@ import ( // Allocates a PTY and starts the specified command attached to it. // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *exec.Cmd) (PTY, Process, error) { +func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) { + var opts startOptions + for _, o := range opt { + o(&opts) + } + fullPath, err := exec.LookPath(cmd.Path) if err != nil { return nil, nil, err @@ -39,11 +45,14 @@ func startPty(cmd *exec.Cmd) (PTY, Process, error) { if err != nil { return nil, nil, err } - pty, err := newPty() + pty, err := newPty(opts.ptyOpts...) if err != nil { return nil, nil, err } winPty := pty.(*ptyWindows) + if winPty.opts.sshReq != nil { + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", winPty.Name())) + } attrs, err := windows.NewProcThreadAttributeList(1) if err != nil {