fix: Terminal emulation used by SSH sessions (#3473)

Fixes #3371
This commit is contained in:
Mathias Fredriksson 2022-09-12 19:27:51 +03:00 committed by GitHub
parent b4c29f34c3
commit 09da3858ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 388 additions and 83 deletions

View File

@ -31,6 +31,7 @@
"gonet",
"gossh",
"gsyslog",
"GTTY",
"hashicorp",
"hclsyntax",
"httpapi",
@ -67,6 +68,7 @@
"ntqry",
"OIDC",
"oneof",
"opty",
"paralleltest",
"parameterscopeid",
"pqtype",
@ -76,6 +78,7 @@
"provisionerd",
"provisionersdk",
"ptty",
"ptys",
"ptytest",
"reconfig",
"retrier",
@ -87,6 +90,7 @@
"sourcemapped",
"Srcs",
"stretchr",
"STTY",
"stuntest",
"tailbroker",
"tailcfg",
@ -105,6 +109,7 @@
"tfjson",
"tfplan",
"tfstate",
"tios",
"tparallel",
"trimprefix",
"tsdial",

View File

@ -374,7 +374,7 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error {
return nil
}
writer, err := os.OpenFile(filepath.Join(os.TempDir(), "coder-startup-script.log"), os.O_CREATE|os.O_RDWR, 0600)
writer, err := os.OpenFile(filepath.Join(os.TempDir(), "coder-startup-script.log"), os.O_CREATE|os.O_RDWR, 0o600)
if err != nil {
return xerrors.Errorf("open startup script log file: %w", err)
}
@ -537,6 +537,8 @@ func (a *agent) init(ctx context.Context) {
},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": func(session ssh.Session) {
session.DisablePTYEmulation()
server, err := sftp.NewServer(session)
if err != nil {
a.logger.Debug(session.Context(), "initialize sftp server", slog.Error(err))
@ -661,7 +663,8 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
}
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
ctx := session.Context()
cmd, err := a.createCommand(ctx, session.RawCommand(), session.Environ())
if err != nil {
return err
}
@ -678,32 +681,34 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
sshPty, windowSize, isPty := session.Pty()
if isPty {
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
// See https://github.com/coder/coder/issues/3371.
session.DisablePTYEmulation()
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
// The pty package sets `SSH_TTY` on supported platforms.
ptty, process, err := pty.Start(cmd)
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
pty.WithSSHRequest(sshPty),
pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)),
))
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
defer func() {
closeErr := ptty.Close()
if closeErr != nil {
a.logger.Warn(context.Background(), "failed to close tty",
slog.Error(closeErr))
a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}
}()
err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width))
if err != nil {
return xerrors.Errorf("resize ptty: %w", err)
}
go func() {
for win := range windowSize {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
if resizeErr != nil {
a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(resizeErr))
a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
}
}
}()
@ -718,8 +723,7 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
// and not something to be concerned about. But, if it's something else, we should log it.
if err != nil && !xerrors.As(err, &exitErr) {
a.logger.Warn(context.Background(), "wait error",
slog.Error(err))
a.logger.Warn(ctx, "wait error", slog.Error(err))
}
return err
}

View File

@ -195,6 +195,13 @@ func ssh() *cobra.Command {
// shutdown of services.
defer cancel()
if validOut {
// Set initial window size.
width, height, err := term.GetSize(int(stdoutFile.Fd()))
if err == nil {
_ = sshSession.WindowChange(height, width)
}
}
err = sshSession.Wait()
if err != nil {
// If the connection drops unexpectedly, we get an ExitMissingError but no other

10
go.mod
View File

@ -51,6 +51,15 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0
// https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220907193453-fb5ba5ab658d
// Switch to our fork that imports fixes from http://github.com/tailscale/ssh.
// See: https://github.com/coder/coder/issues/3371
//
// Note that http://github.com/tailscale/ssh has been merged into the Tailscale
// repo as tailscale.com/tempfork/gliderlabs/ssh, however, we can't replace the
// subpath and it includes changes to golang.org/x/crypto/ssh as well which
// makes importing it directly a bit messy.
replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20220811105153-fcea99919338
require (
cdr.dev/slog v1.4.2-0.20220525200111-18dce5c2cd5f
cloud.google.com/go/compute v1.7.0
@ -126,6 +135,7 @@ require (
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.0
github.com/tabbed/pqtype v0.1.1
github.com/u-root/u-root v0.9.0
github.com/unrolled/secure v1.12.0
go.mozilla.org/pkcs7 v0.0.0-20200128120323-432b2356ecb1
go.opentelemetry.io/otel v1.8.0

7
go.sum
View File

@ -352,6 +352,8 @@ github.com/coder/glog v1.0.1-0.20220322161911-7365fe7f2cd1 h1:UqBrPWSYvRI2s5RtOu
github.com/coder/glog v1.0.1-0.20220322161911-7365fe7f2cd1/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4=
github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY=
github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY=
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko=
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
github.com/coder/tailscale v1.1.1-0.20220907193453-fb5ba5ab658d h1:IQ8wJn8MfDS+sesYPpn3EDAyvoGMxFvyyE9uWtcfU6w=
github.com/coder/tailscale v1.1.1-0.20220907193453-fb5ba5ab658d/go.mod h1:MO+tWkQp2YIF3KBnnej/mQvgYccRS5Xk/IrEpZ4Z3BU=
github.com/coder/wireguard-go/tun/netstack v0.0.0-20220823170024-a78136eb0cab h1:9yEvRWXXfyKzXu8AqywCi+tFZAoqCy4wVcsXwuvZNMc=
@ -638,9 +640,6 @@ github.com/gin-gonic/gin v1.7.0 h1:jGB9xAJQ12AIGNB4HguylppmDK1Am9ppF7XnGXXJuoU=
github.com/gin-gonic/gin v1.7.0/go.mod h1:jD2toBW3GZUr5UMcdrwQA10I7RuaFOl/SGeDjXkfUtY=
github.com/github/fakeca v0.1.0 h1:Km/MVOFvclqxPM9dZBC4+QE564nU4gz4iZ0D9pMw28I=
github.com/github/fakeca v0.1.0/go.mod h1:+bormgoGMMuamOscx7N91aOuUST7wdaJ2rNjeohylyo=
github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
github.com/gliderlabs/ssh v0.3.4 h1:+AXBtim7MTKaLVPgvE+3mhewYRawNLTd+jEEz/wExZw=
github.com/gliderlabs/ssh v0.3.4/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs=
github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg=
github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8=
@ -1810,6 +1809,8 @@ github.com/tomarrell/wrapcheck/v2 v2.4.0/go.mod h1:68bQ/eJg55BROaRTbMjC7vuhL2Ogf
github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4=
github.com/tommy-muehle/go-mnd/v2 v2.4.0/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw=
github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM=
github.com/u-root/u-root v0.9.0 h1:1dpUzrE0FyKrNEjxpKFOkyveuV1f3T0Ko5CQg4gTkCg=
github.com/u-root/u-root v0.9.0/go.mod h1:ewc9w6JF1ayZCVC9Y5wsrUiCBw3nMmPC3QItvrEwmew=
github.com/u-root/uio v0.0.0-20210528114334-82958018845c/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA=
github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4 h1:hl6sK6aFgTLISijk6xIzeqnPzQcsLqqvL6vEfTPinME=
github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA=

View File

@ -2,13 +2,19 @@ package pty
import (
"io"
"log"
"os"
"github.com/gliderlabs/ssh"
)
// PTY is a minimal interface for interacting with a TTY.
type PTY interface {
io.Closer
// Name of the TTY. Example on Linux would be "/dev/pts/1".
Name() string
// Output handles TTY output.
//
// cmd.SetOutput(pty.Output()) would be used to specify a command
@ -35,7 +41,6 @@ type PTY interface {
// to Wait() on a process, this abstraction provides a goroutine-safe interface for interacting with
// the process.
type Process interface {
// Wait for the command to complete. Returned error is as for exec.Cmd.Wait()
Wait() error
@ -52,9 +57,33 @@ type WithFlags interface {
EchoEnabled() (bool, error)
}
// Options represents a an option for a PTY.
type Option func(*ptyOptions)
type ptyOptions struct {
logger *log.Logger
sshReq *ssh.Pty
}
// WithSSHRequest applies the ssh.Pty request to the PTY.
//
// Only partially supported on Windows (e.g. window size).
func WithSSHRequest(req ssh.Pty) Option {
return func(opts *ptyOptions) {
opts.sshReq = &req
}
}
// WithLogger sets a logger for logging errors.
func WithLogger(logger *log.Logger) Option {
return func(opts *ptyOptions) {
opts.logger = logger
}
}
// New constructs a new Pty.
func New() (PTY, error) {
return newPty()
func New(opts ...Option) (PTY, error) {
return newPty(opts...)
}
// ReadWriter is an implementation of io.ReadWriter that wraps two separate

View File

@ -2,12 +2,23 @@
package pty
import "golang.org/x/sys/unix"
import (
"github.com/u-root/u-root/pkg/termios"
"golang.org/x/sys/unix"
)
func (p *otherPty) EchoEnabled() (bool, error) {
termios, err := unix.IoctlGetTermios(int(p.pty.Fd()), unix.TCGETS)
func (p *otherPty) EchoEnabled() (echo bool, err error) {
err = p.control(p.pty, func(fd uintptr) error {
t, err := termios.GetTermios(fd)
if err != nil {
return err
}
echo = (t.Lflag & unix.ECHO) != 0
return nil
})
if err != nil {
return false, err
}
return (termios.Lflag & unix.ECHO) != 0, nil
return echo, nil
}

View File

@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package pty
@ -10,19 +9,42 @@ import (
"sync"
"github.com/creack/pty"
"github.com/u-root/u-root/pkg/termios"
"golang.org/x/sys/unix"
"golang.org/x/xerrors"
)
func newPty() (PTY, error) {
func newPty(opt ...Option) (retPTY *otherPty, err error) {
var opts ptyOptions
for _, o := range opt {
o(&opts)
}
ptyFile, ttyFile, err := pty.Open()
if err != nil {
return nil, err
}
opty := &otherPty{
pty: ptyFile,
tty: ttyFile,
opts: opts,
}
defer func() {
if err != nil {
_ = opty.Close()
}
}()
return &otherPty{
pty: ptyFile,
tty: ttyFile,
}, nil
if opts.sshReq != nil {
err = opty.control(opty.tty, func(fd uintptr) error {
return applyTerminalModesToFd(opts.logger, fd, *opts.sshReq)
})
if err != nil {
return nil, err
}
}
return opty, nil
}
type otherPty struct {
@ -30,15 +52,40 @@ type otherPty struct {
closed bool
err error
pty, tty *os.File
opts ptyOptions
}
type otherProcess struct {
pty *os.File
cmd *exec.Cmd
func (p *otherPty) control(tty *os.File, fn func(fd uintptr) error) (err error) {
defer func() {
// Always echo the close error for closed ptys.
p.mutex.Lock()
defer p.mutex.Unlock()
if p.closed {
err = p.err
}
}()
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
cmdDone chan any
cmdErr error
rawConn, err := tty.SyscallConn()
if err != nil {
return err
}
var ctlErr error
err = rawConn.Control(func(fd uintptr) {
ctlErr = fn(fd)
})
switch {
case err != nil:
return err
case ctlErr != nil:
return ctlErr
default:
return nil
}
}
func (p *otherPty) Name() string {
return p.tty.Name()
}
func (p *otherPty) Input() ReadWriter {
@ -56,14 +103,13 @@ func (p *otherPty) Output() ReadWriter {
}
func (p *otherPty) Resize(height uint16, width uint16) error {
p.mutex.Lock()
defer p.mutex.Unlock()
if p.closed {
return p.err
}
return pty.Setsize(p.pty, &pty.Winsize{
Rows: height,
Cols: width,
return p.control(p.pty, func(fd uintptr) error {
return termios.SetWinSize(fd, &termios.Winsize{
Winsize: unix.Winsize{
Row: height,
Col: width,
},
})
})
}
@ -91,6 +137,15 @@ func (p *otherPty) Close() error {
return err
}
type otherProcess struct {
pty *os.File
cmd *exec.Cmd
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
cmdDone chan any
cmdErr error
}
func (p *otherProcess) Wait() error {
<-p.cmdDone
return p.cmdErr

View File

@ -1,5 +1,4 @@
//go:build windows
// +build windows
package pty
@ -22,7 +21,12 @@ var (
)
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session
func newPty() (PTY, error) {
func newPty(opt ...Option) (PTY, error) {
var opts ptyOptions
for _, o := range opt {
o(&opts)
}
// We use the CreatePseudoConsole API which was introduced in build 17763
vsn := windows.RtlGetVersion()
if vsn.MajorVersion < 10 ||
@ -32,30 +36,42 @@ func newPty() (PTY, error) {
return nil, xerrors.Errorf("pty not supported")
}
ptyWindows := &ptyWindows{}
pty := &ptyWindows{
opts: opts,
}
var err error
ptyWindows.inputRead, ptyWindows.inputWrite, err = os.Pipe()
pty.inputRead, pty.inputWrite, err = os.Pipe()
if err != nil {
return nil, err
}
ptyWindows.outputRead, ptyWindows.outputWrite, err = os.Pipe()
pty.outputRead, pty.outputWrite, err = os.Pipe()
if err != nil {
_ = pty.inputRead.Close()
_ = pty.inputWrite.Close()
return nil, err
}
consoleSize := uintptr(80) + (uintptr(80) << 16)
if opts.sshReq != nil {
consoleSize = uintptr(opts.sshReq.Window.Width) + (uintptr(opts.sshReq.Window.Height) << 16)
}
ret, _, err := procCreatePseudoConsole.Call(
consoleSize,
uintptr(ptyWindows.inputRead.Fd()),
uintptr(ptyWindows.outputWrite.Fd()),
uintptr(pty.inputRead.Fd()),
uintptr(pty.outputWrite.Fd()),
0,
uintptr(unsafe.Pointer(&ptyWindows.console)),
uintptr(unsafe.Pointer(&pty.console)),
)
if int32(ret) < 0 {
_ = pty.Close()
return nil, xerrors.Errorf("create pseudo console (%d): %w", int32(ret), err)
}
return ptyWindows, nil
return pty, nil
}
type ptyWindows struct {
opts ptyOptions
console windows.Handle
outputWrite *os.File
@ -74,6 +90,13 @@ type windowsProcess struct {
proc *os.Process
}
// Name returns the TTY name on Windows.
//
// Not implemented.
func (p *ptyWindows) Name() string {
return ""
}
func (p *ptyWindows) Output() ReadWriter {
return ReadWriter{
Reader: p.outputRead,

View File

@ -21,19 +21,19 @@ import (
"github.com/coder/coder/testutil"
)
func New(t *testing.T) *PTY {
func New(t *testing.T, opts ...pty.Option) *PTY {
t.Helper()
ptty, err := pty.New()
ptty, err := pty.New(opts...)
require.NoError(t, err)
return create(t, ptty, "cmd")
}
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, pty.Process) {
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Process) {
t.Helper()
ptty, ps, err := pty.Start(cmd)
ptty, ps, err := pty.Start(cmd, opts...)
require.NoError(t, err)
t.Cleanup(func() {
_ = ps.Kill()

127
pty/ssh_other.go Normal file
View File

@ -0,0 +1,127 @@
//go:build !windows
package pty
import (
"log"
"github.com/gliderlabs/ssh"
"github.com/u-root/u-root/pkg/termios"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
)
// terminalModeFlagNames maps the SSH terminal mode flags to mnemonic
// names used by the termios package.
var terminalModeFlagNames = map[uint8]string{
gossh.VINTR: "intr",
gossh.VQUIT: "quit",
gossh.VERASE: "erase",
gossh.VKILL: "kill",
gossh.VEOF: "eof",
gossh.VEOL: "eol",
gossh.VEOL2: "eol2",
gossh.VSTART: "start",
gossh.VSTOP: "stop",
gossh.VSUSP: "susp",
gossh.VDSUSP: "dsusp",
gossh.VREPRINT: "rprnt",
gossh.VWERASE: "werase",
gossh.VLNEXT: "lnext",
gossh.VFLUSH: "flush",
gossh.VSWTCH: "swtch",
gossh.VSTATUS: "status",
gossh.VDISCARD: "discard",
gossh.IGNPAR: "ignpar",
gossh.PARMRK: "parmrk",
gossh.INPCK: "inpck",
gossh.ISTRIP: "istrip",
gossh.INLCR: "inlcr",
gossh.IGNCR: "igncr",
gossh.ICRNL: "icrnl",
gossh.IUCLC: "iuclc",
gossh.IXON: "ixon",
gossh.IXANY: "ixany",
gossh.IXOFF: "ixoff",
gossh.IMAXBEL: "imaxbel",
gossh.IUTF8: "iutf8",
gossh.ISIG: "isig",
gossh.ICANON: "icanon",
gossh.XCASE: "xcase",
gossh.ECHO: "echo",
gossh.ECHOE: "echoe",
gossh.ECHOK: "echok",
gossh.ECHONL: "echonl",
gossh.NOFLSH: "noflsh",
gossh.TOSTOP: "tostop",
gossh.IEXTEN: "iexten",
gossh.ECHOCTL: "echoctl",
gossh.ECHOKE: "echoke",
gossh.PENDIN: "pendin",
gossh.OPOST: "opost",
gossh.OLCUC: "olcuc",
gossh.ONLCR: "onlcr",
gossh.OCRNL: "ocrnl",
gossh.ONOCR: "onocr",
gossh.ONLRET: "onlret",
gossh.CS7: "cs7",
gossh.CS8: "cs8",
gossh.PARENB: "parenb",
gossh.PARODD: "parodd",
gossh.TTY_OP_ISPEED: "tty_op_ispeed",
gossh.TTY_OP_OSPEED: "tty_op_ospeed",
}
// applyTerminalModesToFd applies the terminal settings from the SSH
// request to the given fd.
//
// This is based on code from Tailscale's tailssh package:
// https://github.com/tailscale/tailscale/blob/main/ssh/tailssh/incubator.go
func applyTerminalModesToFd(logger *log.Logger, fd uintptr, req ssh.Pty) error {
// Get the current TTY configuration.
tios, err := termios.GTTY(int(fd))
if err != nil {
return xerrors.Errorf("GTTY: %w", err)
}
// Apply the modes from the SSH request.
tios.Row = req.Window.Height
tios.Col = req.Window.Width
for c, v := range req.Modes {
if c == gossh.TTY_OP_ISPEED {
tios.Ispeed = int(v)
continue
}
if c == gossh.TTY_OP_OSPEED {
tios.Ospeed = int(v)
continue
}
k, ok := terminalModeFlagNames[c]
if !ok {
if logger != nil {
logger.Printf("unknown terminal mode: %d", c)
}
continue
}
if _, ok := tios.CC[k]; ok {
tios.CC[k] = uint8(v)
continue
}
if _, ok := tios.Opts[k]; ok {
tios.Opts[k] = v > 0
continue
}
if logger != nil {
logger.Printf("unsupported terminal mode: k=%s, c=%d, v=%d", k, c, v)
}
}
// Save the new TTY configuration.
if _, err := tios.STTY(int(fd)); err != nil {
return xerrors.Errorf("STTY: %w", err)
}
return nil
}

View File

@ -4,8 +4,22 @@ import (
"os/exec"
)
// StartOption represents a configuration option passed to Start.
type StartOption func(*startOptions)
type startOptions struct {
ptyOpts []Option
}
// WithPTYOption applies the given options to the underlying PTY.
func WithPTYOption(opts ...Option) StartOption {
return func(o *startOptions) {
o.ptyOpts = append(o.ptyOpts, opts...)
}
}
// Start the command in a TTY. The calling code must not use cmd after passing it to the PTY, and
// instead rely on the returned Process to manage the command/process.
func Start(cmd *exec.Cmd) (PTY, Process, error) {
return startPty(cmd)
func Start(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
return startPty(cmd, opt...)
}

View File

@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package pty
@ -10,45 +9,49 @@ import (
"strings"
"syscall"
"github.com/creack/pty"
"golang.org/x/xerrors"
)
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
ptty, tty, err := pty.Open()
if err != nil {
return nil, nil, xerrors.Errorf("open: %w", err)
func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process, err error) {
var opts startOptions
for _, o := range opt {
o(&opts)
}
opty, err := newPty(opts.ptyOpts...)
if err != nil {
return nil, nil, xerrors.Errorf("newPty failed: %w", err)
}
origEnv := cmd.Env
if opty.opts.sshReq != nil {
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
}
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_PTY=%s", tty.Name()))
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
Setctty: true,
}
cmd.Stdout = tty
cmd.Stderr = tty
cmd.Stdin = tty
cmd.Stdout = opty.tty
cmd.Stderr = opty.tty
cmd.Stdin = opty.tty
err = cmd.Start()
if err != nil {
_ = ptty.Close()
_ = tty.Close()
_ = opty.Close()
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "bad file descriptor") {
// macOS has an obscure issue where the PTY occasionally closes
// before it's used. It's unknown why this is, but creating a new
// TTY resolves it.
return startPty(cmd)
cmd.Env = origEnv
return startPty(cmd, opt...)
}
return nil, nil, xerrors.Errorf("start: %w", err)
}
oPty := &otherPty{
pty: ptty,
tty: tty,
}
oProcess := &otherProcess{
pty: ptty,
pty: opty.pty,
cmd: cmd,
cmdDone: make(chan any),
}
go oProcess.waitInternal()
return oPty, oProcess, nil
return opty, oProcess, nil
}

View File

@ -6,12 +6,13 @@ import (
"os/exec"
"testing"
"github.com/gliderlabs/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/xerrors"
"go.uber.org/goleak"
"github.com/coder/coder/pty"
"github.com/coder/coder/pty/ptytest"
)
@ -40,10 +41,16 @@ func TestStart(t *testing.T) {
assert.NotEqual(t, 0, exitErr.ExitCode())
})
t.Run("SSH_PTY", func(t *testing.T) {
t.Run("SSH_TTY", func(t *testing.T) {
t.Parallel()
pty, ps := ptytest.Start(t, exec.Command("env"))
pty.ExpectMatch("SSH_PTY=/dev/")
opts := pty.WithPTYOption(pty.WithSSHRequest(ssh.Pty{
Window: ssh.Window{
Width: 80,
Height: 24,
},
}))
pty, ps := ptytest.Start(t, exec.Command("env"), opts)
pty.ExpectMatch("SSH_TTY=/dev/")
err := ps.Wait()
require.NoError(t, err)
})

View File

@ -4,6 +4,7 @@
package pty
import (
"fmt"
"os"
"os/exec"
"strings"
@ -16,7 +17,12 @@ import (
// Allocates a PTY and starts the specified command attached to it.
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
func startPty(cmd *exec.Cmd) (PTY, Process, error) {
func startPty(cmd *exec.Cmd, opt ...StartOption) (PTY, Process, error) {
var opts startOptions
for _, o := range opt {
o(&opts)
}
fullPath, err := exec.LookPath(cmd.Path)
if err != nil {
return nil, nil, err
@ -39,11 +45,14 @@ func startPty(cmd *exec.Cmd) (PTY, Process, error) {
if err != nil {
return nil, nil, err
}
pty, err := newPty()
pty, err := newPty(opts.ptyOpts...)
if err != nil {
return nil, nil, err
}
winPty := pty.(*ptyWindows)
if winPty.opts.sshReq != nil {
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", winPty.Name()))
}
attrs, err := windows.NewProcThreadAttributeList(1)
if err != nil {