mirror of https://github.com/coder/coder.git
1983 lines
53 KiB
Go
1983 lines
53 KiB
Go
package agent_test
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/netip"
|
|
"os"
|
|
"os/exec"
|
|
"os/user"
|
|
"path"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
scp "github.com/bramvdbogaerde/go-scp"
|
|
"github.com/google/uuid"
|
|
"github.com/pion/udp"
|
|
"github.com/pkg/sftp"
|
|
"github.com/spf13/afero"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/exp/maps"
|
|
"golang.org/x/xerrors"
|
|
"tailscale.com/net/speedtest"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/agent"
|
|
"github.com/coder/coder/agent/agentssh"
|
|
"github.com/coder/coder/coderd/httpapi"
|
|
"github.com/coder/coder/codersdk"
|
|
"github.com/coder/coder/codersdk/agentsdk"
|
|
"github.com/coder/coder/pty"
|
|
"github.com/coder/coder/pty/ptytest"
|
|
"github.com/coder/coder/tailnet"
|
|
"github.com/coder/coder/tailnet/tailnettest"
|
|
"github.com/coder/coder/testutil"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
// NOTE: These tests only work when your default shell is bash for some reason.
|
|
|
|
func TestAgent_Stats_SSH(t *testing.T) {
|
|
t.Parallel()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
|
|
//nolint:dogsled
|
|
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
|
|
sshClient, err := conn.SSHClient(ctx)
|
|
require.NoError(t, err)
|
|
defer sshClient.Close()
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err)
|
|
defer session.Close()
|
|
stdin, err := session.StdinPipe()
|
|
require.NoError(t, err)
|
|
err = session.Shell()
|
|
require.NoError(t, err)
|
|
|
|
var s *agentsdk.Stats
|
|
require.Eventuallyf(t, func() bool {
|
|
var ok bool
|
|
s, ok = <-stats
|
|
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSSH == 1
|
|
}, testutil.WaitLong, testutil.IntervalFast,
|
|
"never saw stats: %+v", s,
|
|
)
|
|
_ = stdin.Close()
|
|
err = session.Wait()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
|
|
//nolint:dogsled
|
|
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
|
|
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash")
|
|
require.NoError(t, err)
|
|
defer ptyConn.Close()
|
|
|
|
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
|
|
Data: "echo test\r\n",
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = ptyConn.Write(data)
|
|
require.NoError(t, err)
|
|
|
|
var s *agentsdk.Stats
|
|
require.Eventuallyf(t, func() bool {
|
|
var ok bool
|
|
s, ok = <-stats
|
|
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPTY == 1
|
|
}, testutil.WaitLong, testutil.IntervalFast,
|
|
"never saw stats: %+v", s,
|
|
)
|
|
}
|
|
|
|
func TestAgent_Stats_Magic(t *testing.T) {
|
|
t.Parallel()
|
|
t.Run("StripsEnvironmentVariable", func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
sshClient, err := conn.SSHClient(ctx)
|
|
require.NoError(t, err)
|
|
defer sshClient.Close()
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err)
|
|
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
|
|
defer session.Close()
|
|
|
|
command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'"
|
|
expected := ""
|
|
if runtime.GOOS == "windows" {
|
|
expected = "%" + agentssh.MagicSessionTypeEnvironmentVariable + "%"
|
|
command = "cmd.exe /c echo " + expected
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.Equal(t, expected, strings.TrimSpace(string(output)))
|
|
})
|
|
t.Run("Tracks", func(t *testing.T) {
|
|
t.Parallel()
|
|
if runtime.GOOS == "window" {
|
|
t.Skip("Sleeping for infinity doesn't work on Windows")
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
//nolint:dogsled
|
|
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
sshClient, err := conn.SSHClient(ctx)
|
|
require.NoError(t, err)
|
|
defer sshClient.Close()
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err)
|
|
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
|
|
defer session.Close()
|
|
stdin, err := session.StdinPipe()
|
|
require.NoError(t, err)
|
|
err = session.Shell()
|
|
require.NoError(t, err)
|
|
var s *agentsdk.Stats
|
|
require.Eventuallyf(t, func() bool {
|
|
var ok bool
|
|
s, ok = <-stats
|
|
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 &&
|
|
// Ensure that the connection didn't count as a "normal" SSH session.
|
|
// This was a special one, so it should be labeled specially in the stats!
|
|
s.SessionCountVSCode == 1 &&
|
|
// Ensure that connection latency is being counted!
|
|
// If it isn't, it's set to -1.
|
|
s.ConnectionMedianLatencyMS >= 0
|
|
}, testutil.WaitLong, testutil.IntervalFast,
|
|
"never saw stats: %+v", s,
|
|
)
|
|
// The shell will automatically exit if there is no stdin!
|
|
_ = stdin.Close()
|
|
err = session.Wait()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestAgent_SessionExec(t *testing.T) {
|
|
t.Parallel()
|
|
session := setupSSHSession(t, agentsdk.Manifest{})
|
|
|
|
command := "echo test"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo test"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "test", strings.TrimSpace(string(output)))
|
|
}
|
|
|
|
func TestAgent_GitSSH(t *testing.T) {
|
|
t.Parallel()
|
|
session := setupSSHSession(t, agentsdk.Manifest{})
|
|
command := "sh -c 'echo $GIT_SSH_COMMAND'"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
|
|
}
|
|
|
|
func TestAgent_SessionTTYShell(t *testing.T) {
|
|
t.Parallel()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
t.Cleanup(cancel)
|
|
if runtime.GOOS == "windows" {
|
|
// This might be our implementation, or ConPTY itself.
|
|
// It's difficult to find extensive tests for it, so
|
|
// it seems like it could be either.
|
|
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
|
}
|
|
session := setupSSHSession(t, agentsdk.Manifest{})
|
|
command := "sh"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe"
|
|
}
|
|
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
|
require.NoError(t, err)
|
|
ptty := ptytest.New(t)
|
|
session.Stdout = ptty.Output()
|
|
session.Stderr = ptty.Output()
|
|
session.Stdin = ptty.Input()
|
|
err = session.Start(command)
|
|
require.NoError(t, err)
|
|
_ = ptty.Peek(ctx, 1) // wait for the prompt
|
|
ptty.WriteLine("echo test")
|
|
ptty.ExpectMatch("test")
|
|
ptty.WriteLine("exit")
|
|
err = session.Wait()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestAgent_SessionTTYExitCode(t *testing.T) {
|
|
t.Parallel()
|
|
session := setupSSHSession(t, agentsdk.Manifest{})
|
|
command := "areallynotrealcommand"
|
|
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
|
require.NoError(t, err)
|
|
ptty := ptytest.New(t)
|
|
session.Stdout = ptty.Output()
|
|
session.Stderr = ptty.Output()
|
|
session.Stdin = ptty.Input()
|
|
err = session.Start(command)
|
|
require.NoError(t, err)
|
|
err = session.Wait()
|
|
exitErr := &ssh.ExitError{}
|
|
require.True(t, xerrors.As(err, &exitErr))
|
|
if runtime.GOOS == "windows" {
|
|
assert.Equal(t, 1, exitErr.ExitStatus())
|
|
} else {
|
|
assert.Equal(t, 127, exitErr.ExitStatus())
|
|
}
|
|
}
|
|
|
|
//nolint:paralleltest // This test sets an environment variable.
|
|
func TestAgent_Session_TTY_MOTD(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
// This might be our implementation, or ConPTY itself.
|
|
// It's difficult to find extensive tests for it, so
|
|
// it seems like it could be either.
|
|
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
|
}
|
|
|
|
wantMOTD := "Welcome to your Coder workspace!"
|
|
|
|
tmpdir := t.TempDir()
|
|
name := filepath.Join(tmpdir, "motd")
|
|
err := os.WriteFile(name, []byte(wantMOTD), 0o600)
|
|
require.NoError(t, err, "write motd file")
|
|
|
|
// Set HOME so we can ensure no ~/.hushlogin is present.
|
|
t.Setenv("HOME", tmpdir)
|
|
|
|
session := setupSSHSession(t, agentsdk.Manifest{
|
|
MOTDFile: name,
|
|
})
|
|
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
|
require.NoError(t, err)
|
|
|
|
ptty := ptytest.New(t)
|
|
var stdout bytes.Buffer
|
|
session.Stdout = &stdout
|
|
session.Stderr = ptty.Output()
|
|
session.Stdin = ptty.Input()
|
|
err = session.Shell()
|
|
require.NoError(t, err)
|
|
|
|
ptty.WriteLine("exit 0")
|
|
err = session.Wait()
|
|
require.NoError(t, err)
|
|
|
|
require.Contains(t, stdout.String(), wantMOTD, "should show motd")
|
|
}
|
|
|
|
//nolint:paralleltest // This test sets an environment variable.
|
|
func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
// This might be our implementation, or ConPTY itself.
|
|
// It's difficult to find extensive tests for it, so
|
|
// it seems like it could be either.
|
|
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
|
}
|
|
|
|
wantNotMOTD := "Welcome to your Coder workspace!"
|
|
|
|
tmpdir := t.TempDir()
|
|
name := filepath.Join(tmpdir, "motd")
|
|
err := os.WriteFile(name, []byte(wantNotMOTD), 0o600)
|
|
require.NoError(t, err, "write motd file")
|
|
|
|
// Create hushlogin to silence motd.
|
|
f, err := os.Create(filepath.Join(tmpdir, ".hushlogin"))
|
|
require.NoError(t, err, "create .hushlogin file")
|
|
err = f.Close()
|
|
require.NoError(t, err, "close .hushlogin file")
|
|
|
|
// Set HOME so we can ensure ~/.hushlogin is present.
|
|
t.Setenv("HOME", tmpdir)
|
|
|
|
session := setupSSHSession(t, agentsdk.Manifest{
|
|
MOTDFile: name,
|
|
})
|
|
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
|
require.NoError(t, err)
|
|
|
|
ptty := ptytest.New(t)
|
|
var stdout bytes.Buffer
|
|
session.Stdout = &stdout
|
|
session.Stderr = ptty.Output()
|
|
session.Stdin = ptty.Input()
|
|
err = session.Shell()
|
|
require.NoError(t, err)
|
|
|
|
ptty.WriteLine("exit 0")
|
|
err = session.Wait()
|
|
require.NoError(t, err)
|
|
|
|
require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd")
|
|
}
|
|
|
|
func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
|
|
t.Parallel()
|
|
// This test is here to prevent regressions where quickly executing
|
|
// commands (with TTY) don't sync their output to the SSH session.
|
|
//
|
|
// See: https://github.com/coder/coder/issues/6656
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
sshClient, err := conn.SSHClient(ctx)
|
|
require.NoError(t, err)
|
|
defer sshClient.Close()
|
|
|
|
ptty := ptytest.New(t)
|
|
|
|
var stdout bytes.Buffer
|
|
// NOTE(mafredri): Increase iterations to increase chance of failure,
|
|
// assuming bug is present. Limiting GOMAXPROCS further
|
|
// increases the chance of failure.
|
|
// Using 1000 iterations is basically a guaranteed failure (but let's
|
|
// not increase test times needlessly).
|
|
// Limit GOMAXPROCS (e.g. `export GOMAXPROCS=1`) to further increase
|
|
// chance of failure. Also -race helps.
|
|
for i := 0; i < 5; i++ {
|
|
func() {
|
|
stdout.Reset()
|
|
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err)
|
|
defer session.Close()
|
|
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
|
require.NoError(t, err)
|
|
|
|
session.Stdout = &stdout
|
|
session.Stderr = ptty.Output()
|
|
session.Stdin = ptty.Input()
|
|
err = session.Start("echo wazzup")
|
|
require.NoError(t, err)
|
|
|
|
err = session.Wait()
|
|
require.NoError(t, err)
|
|
require.Contains(t, stdout.String(), "wazzup", "should output greeting")
|
|
}()
|
|
}
|
|
}
|
|
|
|
func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// This test is here to prevent regressions where a command (with or
|
|
// without) a large amount of output would not be fully copied to the
|
|
// SSH session. On unix systems, this was fixed by duplicating the file
|
|
// descriptor of the PTY master and using it for copying the output.
|
|
//
|
|
// See: https://github.com/coder/coder/issues/6656
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
sshClient, err := conn.SSHClient(ctx)
|
|
require.NoError(t, err)
|
|
defer sshClient.Close()
|
|
|
|
ptty := ptytest.New(t)
|
|
|
|
var stdout bytes.Buffer
|
|
// NOTE(mafredri): Increase iterations to increase chance of failure,
|
|
// assuming bug is present.
|
|
// Using 10 iterations is basically a guaranteed failure (but let's
|
|
// not increase test times needlessly). Run with -race and do not
|
|
// limit parallelism (`export GOMAXPROCS=10`) to increase the chance
|
|
// of failure.
|
|
for i := 0; i < 1; i++ {
|
|
func() {
|
|
stdout.Reset()
|
|
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err)
|
|
defer session.Close()
|
|
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
|
require.NoError(t, err)
|
|
|
|
session.Stdout = &stdout
|
|
session.Stderr = ptty.Output()
|
|
session.Stdin = ptty.Input()
|
|
want := strings.Repeat("wazzup", 1024+1) // ~6KB, +1 because 1024 is a common buffer size.
|
|
err = session.Start("echo " + want)
|
|
require.NoError(t, err)
|
|
|
|
err = session.Wait()
|
|
require.NoError(t, err)
|
|
require.Contains(t, stdout.String(), want, "should output entire greeting")
|
|
}()
|
|
}
|
|
}
|
|
|
|
//nolint:paralleltest // This test reserves a port.
|
|
func TestAgent_TCPLocalForwarding(t *testing.T) {
|
|
random, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
_ = random.Close()
|
|
tcpAddr, valid := random.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
randomPort := tcpAddr.Port
|
|
|
|
local, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer local.Close()
|
|
tcpAddr, valid = local.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
remotePort := tcpAddr.Port
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
conn, err := local.Accept()
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
b := make([]byte, 4)
|
|
_, err = conn.Read(b)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
_, err = conn.Write(b)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
}()
|
|
|
|
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
|
|
|
|
go func() {
|
|
err := proc.Wait()
|
|
select {
|
|
case <-done:
|
|
default:
|
|
assert.NoError(t, err)
|
|
}
|
|
}()
|
|
|
|
require.Eventually(t, func() bool {
|
|
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer conn.Close()
|
|
_, err = conn.Write([]byte("test"))
|
|
if !assert.NoError(t, err) {
|
|
return false
|
|
}
|
|
b := make([]byte, 4)
|
|
_, err = conn.Read(b)
|
|
if !assert.NoError(t, err) {
|
|
return false
|
|
}
|
|
if !assert.Equal(t, "test", string(b)) {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}, testutil.WaitLong, testutil.IntervalSlow)
|
|
|
|
<-done
|
|
|
|
_ = proc.Kill()
|
|
}
|
|
|
|
//nolint:paralleltest // This test reserves a port.
|
|
func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
|
random, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
_ = random.Close()
|
|
tcpAddr, valid := random.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
randomPort := tcpAddr.Port
|
|
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer l.Close()
|
|
tcpAddr, valid = l.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
localPort := tcpAddr.Port
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
b := make([]byte, 4)
|
|
_, err = conn.Read(b)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
_, err = conn.Write(b)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
}()
|
|
|
|
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
|
|
|
|
go func() {
|
|
err := proc.Wait()
|
|
select {
|
|
case <-done:
|
|
default:
|
|
assert.NoError(t, err)
|
|
}
|
|
}()
|
|
|
|
require.Eventually(t, func() bool {
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer conn.Close()
|
|
_, err = conn.Write([]byte("test"))
|
|
if !assert.NoError(t, err) {
|
|
return false
|
|
}
|
|
b := make([]byte, 4)
|
|
_, err = conn.Read(b)
|
|
if !assert.NoError(t, err) {
|
|
return false
|
|
}
|
|
if !assert.Equal(t, "test", string(b)) {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}, testutil.WaitLong, testutil.IntervalSlow)
|
|
|
|
<-done
|
|
|
|
_ = proc.Kill()
|
|
}
|
|
|
|
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
|
t.Parallel()
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("unix domain sockets are not fully supported on Windows")
|
|
}
|
|
|
|
tmpdir := tempDirUnixSocket(t)
|
|
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
|
localSocketPath := filepath.Join(tmpdir, "local-socket")
|
|
|
|
l, err := net.Listen("unix", remoteSocketPath)
|
|
require.NoError(t, err)
|
|
defer l.Close()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
b := make([]byte, 4)
|
|
_, err = conn.Read(b)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
_, err = conn.Write(b)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
}()
|
|
|
|
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
|
|
|
|
go func() {
|
|
err := proc.Wait()
|
|
select {
|
|
case <-done:
|
|
default:
|
|
assert.NoError(t, err)
|
|
}
|
|
}()
|
|
|
|
require.Eventually(t, func() bool {
|
|
_, err := os.Stat(localSocketPath)
|
|
return err == nil
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
|
|
conn, err := net.Dial("unix", localSocketPath)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
_, err = conn.Write([]byte("test"))
|
|
require.NoError(t, err)
|
|
b := make([]byte, 4)
|
|
_, err = conn.Read(b)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "test", string(b))
|
|
_ = conn.Close()
|
|
<-done
|
|
|
|
_ = proc.Kill()
|
|
}
|
|
|
|
func TestAgent_UnixRemoteForwarding(t *testing.T) {
|
|
t.Parallel()
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("unix domain sockets are not fully supported on Windows")
|
|
}
|
|
|
|
tmpdir := tempDirUnixSocket(t)
|
|
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
|
localSocketPath := filepath.Join(tmpdir, "local-socket")
|
|
|
|
l, err := net.Listen("unix", localSocketPath)
|
|
require.NoError(t, err)
|
|
defer l.Close()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
b := make([]byte, 4)
|
|
_, err = conn.Read(b)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
_, err = conn.Write(b)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
}()
|
|
|
|
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
|
|
|
|
go func() {
|
|
err := proc.Wait()
|
|
select {
|
|
case <-done:
|
|
default:
|
|
assert.NoError(t, err)
|
|
}
|
|
}()
|
|
|
|
// It's possible that the socket is created but the server is not ready to
|
|
// accept connections yet. We need to retry until we can connect.
|
|
//
|
|
// Note that we wait long here because if the tailnet connection has trouble
|
|
// connecting, it could take 5 seconds or more to reconnect.
|
|
var conn net.Conn
|
|
require.Eventually(t, func() bool {
|
|
var err error
|
|
conn, err = net.Dial("unix", remoteSocketPath)
|
|
return err == nil
|
|
}, testutil.WaitLong, testutil.IntervalFast)
|
|
defer conn.Close()
|
|
_, err = conn.Write([]byte("test"))
|
|
require.NoError(t, err)
|
|
b := make([]byte, 4)
|
|
_, err = conn.Read(b)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "test", string(b))
|
|
_ = conn.Close()
|
|
|
|
<-done
|
|
|
|
_ = proc.Kill()
|
|
}
|
|
|
|
func TestAgent_SFTP(t *testing.T) {
|
|
t.Parallel()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
u, err := user.Current()
|
|
require.NoError(t, err, "get current user")
|
|
home := u.HomeDir
|
|
if runtime.GOOS == "windows" {
|
|
home = "/" + strings.ReplaceAll(home, "\\", "/")
|
|
}
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
sshClient, err := conn.SSHClient(ctx)
|
|
require.NoError(t, err)
|
|
defer sshClient.Close()
|
|
client, err := sftp.NewClient(sshClient)
|
|
require.NoError(t, err)
|
|
defer client.Close()
|
|
wd, err := client.Getwd()
|
|
require.NoError(t, err, "get working directory")
|
|
require.Equal(t, home, wd, "working directory should be home user home")
|
|
tempFile := filepath.Join(t.TempDir(), "sftp")
|
|
// SFTP only accepts unix-y paths.
|
|
remoteFile := filepath.ToSlash(tempFile)
|
|
if !path.IsAbs(remoteFile) {
|
|
// On Windows, e.g. "/C:/Users/...".
|
|
remoteFile = path.Join("/", remoteFile)
|
|
}
|
|
file, err := client.Create(remoteFile)
|
|
require.NoError(t, err)
|
|
err = file.Close()
|
|
require.NoError(t, err)
|
|
_, err = os.Stat(tempFile)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestAgent_SCP(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
sshClient, err := conn.SSHClient(ctx)
|
|
require.NoError(t, err)
|
|
defer sshClient.Close()
|
|
scpClient, err := scp.NewClientBySSH(sshClient)
|
|
require.NoError(t, err)
|
|
defer scpClient.Close()
|
|
tempFile := filepath.Join(t.TempDir(), "scp")
|
|
content := "hello world"
|
|
err = scpClient.CopyFile(context.Background(), strings.NewReader(content), tempFile, "0755")
|
|
require.NoError(t, err)
|
|
_, err = os.Stat(tempFile)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestAgent_EnvironmentVariables(t *testing.T) {
|
|
t.Parallel()
|
|
key := "EXAMPLE"
|
|
value := "value"
|
|
session := setupSSHSession(t, agentsdk.Manifest{
|
|
EnvironmentVariables: map[string]string{
|
|
key: value,
|
|
},
|
|
})
|
|
command := "sh -c 'echo $" + key + "'"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo %" + key + "%"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.Equal(t, value, strings.TrimSpace(string(output)))
|
|
}
|
|
|
|
func TestAgent_EnvironmentVariableExpansion(t *testing.T) {
|
|
t.Parallel()
|
|
key := "EXAMPLE"
|
|
session := setupSSHSession(t, agentsdk.Manifest{
|
|
EnvironmentVariables: map[string]string{
|
|
key: "$SOMETHINGNOTSET",
|
|
},
|
|
})
|
|
command := "sh -c 'echo $" + key + "'"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo %" + key + "%"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
expect := ""
|
|
if runtime.GOOS == "windows" {
|
|
expect = "%EXAMPLE%"
|
|
}
|
|
// Output should be empty, because the variable is not set!
|
|
require.Equal(t, expect, strings.TrimSpace(string(output)))
|
|
}
|
|
|
|
func TestAgent_CoderEnvVars(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, key := range []string{"CODER"} {
|
|
key := key
|
|
t.Run(key, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
session := setupSSHSession(t, agentsdk.Manifest{})
|
|
command := "sh -c 'echo $" + key + "'"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo %" + key + "%"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, strings.TrimSpace(string(output)))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAgent_SSHConnectionEnvVars(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Note: the SSH_TTY environment variable should only be set for TTYs.
|
|
// For some reason this test produces a TTY locally and a non-TTY in CI
|
|
// so we don't test for the absence of SSH_TTY.
|
|
for _, key := range []string{"SSH_CONNECTION", "SSH_CLIENT"} {
|
|
key := key
|
|
t.Run(key, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
session := setupSSHSession(t, agentsdk.Manifest{})
|
|
command := "sh -c 'echo $" + key + "'"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo %" + key + "%"
|
|
}
|
|
output, err := session.Output(command)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, strings.TrimSpace(string(output)))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAgent_StartupScript(t *testing.T) {
|
|
t.Parallel()
|
|
output := "something"
|
|
command := "sh -c 'echo " + output + "'"
|
|
if runtime.GOOS == "windows" {
|
|
command = "cmd.exe /c echo " + output
|
|
}
|
|
t.Run("Success", func(t *testing.T) {
|
|
t.Parallel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
client := &client{
|
|
t: t,
|
|
agentID: uuid.New(),
|
|
manifest: agentsdk.Manifest{
|
|
StartupScript: command,
|
|
DERPMap: &tailcfg.DERPMap{},
|
|
},
|
|
statsChan: make(chan *agentsdk.Stats),
|
|
coordinator: tailnet.NewCoordinator(logger),
|
|
}
|
|
closer := agent.New(agent.Options{
|
|
Client: client,
|
|
Filesystem: afero.NewMemMapFs(),
|
|
Logger: logger.Named("agent"),
|
|
ReconnectingPTYTimeout: 0,
|
|
})
|
|
t.Cleanup(func() {
|
|
_ = closer.Close()
|
|
})
|
|
assert.Eventually(t, func() bool {
|
|
got := client.getLifecycleStates()
|
|
return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
require.Len(t, client.getStartupLogs(), 1)
|
|
require.Equal(t, output, client.getStartupLogs()[0].Output)
|
|
})
|
|
// This ensures that even when coderd sends back that the startup
|
|
// script has written too many lines it will still succeed!
|
|
t.Run("OverflowsAndSkips", func(t *testing.T) {
|
|
t.Parallel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
client := &client{
|
|
t: t,
|
|
agentID: uuid.New(),
|
|
manifest: agentsdk.Manifest{
|
|
StartupScript: command,
|
|
DERPMap: &tailcfg.DERPMap{},
|
|
},
|
|
patchWorkspaceLogs: func() error {
|
|
resp := httptest.NewRecorder()
|
|
httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{
|
|
Message: "Too many lines!",
|
|
})
|
|
res := resp.Result()
|
|
defer res.Body.Close()
|
|
return codersdk.ReadBodyAsError(res)
|
|
},
|
|
statsChan: make(chan *agentsdk.Stats),
|
|
coordinator: tailnet.NewCoordinator(logger),
|
|
}
|
|
closer := agent.New(agent.Options{
|
|
Client: client,
|
|
Filesystem: afero.NewMemMapFs(),
|
|
Logger: logger.Named("agent"),
|
|
ReconnectingPTYTimeout: 0,
|
|
})
|
|
t.Cleanup(func() {
|
|
_ = closer.Close()
|
|
})
|
|
assert.Eventually(t, func() bool {
|
|
got := client.getLifecycleStates()
|
|
return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
require.Len(t, client.getStartupLogs(), 0)
|
|
})
|
|
}
|
|
|
|
func TestAgent_Metadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
echoHello := "echo 'hello'"
|
|
|
|
t.Run("Once", func(t *testing.T) {
|
|
t.Parallel()
|
|
//nolint:dogsled
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
|
{
|
|
Key: "greeting",
|
|
Interval: 0,
|
|
Script: echoHello,
|
|
},
|
|
},
|
|
}, 0)
|
|
|
|
var gotMd map[string]agentsdk.PostMetadataRequest
|
|
require.Eventually(t, func() bool {
|
|
gotMd = client.getMetadata()
|
|
return len(gotMd) == 1
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
collectedAt := gotMd["greeting"].CollectedAt
|
|
|
|
require.Never(t, func() bool {
|
|
gotMd = client.getMetadata()
|
|
if len(gotMd) != 1 {
|
|
panic("unexpected number of metadata")
|
|
}
|
|
return !gotMd["greeting"].CollectedAt.Equal(collectedAt)
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
})
|
|
|
|
t.Run("Many", func(t *testing.T) {
|
|
t.Parallel()
|
|
//nolint:dogsled
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
|
{
|
|
Key: "greeting",
|
|
Interval: 1,
|
|
Timeout: 100,
|
|
Script: echoHello,
|
|
},
|
|
},
|
|
}, 0)
|
|
|
|
var gotMd map[string]agentsdk.PostMetadataRequest
|
|
require.Eventually(t, func() bool {
|
|
gotMd = client.getMetadata()
|
|
return len(gotMd) == 1
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
collectedAt1 := gotMd["greeting"].CollectedAt
|
|
if !assert.Equal(t, "hello", strings.TrimSpace(gotMd["greeting"].Value)) {
|
|
t.Errorf("got: %+v", gotMd)
|
|
}
|
|
|
|
if !assert.Eventually(t, func() bool {
|
|
gotMd = client.getMetadata()
|
|
return gotMd["greeting"].CollectedAt.After(collectedAt1)
|
|
}, testutil.WaitShort, testutil.IntervalMedium) {
|
|
t.Fatalf("expected metadata to be collected again")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAgentMetadata_Timing(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
// Shell scripting in Windows is a pain, and we have already tested
|
|
// that the OS logic works in the simpler tests.
|
|
t.Skip()
|
|
}
|
|
testutil.SkipIfNotTiming(t)
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
|
|
const reportInterval = 2
|
|
const intervalUnit = 100 * time.Millisecond
|
|
var (
|
|
greetingPath = filepath.Join(dir, "greeting")
|
|
script = "echo hello | tee -a " + greetingPath
|
|
)
|
|
//nolint:dogsled
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
|
{
|
|
Key: "greeting",
|
|
Interval: reportInterval,
|
|
Script: script,
|
|
},
|
|
{
|
|
Key: "bad",
|
|
Interval: reportInterval,
|
|
Script: "exit 1",
|
|
},
|
|
},
|
|
}, 0)
|
|
|
|
require.Eventually(t, func() bool {
|
|
return len(client.getMetadata()) == 2
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
for start := time.Now(); time.Since(start) < testutil.WaitMedium; time.Sleep(testutil.IntervalMedium) {
|
|
md := client.getMetadata()
|
|
require.Len(t, md, 2, "got: %+v", md)
|
|
|
|
require.Equal(t, "hello\n", md["greeting"].Value)
|
|
require.Equal(t, "run cmd: exit status 1", md["bad"].Error)
|
|
|
|
greetingByt, err := os.ReadFile(greetingPath)
|
|
require.NoError(t, err)
|
|
|
|
var (
|
|
numGreetings = bytes.Count(greetingByt, []byte("hello"))
|
|
idealNumGreetings = time.Since(start) / (reportInterval * intervalUnit)
|
|
// We allow a 50% error margin because the report loop may backlog
|
|
// in CI and other toasters. In production, there is no hard
|
|
// guarantee on timing either, and the frontend gives similar
|
|
// wiggle room to the staleness of the value.
|
|
upperBound = int(idealNumGreetings) + 1
|
|
lowerBound = (int(idealNumGreetings) / 2)
|
|
)
|
|
|
|
if idealNumGreetings < 50 {
|
|
// There is an insufficient sample size.
|
|
continue
|
|
}
|
|
|
|
t.Logf("numGreetings: %d, idealNumGreetings: %d", numGreetings, idealNumGreetings)
|
|
// The report loop may slow down on load, but it should never, ever
|
|
// speed up.
|
|
if numGreetings > upperBound {
|
|
t.Fatalf("too many greetings: %d > %d in %v", numGreetings, upperBound, time.Since(start))
|
|
} else if numGreetings < lowerBound {
|
|
t.Fatalf("too few greetings: %d < %d", numGreetings, lowerBound)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestAgent_Lifecycle(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("StartTimeout", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
StartupScript: "sleep 5",
|
|
StartupScriptTimeout: time.Nanosecond,
|
|
}, 0)
|
|
|
|
want := []codersdk.WorkspaceAgentLifecycle{
|
|
codersdk.WorkspaceAgentLifecycleStarting,
|
|
codersdk.WorkspaceAgentLifecycleStartTimeout,
|
|
}
|
|
|
|
var got []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
got = client.getLifecycleStates()
|
|
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
switch len(got) {
|
|
case 1:
|
|
// This can happen if lifecycle state updates are
|
|
// too fast, only the latest one is reported.
|
|
require.Equal(t, want[1:], got)
|
|
default:
|
|
// This is the expected case.
|
|
require.Equal(t, want, got)
|
|
}
|
|
})
|
|
|
|
t.Run("StartError", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
StartupScript: "false",
|
|
StartupScriptTimeout: 30 * time.Second,
|
|
}, 0)
|
|
|
|
want := []codersdk.WorkspaceAgentLifecycle{
|
|
codersdk.WorkspaceAgentLifecycleStarting,
|
|
codersdk.WorkspaceAgentLifecycleStartError,
|
|
}
|
|
|
|
var got []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
got = client.getLifecycleStates()
|
|
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
switch len(got) {
|
|
case 1:
|
|
// This can happen if lifecycle state updates are
|
|
// too fast, only the latest one is reported.
|
|
require.Equal(t, want[1:], got)
|
|
default:
|
|
// This is the expected case.
|
|
require.Equal(t, want, got)
|
|
}
|
|
})
|
|
|
|
t.Run("Ready", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
StartupScript: "true",
|
|
StartupScriptTimeout: 30 * time.Second,
|
|
}, 0)
|
|
|
|
want := []codersdk.WorkspaceAgentLifecycle{
|
|
codersdk.WorkspaceAgentLifecycleStarting,
|
|
codersdk.WorkspaceAgentLifecycleReady,
|
|
}
|
|
|
|
var got []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
got = client.getLifecycleStates()
|
|
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
switch len(got) {
|
|
case 1:
|
|
// This can happen if lifecycle state updates are
|
|
// too fast, only the latest one is reported.
|
|
require.Equal(t, want[1:], got)
|
|
default:
|
|
// This is the expected case.
|
|
require.Equal(t, want, got)
|
|
}
|
|
})
|
|
|
|
t.Run("ShuttingDown", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
|
ShutdownScript: "sleep 5",
|
|
StartupScriptTimeout: 30 * time.Second,
|
|
}, 0)
|
|
|
|
var ready []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
ready = client.getLifecycleStates()
|
|
return len(ready) > 0 && ready[len(ready)-1] == codersdk.WorkspaceAgentLifecycleReady
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
// Start close asynchronously so that we an inspect the state.
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
err := closer.Close()
|
|
assert.NoError(t, err)
|
|
}()
|
|
t.Cleanup(func() {
|
|
<-done
|
|
})
|
|
|
|
want := []codersdk.WorkspaceAgentLifecycle{
|
|
codersdk.WorkspaceAgentLifecycleShuttingDown,
|
|
}
|
|
|
|
var got []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
got = client.getLifecycleStates()[len(ready):]
|
|
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
require.Equal(t, want, got)
|
|
})
|
|
|
|
t.Run("ShutdownTimeout", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
|
ShutdownScript: "sleep 5",
|
|
ShutdownScriptTimeout: time.Nanosecond,
|
|
}, 0)
|
|
|
|
var ready []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
ready = client.getLifecycleStates()
|
|
return len(ready) > 0 && ready[len(ready)-1] == codersdk.WorkspaceAgentLifecycleReady
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
// Start close asynchronously so that we an inspect the state.
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
err := closer.Close()
|
|
assert.NoError(t, err)
|
|
}()
|
|
t.Cleanup(func() {
|
|
<-done
|
|
})
|
|
|
|
want := []codersdk.WorkspaceAgentLifecycle{
|
|
codersdk.WorkspaceAgentLifecycleShuttingDown,
|
|
codersdk.WorkspaceAgentLifecycleShutdownTimeout,
|
|
}
|
|
|
|
var got []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
got = client.getLifecycleStates()[len(ready):]
|
|
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
switch len(got) {
|
|
case 1:
|
|
// This can happen if lifecycle state updates are
|
|
// too fast, only the latest one is reported.
|
|
require.Equal(t, want[1:], got)
|
|
default:
|
|
// This is the expected case.
|
|
require.Equal(t, want, got)
|
|
}
|
|
})
|
|
|
|
t.Run("ShutdownError", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
|
ShutdownScript: "false",
|
|
ShutdownScriptTimeout: 30 * time.Second,
|
|
}, 0)
|
|
|
|
var ready []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
ready = client.getLifecycleStates()
|
|
return len(ready) > 0 && ready[len(ready)-1] == codersdk.WorkspaceAgentLifecycleReady
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
// Start close asynchronously so that we an inspect the state.
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
err := closer.Close()
|
|
assert.NoError(t, err)
|
|
}()
|
|
t.Cleanup(func() {
|
|
<-done
|
|
})
|
|
|
|
want := []codersdk.WorkspaceAgentLifecycle{
|
|
codersdk.WorkspaceAgentLifecycleShuttingDown,
|
|
codersdk.WorkspaceAgentLifecycleShutdownError,
|
|
}
|
|
|
|
var got []codersdk.WorkspaceAgentLifecycle
|
|
assert.Eventually(t, func() bool {
|
|
got = client.getLifecycleStates()[len(ready):]
|
|
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
switch len(got) {
|
|
case 1:
|
|
// This can happen if lifecycle state updates are
|
|
// too fast, only the latest one is reported.
|
|
require.Equal(t, want[1:], got)
|
|
default:
|
|
// This is the expected case.
|
|
require.Equal(t, want, got)
|
|
}
|
|
})
|
|
|
|
t.Run("ShutdownScriptOnce", func(t *testing.T) {
|
|
t.Parallel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
expected := "this-is-shutdown"
|
|
client := &client{
|
|
t: t,
|
|
agentID: uuid.New(),
|
|
manifest: agentsdk.Manifest{
|
|
DERPMap: tailnettest.RunDERPAndSTUN(t),
|
|
StartupScript: "echo 1",
|
|
ShutdownScript: "echo " + expected,
|
|
},
|
|
statsChan: make(chan *agentsdk.Stats),
|
|
coordinator: tailnet.NewCoordinator(logger),
|
|
}
|
|
|
|
fs := afero.NewMemMapFs()
|
|
agent := agent.New(agent.Options{
|
|
Client: client,
|
|
Logger: logger.Named("agent"),
|
|
Filesystem: fs,
|
|
})
|
|
|
|
// agent.Close() loads the shutdown script from the agent metadata.
|
|
// The metadata is populated just before execution of the startup script, so it's mandatory to wait
|
|
// until the startup starts.
|
|
require.Eventually(t, func() bool {
|
|
outputPath := filepath.Join(os.TempDir(), "coder-startup-script.log")
|
|
content, err := afero.ReadFile(fs, outputPath)
|
|
if err != nil {
|
|
t.Logf("read file %q: %s", outputPath, err)
|
|
return false
|
|
}
|
|
return len(content) > 0 // something is in the startup log file
|
|
}, testutil.WaitShort, testutil.IntervalMedium)
|
|
|
|
err := agent.Close()
|
|
require.NoError(t, err, "agent should be closed successfully")
|
|
|
|
outputPath := filepath.Join(os.TempDir(), "coder-shutdown-script.log")
|
|
logFirstRead, err := afero.ReadFile(fs, outputPath)
|
|
require.NoError(t, err, "log file should be present")
|
|
require.Equal(t, expected, string(bytes.TrimSpace(logFirstRead)))
|
|
|
|
// Make sure that script can't be executed twice.
|
|
err = agent.Close()
|
|
require.NoError(t, err, "don't need to close the agent twice, no effect")
|
|
|
|
logSecondRead, err := afero.ReadFile(fs, outputPath)
|
|
require.NoError(t, err, "log file should be present")
|
|
require.Equal(t, string(bytes.TrimSpace(logFirstRead)), string(bytes.TrimSpace(logSecondRead)))
|
|
})
|
|
}
|
|
|
|
func TestAgent_Startup(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("EmptyDirectory", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
StartupScript: "true",
|
|
StartupScriptTimeout: 30 * time.Second,
|
|
Directory: "",
|
|
}, 0)
|
|
assert.Eventually(t, func() bool {
|
|
return client.getStartup().Version != ""
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
require.Equal(t, "", client.getStartup().ExpandedDirectory)
|
|
})
|
|
|
|
t.Run("HomeDirectory", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
StartupScript: "true",
|
|
StartupScriptTimeout: 30 * time.Second,
|
|
Directory: "~",
|
|
}, 0)
|
|
assert.Eventually(t, func() bool {
|
|
return client.getStartup().Version != ""
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
homeDir, err := os.UserHomeDir()
|
|
require.NoError(t, err)
|
|
require.Equal(t, homeDir, client.getStartup().ExpandedDirectory)
|
|
})
|
|
|
|
t.Run("NotAbsoluteDirectory", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
StartupScript: "true",
|
|
StartupScriptTimeout: 30 * time.Second,
|
|
Directory: "coder/coder",
|
|
}, 0)
|
|
assert.Eventually(t, func() bool {
|
|
return client.getStartup().Version != ""
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
homeDir, err := os.UserHomeDir()
|
|
require.NoError(t, err)
|
|
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.getStartup().ExpandedDirectory)
|
|
})
|
|
|
|
t.Run("HomeEnvironmentVariable", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
StartupScript: "true",
|
|
StartupScriptTimeout: 30 * time.Second,
|
|
Directory: "$HOME",
|
|
}, 0)
|
|
assert.Eventually(t, func() bool {
|
|
return client.getStartup().Version != ""
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
homeDir, err := os.UserHomeDir()
|
|
require.NoError(t, err)
|
|
require.Equal(t, homeDir, client.getStartup().ExpandedDirectory)
|
|
})
|
|
}
|
|
|
|
func TestAgent_ReconnectingPTY(t *testing.T) {
|
|
t.Parallel()
|
|
if runtime.GOOS == "windows" {
|
|
// This might be our implementation, or ConPTY itself.
|
|
// It's difficult to find extensive tests for it, so
|
|
// it seems like it could be either.
|
|
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
id := uuid.New()
|
|
netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
|
|
require.NoError(t, err)
|
|
defer netConn.Close()
|
|
|
|
bufRead := bufio.NewReader(netConn)
|
|
|
|
// Brief pause to reduce the likelihood that we send keystrokes while
|
|
// the shell is simultaneously sending a prompt.
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
|
|
Data: "echo test\r\n",
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = netConn.Write(data)
|
|
require.NoError(t, err)
|
|
|
|
expectLine := func(matcher func(string) bool) {
|
|
for {
|
|
line, err := bufRead.ReadString('\n')
|
|
require.NoError(t, err)
|
|
if matcher(line) {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
matchEchoCommand := func(line string) bool {
|
|
return strings.Contains(line, "echo test")
|
|
}
|
|
matchEchoOutput := func(line string) bool {
|
|
return strings.Contains(line, "test") && !strings.Contains(line, "echo")
|
|
}
|
|
|
|
// Once for typing the command...
|
|
expectLine(matchEchoCommand)
|
|
// And another time for the actual output.
|
|
expectLine(matchEchoOutput)
|
|
|
|
_ = netConn.Close()
|
|
netConn, err = conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
|
|
require.NoError(t, err)
|
|
defer netConn.Close()
|
|
|
|
bufRead = bufio.NewReader(netConn)
|
|
|
|
// Same output again!
|
|
expectLine(matchEchoCommand)
|
|
expectLine(matchEchoOutput)
|
|
}
|
|
|
|
func TestAgent_Dial(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
setup func(t *testing.T) net.Listener
|
|
}{
|
|
{
|
|
name: "TCP",
|
|
setup: func(t *testing.T) net.Listener {
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err, "create TCP listener")
|
|
return l
|
|
},
|
|
},
|
|
{
|
|
name: "UDP",
|
|
setup: func(t *testing.T) net.Listener {
|
|
addr := net.UDPAddr{
|
|
IP: net.ParseIP("127.0.0.1"),
|
|
Port: 0,
|
|
}
|
|
l, err := udp.Listen("udp", &addr)
|
|
require.NoError(t, err, "create UDP listener")
|
|
return l
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
c := c
|
|
t.Run(c.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Setup listener
|
|
l := c.setup(t)
|
|
defer l.Close()
|
|
go func() {
|
|
for {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
go testAccept(t, c)
|
|
}
|
|
}()
|
|
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
require.True(t, conn.AwaitReachable(context.Background()))
|
|
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn1.Close()
|
|
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
|
require.NoError(t, err)
|
|
defer conn2.Close()
|
|
testDial(t, conn2)
|
|
testDial(t, conn1)
|
|
time.Sleep(150 * time.Millisecond)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAgent_Speedtest(t *testing.T) {
|
|
t.Parallel()
|
|
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
derpMap := tailnettest.RunDERPAndSTUN(t)
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
|
DERPMap: derpMap,
|
|
}, 0)
|
|
defer conn.Close()
|
|
res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond)
|
|
require.NoError(t, err)
|
|
t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond())
|
|
}
|
|
|
|
func TestAgent_Reconnect(t *testing.T) {
|
|
t.Parallel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
// After the agent is disconnected from a coordinator, it's supposed
|
|
// to reconnect!
|
|
coordinator := tailnet.NewCoordinator(logger)
|
|
defer coordinator.Close()
|
|
|
|
agentID := uuid.New()
|
|
statsCh := make(chan *agentsdk.Stats)
|
|
derpMap := tailnettest.RunDERPAndSTUN(t)
|
|
client := &client{
|
|
t: t,
|
|
agentID: agentID,
|
|
manifest: agentsdk.Manifest{
|
|
DERPMap: derpMap,
|
|
},
|
|
statsChan: statsCh,
|
|
coordinator: coordinator,
|
|
}
|
|
initialized := atomic.Int32{}
|
|
closer := agent.New(agent.Options{
|
|
ExchangeToken: func(ctx context.Context) (string, error) {
|
|
initialized.Add(1)
|
|
return "", nil
|
|
},
|
|
Client: client,
|
|
Logger: logger.Named("agent"),
|
|
})
|
|
defer closer.Close()
|
|
|
|
require.Eventually(t, func() bool {
|
|
return coordinator.Node(agentID) != nil
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
client.lastWorkspaceAgent()
|
|
require.Eventually(t, func() bool {
|
|
return initialized.Load() == 2
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
|
t.Parallel()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
coordinator := tailnet.NewCoordinator(logger)
|
|
defer coordinator.Close()
|
|
|
|
client := &client{
|
|
t: t,
|
|
agentID: uuid.New(),
|
|
manifest: agentsdk.Manifest{
|
|
GitAuthConfigs: 1,
|
|
DERPMap: &tailcfg.DERPMap{},
|
|
},
|
|
statsChan: make(chan *agentsdk.Stats),
|
|
coordinator: coordinator,
|
|
}
|
|
filesystem := afero.NewMemMapFs()
|
|
closer := agent.New(agent.Options{
|
|
ExchangeToken: func(ctx context.Context) (string, error) {
|
|
return "", nil
|
|
},
|
|
Client: client,
|
|
Logger: logger.Named("agent"),
|
|
Filesystem: filesystem,
|
|
})
|
|
defer closer.Close()
|
|
|
|
home, err := os.UserHomeDir()
|
|
require.NoError(t, err)
|
|
name := filepath.Join(home, ".vscode-server", "data", "Machine", "settings.json")
|
|
require.Eventually(t, func() bool {
|
|
_, err := filesystem.Stat(name)
|
|
return err == nil
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
}
|
|
|
|
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
|
|
//nolint:dogsled
|
|
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
waitGroup := sync.WaitGroup{}
|
|
go func() {
|
|
defer listener.Close()
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
ssh, err := agentConn.SSH(ctx)
|
|
cancel()
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return
|
|
}
|
|
waitGroup.Add(1)
|
|
go func() {
|
|
agentssh.Bicopy(context.Background(), conn, ssh)
|
|
waitGroup.Done()
|
|
}()
|
|
}
|
|
}()
|
|
t.Cleanup(func() {
|
|
_ = listener.Close()
|
|
waitGroup.Wait()
|
|
})
|
|
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
args := append(beforeArgs,
|
|
"-o", "HostName "+tcpAddr.IP.String(),
|
|
"-o", "Port "+strconv.Itoa(tcpAddr.Port),
|
|
"-o", "StrictHostKeyChecking=no",
|
|
"-o", "UserKnownHostsFile=/dev/null",
|
|
"host",
|
|
)
|
|
args = append(args, afterArgs...)
|
|
cmd := exec.Command("ssh", args...)
|
|
return ptytest.Start(t, cmd)
|
|
}
|
|
|
|
func setupSSHSession(t *testing.T, options agentsdk.Manifest) *ssh.Session {
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
//nolint:dogsled
|
|
conn, _, _, _, _ := setupAgent(t, options, 0)
|
|
sshClient, err := conn.SSHClient(ctx)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_ = sshClient.Close()
|
|
})
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_ = session.Close()
|
|
})
|
|
return session
|
|
}
|
|
|
|
type closeFunc func() error
|
|
|
|
func (c closeFunc) Close() error {
|
|
return c()
|
|
}
|
|
|
|
func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration) (
|
|
*codersdk.WorkspaceAgentConn,
|
|
*client,
|
|
<-chan *agentsdk.Stats,
|
|
afero.Fs,
|
|
io.Closer,
|
|
) {
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
if metadata.DERPMap == nil {
|
|
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
|
}
|
|
coordinator := tailnet.NewCoordinator(logger)
|
|
t.Cleanup(func() {
|
|
_ = coordinator.Close()
|
|
})
|
|
agentID := uuid.New()
|
|
statsCh := make(chan *agentsdk.Stats, 50)
|
|
fs := afero.NewMemMapFs()
|
|
c := &client{
|
|
t: t,
|
|
agentID: agentID,
|
|
manifest: metadata,
|
|
statsChan: statsCh,
|
|
coordinator: coordinator,
|
|
}
|
|
closer := agent.New(agent.Options{
|
|
Client: c,
|
|
Filesystem: fs,
|
|
Logger: logger.Named("agent"),
|
|
ReconnectingPTYTimeout: ptyTimeout,
|
|
})
|
|
t.Cleanup(func() {
|
|
_ = closer.Close()
|
|
})
|
|
conn, err := tailnet.NewConn(&tailnet.Options{
|
|
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
|
DERPMap: metadata.DERPMap,
|
|
Logger: logger.Named("client"),
|
|
})
|
|
require.NoError(t, err)
|
|
clientConn, serverConn := net.Pipe()
|
|
serveClientDone := make(chan struct{})
|
|
t.Cleanup(func() {
|
|
_ = clientConn.Close()
|
|
_ = serverConn.Close()
|
|
_ = conn.Close()
|
|
<-serveClientDone
|
|
})
|
|
go func() {
|
|
defer close(serveClientDone)
|
|
coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
|
}()
|
|
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
|
return conn.UpdateNodes(node, false)
|
|
})
|
|
conn.SetNodeCallback(sendNode)
|
|
agentConn := &codersdk.WorkspaceAgentConn{
|
|
Conn: conn,
|
|
}
|
|
t.Cleanup(func() {
|
|
_ = agentConn.Close()
|
|
})
|
|
// Ideally we wouldn't wait too long here, but sometimes the the
|
|
// networking needs more time to resolve itself.
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
if !agentConn.AwaitReachable(ctx) {
|
|
t.Fatal("agent not reachable")
|
|
}
|
|
return agentConn, c, statsCh, fs, closer
|
|
}
|
|
|
|
var dialTestPayload = []byte("dean-was-here123")
|
|
|
|
func testDial(t *testing.T, c net.Conn) {
|
|
t.Helper()
|
|
|
|
assertWritePayload(t, c, dialTestPayload)
|
|
assertReadPayload(t, c, dialTestPayload)
|
|
}
|
|
|
|
func testAccept(t *testing.T, c net.Conn) {
|
|
t.Helper()
|
|
defer c.Close()
|
|
|
|
assertReadPayload(t, c, dialTestPayload)
|
|
assertWritePayload(t, c, dialTestPayload)
|
|
}
|
|
|
|
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
|
|
b := make([]byte, len(payload)+16)
|
|
n, err := r.Read(b)
|
|
assert.NoError(t, err, "read payload")
|
|
assert.Equal(t, len(payload), n, "read payload length does not match")
|
|
assert.Equal(t, payload, b[:n])
|
|
}
|
|
|
|
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
|
|
n, err := w.Write(payload)
|
|
assert.NoError(t, err, "write payload")
|
|
assert.Equal(t, len(payload), n, "payload length does not match")
|
|
}
|
|
|
|
type client struct {
|
|
t *testing.T
|
|
agentID uuid.UUID
|
|
manifest agentsdk.Manifest
|
|
metadata map[string]agentsdk.PostMetadataRequest
|
|
statsChan chan *agentsdk.Stats
|
|
coordinator tailnet.Coordinator
|
|
lastWorkspaceAgent func()
|
|
patchWorkspaceLogs func() error
|
|
|
|
mu sync.Mutex // Protects following.
|
|
lifecycleStates []codersdk.WorkspaceAgentLifecycle
|
|
startup agentsdk.PostStartupRequest
|
|
logs []agentsdk.StartupLog
|
|
}
|
|
|
|
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
|
|
return c.manifest, nil
|
|
}
|
|
|
|
func (c *client) Listen(_ context.Context) (net.Conn, error) {
|
|
clientConn, serverConn := net.Pipe()
|
|
closed := make(chan struct{})
|
|
c.lastWorkspaceAgent = func() {
|
|
_ = serverConn.Close()
|
|
_ = clientConn.Close()
|
|
<-closed
|
|
}
|
|
c.t.Cleanup(c.lastWorkspaceAgent)
|
|
go func() {
|
|
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
|
|
close(closed)
|
|
}()
|
|
return clientConn, nil
|
|
}
|
|
|
|
func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
|
|
doneCh := make(chan struct{})
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
go func() {
|
|
defer close(doneCh)
|
|
|
|
setInterval(500 * time.Millisecond)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case stat := <-statsChan:
|
|
select {
|
|
case c.statsChan <- stat:
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
// We don't want to send old stats.
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
return closeFunc(func() error {
|
|
cancel()
|
|
<-doneCh
|
|
close(c.statsChan)
|
|
return nil
|
|
}), nil
|
|
}
|
|
|
|
func (c *client) getLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.lifecycleStates
|
|
}
|
|
|
|
func (c *client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.lifecycleStates = append(c.lifecycleStates, req.State)
|
|
return nil
|
|
}
|
|
|
|
func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
|
|
return nil
|
|
}
|
|
|
|
func (c *client) getStartup() agentsdk.PostStartupRequest {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.startup
|
|
}
|
|
|
|
func (c *client) getMetadata() map[string]agentsdk.PostMetadataRequest {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return maps.Clone(c.metadata)
|
|
}
|
|
|
|
func (c *client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.metadata == nil {
|
|
c.metadata = make(map[string]agentsdk.PostMetadataRequest)
|
|
}
|
|
c.metadata[key] = req
|
|
return nil
|
|
}
|
|
|
|
func (c *client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.startup = startup
|
|
return nil
|
|
}
|
|
|
|
func (c *client) getStartupLogs() []agentsdk.StartupLog {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.logs
|
|
}
|
|
|
|
func (c *client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.patchWorkspaceLogs != nil {
|
|
return c.patchWorkspaceLogs()
|
|
}
|
|
c.logs = append(c.logs, logs.Logs...)
|
|
return nil
|
|
}
|
|
|
|
// tempDirUnixSocket returns a temporary directory that can safely hold unix
|
|
// sockets (probably).
|
|
//
|
|
// During tests on darwin we hit the max path length limit for unix sockets
|
|
// pretty easily in the default location, so this function uses /tmp instead to
|
|
// get shorter paths.
|
|
func tempDirUnixSocket(t *testing.T) string {
|
|
t.Helper()
|
|
if runtime.GOOS == "darwin" {
|
|
testName := strings.ReplaceAll(t.Name(), "/", "_")
|
|
dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("coder-test-%s-", testName))
|
|
require.NoError(t, err, "create temp dir for gpg test")
|
|
|
|
t.Cleanup(func() {
|
|
err := os.RemoveAll(dir)
|
|
assert.NoError(t, err, "remove temp dir", dir)
|
|
})
|
|
return dir
|
|
}
|
|
|
|
return t.TempDir()
|
|
}
|