feat: Add workspace agent for SSH (#318)

* feat: Add workspace agent for SSH

This adds the initial agent that supports TTY
and execution over SSH. It functions across MacOS,
Windows, and Linux.

This does not handle the coderd interaction yet,
but does setup a simple path forward.

* Fix pty tests on Windows

* Fix log race

* Lock around dial error to fix log output

* Fix context return early

* fix: Leaking yamux session after HTTP handler is closed

Closes #317. We depended on the context canceling the yamux connection,
but this isn't a sync operation. Explicitly calling close ensures the
handler waits for yamux to complete before exit.

* Lock around close return

* Force failure with log

* Fix failed handler

* Upgrade dep

* Fix defer inside loops

* Fix context cancel for HTTP requests

* Fix resize
This commit is contained in:
Kyle Carberry 2022-02-18 23:13:32 -06:00 committed by GitHub
parent 65de96c8b4
commit 91bf8636fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 574 additions and 39 deletions

329
agent/agent.go Normal file
View File

@ -0,0 +1,329 @@
package agent
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"io"
"net"
"os/exec"
"os/user"
"sync"
"time"
"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/pty"
"github.com/coder/retry"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
)
func DialSSH(conn *peer.Conn) (net.Conn, error) {
channel, err := conn.Dial(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
return nil, err
}
return channel.NetConn(), nil
}
func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) {
netConn, err := DialSSH(conn)
if err != nil {
return nil, err
}
sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{
Config: gossh.Config{
Ciphers: []string{"arcfour"},
},
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, err
}
return gossh.NewClient(sshConn, channels, requests), nil
}
type Options struct {
Logger slog.Logger
}
type Dialer func(ctx context.Context) (*peerbroker.Listener, error)
func New(dialer Dialer, options *Options) io.Closer {
ctx, cancelFunc := context.WithCancel(context.Background())
server := &server{
clientDialer: dialer,
options: options,
closeCancel: cancelFunc,
closed: make(chan struct{}),
}
server.init(ctx)
return server
}
type server struct {
clientDialer Dialer
options *Options
closeCancel context.CancelFunc
closeMutex sync.Mutex
closed chan struct{}
sshServer *ssh.Server
}
func (s *server) init(ctx context.Context) {
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
if err != nil {
panic(err)
}
sshLogger := s.options.Logger.Named("ssh-server")
forwardHandler := &ssh.ForwardedTCPHandler{}
s.sshServer = &ssh.Server{
ChannelHandlers: ssh.DefaultChannelHandlers,
ConnectionFailedCallback: func(conn net.Conn, err error) {
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
},
Handler: func(session ssh.Session) {
err := s.handleSSHSession(session)
if err != nil {
s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err))
_ = session.Exit(1)
return
}
},
HostSigners: []ssh.Signer{randomSigner},
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
// Allow local port forwarding all!
sshLogger.Debug(ctx, "local port forward",
slog.F("destination-host", destinationHost),
slog.F("destination-port", destinationPort))
return true
},
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
return true
},
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
// Allow reverse port forwarding all!
sshLogger.Debug(ctx, "local port forward",
slog.F("bind-host", bindHost),
slog.F("bind-port", bindPort))
return true
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
},
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return &gossh.ServerConfig{
Config: gossh.Config{
// "arcfour" is the fastest SSH cipher. We prioritize throughput
// over encryption here, because the WebRTC connection is already
// encrypted. If possible, we'd disable encryption entirely here.
Ciphers: []string{"arcfour"},
},
NoClientAuth: true,
}
},
}
go s.run(ctx)
}
func (*server) handleSSHSession(session ssh.Session) error {
var (
command string
args = []string{}
err error
)
username := session.User()
if username == "" {
currentUser, err := user.Current()
if err != nil {
return xerrors.Errorf("get current user: %w", err)
}
username = currentUser.Username
}
// gliderlabs/ssh returns a command slice of zero
// when a shell is requested.
if len(session.Command()) == 0 {
command, err = usershell.Get(username)
if err != nil {
return xerrors.Errorf("get user shell: %w", err)
}
} else {
command = session.Command()[0]
if len(session.Command()) > 1 {
args = session.Command()[1:]
}
}
signals := make(chan ssh.Signal)
breaks := make(chan bool)
defer close(signals)
defer close(breaks)
go func() {
for {
select {
case <-session.Context().Done():
return
// Ignore signals and breaks for now!
case <-signals:
case <-breaks:
}
}
}()
cmd := exec.CommandContext(session.Context(), command, args...)
cmd.Env = session.Environ()
sshPty, windowSize, isPty := session.Pty()
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
ptty, process, err := pty.Start(cmd)
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
go func() {
for win := range windowSize {
err := ptty.Resize(uint16(win.Width), uint16(win.Height))
if err != nil {
panic(err)
}
}
}()
go func() {
_, _ = io.Copy(ptty.Input(), session)
}()
go func() {
_, _ = io.Copy(session, ptty.Output())
}()
_, _ = process.Wait()
_ = ptty.Close()
return nil
}
cmd.Stdout = session
cmd.Stderr = session
// This blocks forever until stdin is received if we don't
// use StdinPipe. It's unknown what causes this.
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return xerrors.Errorf("create stdin pipe: %w", err)
}
go func() {
_, _ = io.Copy(stdinPipe, session)
}()
err = cmd.Start()
if err != nil {
return xerrors.Errorf("start: %w", err)
}
_ = cmd.Wait()
return nil
}
func (s *server) run(ctx context.Context) {
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
peerListener, err = s.clientDialer(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if s.isClosed() {
return
}
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
s.options.Logger.Debug(context.Background(), "connected")
break
}
select {
case <-ctx.Done():
return
default:
}
for {
conn, err := peerListener.Accept()
if err != nil {
if s.isClosed() {
return
}
s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
s.run(ctx)
return
}
go s.handlePeerConn(ctx, conn)
}
}
func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
for {
channel, err := conn.Accept(ctx)
if err != nil {
if errors.Is(err, peer.ErrClosed) || s.isClosed() {
return
}
s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
return
}
switch channel.Protocol() {
case "ssh":
s.sshServer.HandleConn(channel.NetConn())
default:
s.options.Logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
slog.F("label", channel.Label()),
)
}
}
}
// isClosed returns whether the API is closed or not.
func (s *server) isClosed() bool {
select {
case <-s.closed:
return true
default:
return false
}
}
func (s *server) Close() error {
s.closeMutex.Lock()
defer s.closeMutex.Unlock()
if s.isClosed() {
return nil
}
close(s.closed)
s.closeCancel()
_ = s.sshServer.Close()
return nil
}

110
agent/agent_test.go Normal file
View File

@ -0,0 +1,110 @@
package agent_test
import (
"context"
"runtime"
"strings"
"testing"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/pty/ptytest"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestAgent(t *testing.T) {
t.Parallel()
t.Run("SessionExec", func(t *testing.T) {
t.Parallel()
api := setup(t)
stream, err := api.NegotiateConnection(context.Background())
require.NoError(t, err)
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil),
})
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
sshClient, err := agent.DialSSHClient(conn)
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
command := "echo test"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo test"
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, "test", strings.TrimSpace(string(output)))
})
t.Run("SessionTTY", func(t *testing.T) {
t.Parallel()
api := setup(t)
stream, err := api.NegotiateConnection(context.Background())
require.NoError(t, err)
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil),
})
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
sshClient, err := agent.DialSSHClient(conn)
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
prompt := "$"
command := "bash"
if runtime.GOOS == "windows" {
command = "cmd.exe"
prompt = ">"
}
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
require.NoError(t, err)
session.Stdout = ptty.Output()
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Start(command)
require.NoError(t, err)
ptty.ExpectMatch(prompt)
ptty.WriteLine("echo test")
ptty.ExpectMatch("test")
ptty.WriteLine("exit")
err = session.Wait()
require.NoError(t, err)
})
}
func setup(t *testing.T) proto.DRPCPeerBrokerClient {
client, server := provisionersdk.TransportPipe()
closer := agent.New(func(ctx context.Context) (*peerbroker.Listener, error) {
return peerbroker.Listen(server, &peer.ConnOptions{
Logger: slogtest.Make(t, nil),
})
}, &agent.Options{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
})
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
_ = closer.Close()
})
return proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
}

View File

@ -0,0 +1,10 @@
package usershell
import "os"
// Get returns the $SHELL environment variable.
// TODO: This should use "dscl" to fetch the proper value. See:
// https://stackoverflow.com/questions/16375519/how-to-get-the-default-shell
func Get(username string) (string, error) {
return os.Getenv("SHELL"), nil
}

View File

@ -0,0 +1,31 @@
//go:build !windows && !darwin
// +build !windows,!darwin
package usershell
import (
"os"
"strings"
"golang.org/x/xerrors"
)
// Get returns the /etc/passwd entry for the username provided.
func Get(username string) (string, error) {
contents, err := os.ReadFile("/etc/passwd")
if err != nil {
return "", xerrors.Errorf("read /etc/passwd: %w", err)
}
lines := strings.Split(string(contents), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, username+":") {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
return "", xerrors.Errorf("malformed user entry: %q", line)
}
return parts[6], nil
}
return "", xerrors.New("user not found in /etc/passwd and $SHELL not set")
}

View File

@ -0,0 +1,27 @@
//go:build !windows && !darwin
// +build !windows,!darwin
package usershell_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/agent/usershell"
)
func TestGet(t *testing.T) {
t.Parallel()
t.Run("Has", func(t *testing.T) {
t.Parallel()
shell, err := usershell.Get("root")
require.NoError(t, err)
require.NotEmpty(t, shell)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
_, err := usershell.Get("notauser")
require.Error(t, err)
})
}

View File

@ -0,0 +1,6 @@
package usershell
// Get returns the command prompt binary name.
func Get(username string) (string, error) {
return "cmd.exe", nil
}

View File

@ -164,7 +164,7 @@ func AwaitProjectImportJob(t *testing.T, client *codersdk.Client, organization s
provisionerJob, err = client.ProjectImportJob(context.Background(), organization, job)
require.NoError(t, err)
return provisionerJob.Status.Completed()
}, 3*time.Second, 25*time.Millisecond)
}, 5*time.Second, 25*time.Millisecond)
return provisionerJob
}
@ -176,7 +176,7 @@ func AwaitWorkspaceProvisionJob(t *testing.T, client *codersdk.Client, organizat
provisionerJob, err = client.WorkspaceProvisionJob(context.Background(), organization, job)
require.NoError(t, err)
return provisionerJob.Status.Completed()
}, 3*time.Second, 25*time.Millisecond)
}, 5*time.Second, 25*time.Millisecond)
return provisionerJob
}

2
go.mod
View File

@ -20,6 +20,7 @@ require (
github.com/coder/retry v1.3.0
github.com/creack/pty v1.1.17
github.com/fatih/color v1.13.0
github.com/gliderlabs/ssh v0.3.3
github.com/go-chi/chi/v5 v5.0.7
github.com/go-chi/render v1.0.1
github.com/go-playground/validator/v10 v10.10.0
@ -64,6 +65,7 @@ require (
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/alecthomas/chroma v0.10.0 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect
github.com/cenkalti/backoff/v4 v4.1.2 // indirect
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect

4
go.sum
View File

@ -132,6 +132,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF
github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:CgnQgUtFrFz9mxFNtED3jI5tLDjKlOM+oUF/sTk6ps0=
github.com/andybalholm/crlf v0.0.0-20171020200849-670099aa064f/go.mod h1:k8feO4+kXDxro6ErPXBRTJ/ro2mf0SsFG8s7doP9kJE=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/apache/arrow/go/arrow v0.0.0-20210818145353-234c94e4ce64/go.mod h1:2qMFB56yOP3KzkB3PbYZ4AlUFg3a88F67TIx5lB/WwY=
github.com/apache/arrow/go/arrow v0.0.0-20211013220434-5962184e7a30/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
@ -442,6 +444,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm
github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14=
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
github.com/gliderlabs/ssh v0.3.3 h1:mBQ8NiOgDkINJrZtoizkC3nDNYgSaWtxyem6S2XHBtA=
github.com/gliderlabs/ssh v0.3.3/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8=
github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8=

View File

@ -45,7 +45,7 @@ func (p *otherPty) Output() io.ReadWriter {
func (p *otherPty) Resize(cols uint16, rows uint16) error {
p.mutex.Lock()
defer p.mutex.Unlock()
return pty.Setsize(p.tty, &pty.Winsize{
return pty.Setsize(p.pty, &pty.Winsize{
Rows: rows,
Cols: cols,
})

View File

@ -96,12 +96,15 @@ func (p *ptyWindows) Close() error {
return nil
}
p.closed = true
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
if ret != 0 {
return xerrors.Errorf("close pseudo console: %w", err)
}
_ = p.outputWrite.Close()
_ = p.outputRead.Close()
_ = p.inputWrite.Close()
_ = p.inputRead.Close()
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
if ret < 0 {
return xerrors.Errorf("close pseudo console: %w", err)
}
return nil
}

View File

@ -5,8 +5,10 @@ import (
"bytes"
"fmt"
"io"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"testing"
"unicode/utf8"
@ -28,10 +30,10 @@ func New(t *testing.T) *PTY {
return create(t, ptty)
}
func Start(t *testing.T, cmd *exec.Cmd) *PTY {
ptty, err := pty.Start(cmd)
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
ptty, ps, err := pty.Start(cmd)
require.NoError(t, err)
return create(t, ptty)
return create(t, ptty), ps
}
func create(t *testing.T, ptty pty.PTY) *PTY {
@ -86,10 +88,15 @@ func (p *PTY) ExpectMatch(str string) string {
break
}
}
p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), ""))
return buffer.String()
}
func (p *PTY) WriteLine(str string) {
_, err := fmt.Fprintf(p.PTY.Input(), "%s\n", str)
newline := "\n"
if runtime.GOOS == "windows" {
newline = "\r\n"
}
_, err := fmt.Fprintf(p.PTY.Input(), "%s%s", str, newline)
require.NoError(p.t, err)
}

View File

@ -1,7 +1,10 @@
package pty
import "os/exec"
import (
"os"
"os/exec"
)
func Start(cmd *exec.Cmd) (PTY, error) {
func Start(cmd *exec.Cmd) (PTY, *os.Process, error) {
return startPty(cmd)
}

View File

@ -4,6 +4,7 @@
package pty
import (
"os"
"os/exec"
"syscall"
@ -11,10 +12,10 @@ import (
"golang.org/x/xerrors"
)
func startPty(cmd *exec.Cmd) (PTY, error) {
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
ptty, tty, err := pty.Open()
if err != nil {
return nil, xerrors.Errorf("open: %w", err)
return nil, nil, xerrors.Errorf("open: %w", err)
}
defer func() {
_ = tty.Close()
@ -29,10 +30,11 @@ func startPty(cmd *exec.Cmd) (PTY, error) {
err = cmd.Start()
if err != nil {
_ = ptty.Close()
return nil, xerrors.Errorf("start: %w", err)
return nil, nil, xerrors.Errorf("start: %w", err)
}
return &otherPty{
oPty := &otherPty{
pty: ptty,
tty: tty,
}, nil
}
return oPty, cmd.Process, nil
}

View File

@ -7,8 +7,9 @@ import (
"os/exec"
"testing"
"github.com/coder/coder/pty/ptytest"
"go.uber.org/goleak"
"github.com/coder/coder/pty/ptytest"
)
func TestMain(m *testing.M) {
@ -19,7 +20,7 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty := ptytest.Start(t, exec.Command("echo", "test"))
pty, _ := ptytest.Start(t, exec.Command("echo", "test"))
pty.ExpectMatch("test")
})
}

View File

@ -11,47 +11,48 @@ import (
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/xerrors"
)
// Allocates a PTY and starts the specified command attached to it.
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
func startPty(cmd *exec.Cmd) (PTY, error) {
func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
fullPath, err := exec.LookPath(cmd.Path)
if err != nil {
return nil, err
return nil, nil, err
}
pathPtr, err := windows.UTF16PtrFromString(fullPath)
if err != nil {
return nil, err
return nil, nil, err
}
argsPtr, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(cmd.Args))
if err != nil {
return nil, err
return nil, nil, err
}
if cmd.Dir == "" {
cmd.Dir, err = os.Getwd()
if err != nil {
return nil, err
return nil, nil, err
}
}
dirPtr, err := windows.UTF16PtrFromString(cmd.Dir)
if err != nil {
return nil, err
return nil, nil, err
}
pty, err := newPty()
if err != nil {
return nil, err
return nil, nil, err
}
winPty := pty.(*ptyWindows)
attrs, err := windows.NewProcThreadAttributeList(1)
if err != nil {
return nil, err
return nil, nil, err
}
// Taken from: https://github.com/microsoft/hcsshim/blob/2314362e977aa03b3ed245a4beb12d00422af0e2/internal/winapi/process.go#L6
err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console))
if err != nil {
return nil, err
return nil, nil, err
}
startupInfo := &windows.StartupInfoEx{}
@ -73,12 +74,16 @@ func startPty(cmd *exec.Cmd) (PTY, error) {
&processInfo,
)
if err != nil {
return nil, err
return nil, nil, err
}
defer windows.CloseHandle(processInfo.Thread)
defer windows.CloseHandle(processInfo.Process)
return pty, nil
process, err := os.FindProcess(int(processInfo.ProcessId))
if err != nil {
return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err)
}
return pty, process, nil
}
// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476

View File

@ -20,12 +20,12 @@ func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
pty, _ := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
pty.ExpectMatch("test")
})
t.Run("Resize", func(t *testing.T) {
t.Parallel()
pty := ptytest.Start(t, exec.Command("cmd.exe"))
pty, _ := ptytest.Start(t, exec.Command("cmd.exe"))
err := pty.Resize(100, 50)
require.NoError(t, err)
})

View File

@ -1,5 +0,0 @@
variable "bananas" {
description = "hello!"
}
resource "null_resource" "example" {}