mirror of https://github.com/coder/coder.git
458 lines
12 KiB
Go
458 lines
12 KiB
Go
package agent_test
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/pion/udp"
|
|
"github.com/pion/webrtc/v3"
|
|
"github.com/pkg/sftp"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/text/encoding/unicode"
|
|
"golang.org/x/text/transform"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/agent"
|
|
"github.com/coder/coder/peer"
|
|
"github.com/coder/coder/peerbroker"
|
|
"github.com/coder/coder/peerbroker/proto"
|
|
"github.com/coder/coder/provisionersdk"
|
|
"github.com/coder/coder/pty/ptytest"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
func TestAgent(t *testing.T) {
|
|
t.Parallel()
|
|
t.Run("SessionExec", func(t *testing.T) {
|
|
t.Parallel()
|
|
session := setupSSHSession(t, agent.Metadata{})
|
|
|
|
command := "echo test"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo test"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "test", strings.TrimSpace(string(output)))
|
|
})
|
|
|
|
t.Run("GitSSH", func(t *testing.T) {
|
|
t.Parallel()
|
|
session := setupSSHSession(t, agent.Metadata{})
|
|
command := "sh -c 'echo $GIT_SSH_COMMAND'"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
|
|
})
|
|
|
|
t.Run("SessionTTY", func(t *testing.T) {
|
|
t.Parallel()
|
|
if runtime.GOOS == "windows" {
|
|
// This might be our implementation, or ConPTY itself.
|
|
// It's difficult to find extensive tests for it, so
|
|
// it seems like it could be either.
|
|
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
|
}
|
|
session := setupSSHSession(t, agent.Metadata{})
|
|
command := "bash"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe"
|
|
}
|
|
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
|
require.NoError(t, err)
|
|
ptty := ptytest.New(t)
|
|
require.NoError(t, err)
|
|
session.Stdout = ptty.Output()
|
|
session.Stderr = ptty.Output()
|
|
session.Stdin = ptty.Input()
|
|
err = session.Start(command)
|
|
require.NoError(t, err)
|
|
caret := "$"
|
|
if runtime.GOOS == "windows" {
|
|
caret = ">"
|
|
}
|
|
ptty.ExpectMatch(caret)
|
|
ptty.WriteLine("echo test")
|
|
ptty.ExpectMatch("test")
|
|
ptty.WriteLine("exit")
|
|
err = session.Wait()
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("LocalForwarding", func(t *testing.T) {
|
|
t.Parallel()
|
|
random, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
_ = random.Close()
|
|
tcpAddr, valid := random.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
randomPort := tcpAddr.Port
|
|
|
|
local, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer local.Close()
|
|
tcpAddr, valid = local.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
localPort := tcpAddr.Port
|
|
done := make(chan struct{})
|
|
go func() {
|
|
conn, err := local.Accept()
|
|
assert.NoError(t, err)
|
|
_ = conn.Close()
|
|
close(done)
|
|
}()
|
|
|
|
err = setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, localPort)}, []string{"echo", "test"}).Start()
|
|
require.NoError(t, err)
|
|
|
|
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(localPort))
|
|
require.NoError(t, err)
|
|
conn.Close()
|
|
<-done
|
|
})
|
|
|
|
t.Run("SFTP", func(t *testing.T) {
|
|
t.Parallel()
|
|
sshClient, err := setupAgent(t, agent.Metadata{}, 0).SSHClient()
|
|
require.NoError(t, err)
|
|
client, err := sftp.NewClient(sshClient)
|
|
require.NoError(t, err)
|
|
tempFile := filepath.Join(t.TempDir(), "sftp")
|
|
file, err := client.Create(tempFile)
|
|
require.NoError(t, err)
|
|
err = file.Close()
|
|
require.NoError(t, err)
|
|
_, err = os.Stat(tempFile)
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("EnvironmentVariables", func(t *testing.T) {
|
|
t.Parallel()
|
|
key := "EXAMPLE"
|
|
value := "value"
|
|
session := setupSSHSession(t, agent.Metadata{
|
|
EnvironmentVariables: map[string]string{
|
|
key: value,
|
|
},
|
|
})
|
|
command := "sh -c 'echo $" + key + "'"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo %" + key + "%"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.Equal(t, value, strings.TrimSpace(string(output)))
|
|
})
|
|
|
|
t.Run("StartupScript", func(t *testing.T) {
|
|
t.Parallel()
|
|
tempPath := filepath.Join(os.TempDir(), "content.txt")
|
|
content := "somethingnice"
|
|
setupAgent(t, agent.Metadata{
|
|
StartupScript: "echo " + content + " > " + tempPath,
|
|
}, 0)
|
|
var gotContent string
|
|
require.Eventually(t, func() bool {
|
|
content, err := os.ReadFile(tempPath)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
if len(content) == 0 {
|
|
return false
|
|
}
|
|
if runtime.GOOS == "windows" {
|
|
// Windows uses UTF16! 🪟🪟🪟
|
|
content, _, err = transform.Bytes(unicode.UTF16(unicode.LittleEndian, unicode.UseBOM).NewDecoder(), content)
|
|
require.NoError(t, err)
|
|
}
|
|
gotContent = string(content)
|
|
return true
|
|
}, 15*time.Second, 100*time.Millisecond)
|
|
require.Equal(t, content, strings.TrimSpace(gotContent))
|
|
})
|
|
|
|
t.Run("ReconnectingPTY", func(t *testing.T) {
|
|
t.Parallel()
|
|
if runtime.GOOS == "windows" {
|
|
// This might be our implementation, or ConPTY itself.
|
|
// It's difficult to find extensive tests for it, so
|
|
// it seems like it could be either.
|
|
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
|
}
|
|
conn := setupAgent(t, agent.Metadata{}, 0)
|
|
id := uuid.NewString()
|
|
netConn, err := conn.ReconnectingPTY(id, 100, 100)
|
|
require.NoError(t, err)
|
|
bufRead := bufio.NewReader(netConn)
|
|
|
|
// Brief pause to reduce the likelihood that we send keystrokes while
|
|
// the shell is simultaneously sending a prompt.
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
data, err := json.Marshal(agent.ReconnectingPTYRequest{
|
|
Data: "echo test\r\n",
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = netConn.Write(data)
|
|
require.NoError(t, err)
|
|
|
|
expectLine := func(matcher func(string) bool) {
|
|
for {
|
|
line, err := bufRead.ReadString('\n')
|
|
require.NoError(t, err)
|
|
if matcher(line) {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
matchEchoCommand := func(line string) bool {
|
|
return strings.Contains(line, "echo test")
|
|
}
|
|
matchEchoOutput := func(line string) bool {
|
|
return strings.Contains(line, "test") && !strings.Contains(line, "echo")
|
|
}
|
|
|
|
// Once for typing the command...
|
|
expectLine(matchEchoCommand)
|
|
// And another time for the actual output.
|
|
expectLine(matchEchoOutput)
|
|
|
|
_ = netConn.Close()
|
|
netConn, err = conn.ReconnectingPTY(id, 100, 100)
|
|
require.NoError(t, err)
|
|
bufRead = bufio.NewReader(netConn)
|
|
|
|
// Same output again!
|
|
expectLine(matchEchoCommand)
|
|
expectLine(matchEchoOutput)
|
|
})
|
|
|
|
t.Run("Dial", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
setup func(t *testing.T) net.Listener
|
|
}{
|
|
{
|
|
name: "TCP",
|
|
setup: func(t *testing.T) net.Listener {
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err, "create TCP listener")
|
|
return l
|
|
},
|
|
},
|
|
{
|
|
name: "UDP",
|
|
setup: func(t *testing.T) net.Listener {
|
|
addr := net.UDPAddr{
|
|
IP: net.ParseIP("127.0.0.1"),
|
|
Port: 0,
|
|
}
|
|
l, err := udp.Listen("udp", &addr)
|
|
require.NoError(t, err, "create UDP listener")
|
|
return l
|
|
},
|
|
},
|
|
{
|
|
name: "Unix",
|
|
setup: func(t *testing.T) net.Listener {
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("Unix socket forwarding isn't supported on Windows")
|
|
}
|
|
|
|
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
|
require.NoError(t, err, "create temp dir for unix listener")
|
|
t.Cleanup(func() {
|
|
_ = os.RemoveAll(tmpDir)
|
|
})
|
|
|
|
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
|
require.NoError(t, err, "create UDP listener")
|
|
return l
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
c := c
|
|
t.Run(c.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Setup listener
|
|
l := c.setup(t)
|
|
defer l.Close()
|
|
go func() {
|
|
for {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
go testAccept(t, c)
|
|
}
|
|
}()
|
|
|
|
// Dial the listener over WebRTC twice and test out of order
|
|
conn := setupAgent(t, agent.Metadata{}, 0)
|
|
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn1.Close()
|
|
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn2.Close()
|
|
testDial(t, conn2)
|
|
testDial(t, conn1)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("DialError", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if runtime.GOOS == "windows" {
|
|
// This test uses Unix listeners so we can very easily ensure that
|
|
// no other tests decide to listen on the same random port we
|
|
// picked.
|
|
t.Skip("this test is unsupported on Windows")
|
|
return
|
|
}
|
|
|
|
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
|
require.NoError(t, err, "create temp dir")
|
|
t.Cleanup(func() {
|
|
_ = os.RemoveAll(tmpDir)
|
|
})
|
|
|
|
// Try to dial the non-existent Unix socket over WebRTC
|
|
conn := setupAgent(t, agent.Metadata{}, 0)
|
|
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "remote dial error")
|
|
require.ErrorContains(t, err, "no such file")
|
|
require.Nil(t, netConn)
|
|
})
|
|
}
|
|
|
|
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
|
|
agentConn := setupAgent(t, agent.Metadata{}, 0)
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
go func() {
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
ssh, err := agentConn.SSH()
|
|
assert.NoError(t, err)
|
|
go io.Copy(conn, ssh)
|
|
go io.Copy(ssh, conn)
|
|
}
|
|
}()
|
|
t.Cleanup(func() {
|
|
_ = listener.Close()
|
|
})
|
|
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
args := append(beforeArgs,
|
|
"-o", "HostName "+tcpAddr.IP.String(),
|
|
"-o", "Port "+strconv.Itoa(tcpAddr.Port),
|
|
"-o", "StrictHostKeyChecking=no", "host")
|
|
args = append(args, afterArgs...)
|
|
return exec.Command("ssh", args...)
|
|
}
|
|
|
|
func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
|
|
sshClient, err := setupAgent(t, options, 0).SSHClient()
|
|
require.NoError(t, err)
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err)
|
|
return session
|
|
}
|
|
|
|
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
|
|
client, server := provisionersdk.TransportPipe()
|
|
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
|
|
listener, err := peerbroker.Listen(server, nil)
|
|
return metadata, listener, err
|
|
}, &agent.Options{
|
|
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
|
ReconnectingPTYTimeout: ptyTimeout,
|
|
})
|
|
t.Cleanup(func() {
|
|
_ = client.Close()
|
|
_ = server.Close()
|
|
_ = closer.Close()
|
|
})
|
|
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
|
|
stream, err := api.NegotiateConnection(context.Background())
|
|
assert.NoError(t, err)
|
|
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
|
|
Logger: slogtest.Make(t, nil),
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_ = conn.Close()
|
|
})
|
|
|
|
return &agent.Conn{
|
|
Negotiator: api,
|
|
Conn: conn,
|
|
}
|
|
}
|
|
|
|
var dialTestPayload = []byte("dean-was-here123")
|
|
|
|
func testDial(t *testing.T, c net.Conn) {
|
|
t.Helper()
|
|
|
|
assertWritePayload(t, c, dialTestPayload)
|
|
assertReadPayload(t, c, dialTestPayload)
|
|
}
|
|
|
|
func testAccept(t *testing.T, c net.Conn) {
|
|
t.Helper()
|
|
defer c.Close()
|
|
|
|
assertReadPayload(t, c, dialTestPayload)
|
|
assertWritePayload(t, c, dialTestPayload)
|
|
}
|
|
|
|
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
|
|
b := make([]byte, len(payload)+16)
|
|
n, err := r.Read(b)
|
|
assert.NoError(t, err, "read payload")
|
|
assert.Equal(t, len(payload), n, "read payload length does not match")
|
|
assert.Equal(t, payload, b[:n])
|
|
}
|
|
|
|
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
|
|
n, err := w.Write(payload)
|
|
assert.NoError(t, err, "write payload")
|
|
assert.Equal(t, len(payload), n, "payload length does not match")
|
|
}
|