coder/agent/agent.go

1079 lines
30 KiB
Go

package agent
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"net/url"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/armon/circbuf"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/pkg/sftp"
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/pty"
"github.com/coder/coder/tailnet"
"github.com/coder/retry"
)
const (
ProtocolReconnectingPTY = "reconnecting-pty"
ProtocolSSH = "ssh"
ProtocolDial = "dial"
// MagicSessionErrorCode indicates that something went wrong with the session, rather than the
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
MagicSessionErrorCode = 229
)
var (
// tailnetIP is a static IPv6 address with the Tailscale prefix that is used to route
// connections from clients to this node. A dynamic address is not required because a Tailnet
// client only dials a single agent at a time.
tailnetIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
tailnetSSHPort = 1
tailnetReconnectingPTYPort = 2
tailnetSpeedtestPort = 3
)
type Options struct {
CoordinatorDialer CoordinatorDialer
WebRTCDialer WebRTCDialer
FetchMetadata FetchMetadata
StatsReporter StatsReporter
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
}
type Metadata struct {
DERPMap *tailcfg.DERPMap `json:"derpmap"`
EnvironmentVariables map[string]string `json:"environment_variables"`
StartupScript string `json:"startup_script"`
Directory string `json:"directory"`
}
type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
// CoordinatorDialer is a function that constructs a new broker.
// A dialer must be passed in to allow for reconnects.
type CoordinatorDialer func(ctx context.Context) (net.Conn, error)
// FetchMetadata is a function to obtain metadata for the agent.
type FetchMetadata func(ctx context.Context) (Metadata, error)
func New(options Options) io.Closer {
if options.ReconnectingPTYTimeout == 0 {
options.ReconnectingPTYTimeout = 5 * time.Minute
}
ctx, cancelFunc := context.WithCancel(context.Background())
server := &agent{
webrtcDialer: options.WebRTCDialer,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
envVars: options.EnvironmentVariables,
coordinatorDialer: options.CoordinatorDialer,
fetchMetadata: options.FetchMetadata,
stats: &Stats{},
statsReporter: options.StatsReporter,
}
server.init(ctx)
return server
}
type agent struct {
webrtcDialer WebRTCDialer
logger slog.Logger
reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration
connCloseWait sync.WaitGroup
closeCancel context.CancelFunc
closeMutex sync.Mutex
closed chan struct{}
envVars map[string]string
// metadata is atomic because values can change after reconnection.
metadata atomic.Value
fetchMetadata FetchMetadata
sshServer *ssh.Server
network *tailnet.Conn
coordinatorDialer CoordinatorDialer
stats *Stats
statsReporter StatsReporter
}
func (a *agent) run(ctx context.Context) {
var metadata Metadata
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
a.logger.Info(ctx, "connecting")
metadata, err = a.fetchMetadata(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if a.isClosed() {
return
}
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
a.logger.Info(context.Background(), "fetched metadata")
break
}
select {
case <-ctx.Done():
return
default:
}
a.metadata.Store(metadata)
// The startup script has not ran yet!
go func() {
err := a.runStartupScript(ctx, metadata.StartupScript)
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
}
}()
if a.webrtcDialer != nil {
go a.runWebRTCNetworking(ctx)
}
if metadata.DERPMap != nil {
go a.runTailnet(ctx, metadata.DERPMap)
}
}
func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
return
}
if a.network != nil {
a.network.SetDERPMap(derpMap)
return
}
var err error
a.network, err = tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnetIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
})
if err != nil {
a.logger.Critical(ctx, "create tailnet", slog.Error(err))
return
}
a.network.SetForwardTCPCallback(func(conn net.Conn, listenerExists bool) net.Conn {
if listenerExists {
// If a listener already exists, we would double-wrap the conn.
return conn
}
return a.stats.wrapConn(conn)
})
go a.runCoordinator(ctx)
sshListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetSSHPort))
if err != nil {
a.logger.Critical(ctx, "listen for ssh", slog.Error(err))
return
}
go func() {
for {
conn, err := sshListener.Accept()
if err != nil {
return
}
go a.sshServer.HandleConn(a.stats.wrapConn(conn))
}
}()
reconnectingPTYListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetReconnectingPTYPort))
if err != nil {
a.logger.Critical(ctx, "listen for reconnecting pty", slog.Error(err))
return
}
go func() {
for {
conn, err := reconnectingPTYListener.Accept()
if err != nil {
a.logger.Debug(ctx, "accept pty failed", slog.Error(err))
return
}
conn = a.stats.wrapConn(conn)
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err = conn.Read(rawLen)
if err != nil {
continue
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
continue
}
var msg reconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
continue
}
go a.handleReconnectingPTY(ctx, msg, conn)
}
}()
speedtestListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetSpeedtestPort))
if err != nil {
a.logger.Critical(ctx, "listen for speedtest", slog.Error(err))
return
}
go func() {
for {
conn, err := speedtestListener.Accept()
if err != nil {
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
return
}
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
go func() {
defer a.connCloseWait.Done()
_ = speedtest.ServeConn(conn)
}()
}
}()
}
// runCoordinator listens for nodes and updates the self-node as it changes.
func (a *agent) runCoordinator(ctx context.Context) {
var coordinator net.Conn
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
coordinator, err = a.coordinatorDialer(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if a.isClosed() {
return
}
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
a.logger.Info(context.Background(), "connected to coordination server")
break
}
select {
case <-ctx.Done():
return
default:
}
defer coordinator.Close()
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, a.network.UpdateNodes)
a.network.SetNodeCallback(sendNodes)
select {
case <-ctx.Done():
return
case err := <-errChan:
if a.isClosed() {
return
}
if errors.Is(err, context.Canceled) {
return
}
a.logger.Debug(ctx, "node broker accept exited; restarting connection", slog.Error(err))
a.runCoordinator(ctx)
return
}
}
func (a *agent) runWebRTCNetworking(ctx context.Context) {
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
peerListener, err = a.webrtcDialer(ctx, a.logger)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if a.isClosed() {
return
}
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
a.logger.Info(context.Background(), "connected to webrtc broker")
break
}
select {
case <-ctx.Done():
return
default:
}
for {
conn, err := peerListener.Accept()
if err != nil {
if a.isClosed() {
return
}
a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
a.runWebRTCNetworking(ctx)
return
}
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
go a.handlePeerConn(ctx, conn)
}
}
func (a *agent) runStartupScript(ctx context.Context, script string) error {
if script == "" {
return nil
}
writer, err := os.OpenFile(filepath.Join(os.TempDir(), "coder-startup-script.log"), os.O_CREATE|os.O_RDWR, 0o600)
if err != nil {
return xerrors.Errorf("open startup script log file: %w", err)
}
defer func() {
_ = writer.Close()
}()
cmd, err := a.createCommand(ctx, script, nil)
if err != nil {
return xerrors.Errorf("create command: %w", err)
}
cmd.Stdout = writer
cmd.Stderr = writer
err = cmd.Run()
if err != nil {
// cmd.Run does not return a context canceled error, it returns "signal: killed".
if ctx.Err() != nil {
return ctx.Err()
}
return xerrors.Errorf("run: %w", err)
}
return nil
}
func (a *agent) handlePeerConn(ctx context.Context, peerConn *peer.Conn) {
go func() {
select {
case <-a.closed:
case <-peerConn.Closed():
}
_ = peerConn.Close()
a.connCloseWait.Done()
}()
for {
channel, err := peerConn.Accept(ctx)
if err != nil {
if errors.Is(err, peer.ErrClosed) || a.isClosed() {
return
}
a.logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
return
}
conn := channel.NetConn()
switch channel.Protocol() {
case ProtocolSSH:
go a.sshServer.HandleConn(a.stats.wrapConn(conn))
case ProtocolReconnectingPTY:
rawID := channel.Label()
// The ID format is referenced in conn.go.
// <uuid>:<height>:<width>
idParts := strings.SplitN(rawID, ":", 4)
if len(idParts) != 4 {
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
continue
}
id := idParts[0]
// Enforce a consistent format for IDs.
_, err := uuid.Parse(id)
if err != nil {
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
continue
}
// Parse the initial terminal dimensions.
height, err := strconv.Atoi(idParts[1])
if err != nil {
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
continue
}
width, err := strconv.Atoi(idParts[2])
if err != nil {
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
continue
}
go a.handleReconnectingPTY(ctx, reconnectingPTYInit{
ID: id,
Height: uint16(height),
Width: uint16(width),
Command: idParts[3],
}, a.stats.wrapConn(conn))
case ProtocolDial:
go a.handleDial(ctx, channel.Label(), a.stats.wrapConn(conn))
default:
a.logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
slog.F("label", channel.Label()),
)
}
}
}
func (a *agent) init(ctx context.Context) {
a.logger.Info(ctx, "generating host key")
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
if err != nil {
panic(err)
}
sshLogger := a.logger.Named("ssh-server")
forwardHandler := &ssh.ForwardedTCPHandler{}
a.sshServer = &ssh.Server{
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": ssh.DirectTCPIPHandler,
"session": ssh.DefaultSessionHandler,
},
ConnectionFailedCallback: func(conn net.Conn, err error) {
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
},
Handler: func(session ssh.Session) {
err := a.handleSSHSession(session)
var exitError *exec.ExitError
if xerrors.As(err, &exitError) {
a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
_ = session.Exit(exitError.ExitCode())
return
}
if err != nil {
a.logger.Warn(ctx, "ssh session failed", slog.Error(err))
// This exit code is designed to be unlikely to be confused for a legit exit code
// from the process.
_ = session.Exit(MagicSessionErrorCode)
return
}
},
HostSigners: []ssh.Signer{randomSigner},
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
// Allow local port forwarding all!
sshLogger.Debug(ctx, "local port forward",
slog.F("destination-host", destinationHost),
slog.F("destination-port", destinationPort))
return true
},
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
return true
},
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
// Allow reverse port forwarding all!
sshLogger.Debug(ctx, "local port forward",
slog.F("bind-host", bindHost),
slog.F("bind-port", bindPort))
return true
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
},
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return &gossh.ServerConfig{
NoClientAuth: true,
}
},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": func(session ssh.Session) {
session.DisablePTYEmulation()
server, err := sftp.NewServer(session)
if err != nil {
a.logger.Debug(session.Context(), "initialize sftp server", slog.Error(err))
return
}
defer server.Close()
err = server.Serve()
if errors.Is(err, io.EOF) {
return
}
a.logger.Debug(session.Context(), "sftp server exited with error", slog.Error(err))
},
},
}
go a.run(ctx)
if a.statsReporter != nil {
cl, err := a.statsReporter(ctx, a.logger, func() *Stats {
return a.stats.Copy()
})
if err != nil {
a.logger.Error(ctx, "report stats", slog.Error(err))
return
}
a.connCloseWait.Add(1)
go func() {
defer a.connCloseWait.Done()
<-a.closed
cl.Close()
}()
}
}
// createCommand processes raw command input with OpenSSH-like behavior.
// If the rawCommand provided is empty, it will default to the users shell.
// This injects environment variables specified by the user at launch too.
func (a *agent) createCommand(ctx context.Context, rawCommand string, env []string) (*exec.Cmd, error) {
currentUser, err := user.Current()
if err != nil {
return nil, xerrors.Errorf("get current user: %w", err)
}
username := currentUser.Username
shell, err := usershell.Get(username)
if err != nil {
return nil, xerrors.Errorf("get user shell: %w", err)
}
rawMetadata := a.metadata.Load()
if rawMetadata == nil {
return nil, xerrors.Errorf("no metadata was provided: %w", err)
}
metadata, valid := rawMetadata.(Metadata)
if !valid {
return nil, xerrors.Errorf("metadata is the wrong type: %T", metadata)
}
// gliderlabs/ssh returns a command slice of zero
// when a shell is requested.
command := rawCommand
if len(command) == 0 {
command = shell
if runtime.GOOS != "windows" {
// On Linux and macOS, we should start a login
// shell to consume juicy environment variables!
command += " -l"
}
}
// OpenSSH executes all commands with the users current shell.
// We replicate that behavior for IDE support.
caller := "-c"
if runtime.GOOS == "windows" {
caller = "/c"
}
cmd := exec.CommandContext(ctx, shell, caller, command)
cmd.Dir = metadata.Directory
if cmd.Dir == "" {
// Default to $HOME if a directory is not set!
cmd.Dir = os.Getenv("HOME")
}
cmd.Env = append(os.Environ(), env...)
executablePath, err := os.Executable()
if err != nil {
return nil, xerrors.Errorf("getting os executable: %w", err)
}
// Set environment variables reliable detection of being inside a
// Coder workspace.
cmd.Env = append(cmd.Env, "CODER=true")
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
// Set SSH connection environment variables (these are also set by OpenSSH
// and thus expected to be present by SSH clients). Since the agent does
// networking in-memory, trying to provide accurate values here would be
// nonsensical. For now, we hard code these values so that they're present.
srcAddr, srcPort := "0.0.0.0", "0"
dstAddr, dstPort := "0.0.0.0", "0"
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort))
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort))
// Load environment variables passed via the agent.
// These should override all variables we manually specify.
for envKey, value := range metadata.EnvironmentVariables {
// Expanding environment variables allows for customization
// of the $PATH, among other variables. Customers can prepend
// or append to the $PATH, so allowing expand is required!
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value)))
}
// Agent-level environment variables should take over all!
// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
for envKey, value := range a.envVars {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
}
return cmd, nil
}
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
ctx := session.Context()
cmd, err := a.createCommand(ctx, session.RawCommand(), session.Environ())
if err != nil {
return err
}
if ssh.AgentRequested(session) {
l, err := ssh.NewAgentListener()
if err != nil {
return xerrors.Errorf("new agent listener: %w", err)
}
defer l.Close()
go ssh.ForwardAgentConnections(l, session)
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
}
sshPty, windowSize, isPty := session.Pty()
if isPty {
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
// See https://github.com/coder/coder/issues/3371.
session.DisablePTYEmulation()
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
// The pty package sets `SSH_TTY` on supported platforms.
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
pty.WithSSHRequest(sshPty),
pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)),
))
if err != nil {
return xerrors.Errorf("start command: %w", err)
}
defer func() {
closeErr := ptty.Close()
if closeErr != nil {
a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}
}()
go func() {
for win := range windowSize {
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
if resizeErr != nil {
a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
}
}
}()
go func() {
_, _ = io.Copy(ptty.Input(), session)
}()
go func() {
_, _ = io.Copy(session, ptty.Output())
}()
err = process.Wait()
var exitErr *exec.ExitError
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
// and not something to be concerned about. But, if it's something else, we should log it.
if err != nil && !xerrors.As(err, &exitErr) {
a.logger.Warn(ctx, "wait error", slog.Error(err))
}
return err
}
cmd.Stdout = session
cmd.Stderr = session.Stderr()
// This blocks forever until stdin is received if we don't
// use StdinPipe. It's unknown what causes this.
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return xerrors.Errorf("create stdin pipe: %w", err)
}
go func() {
_, _ = io.Copy(stdinPipe, session)
_ = stdinPipe.Close()
}()
err = cmd.Start()
if err != nil {
return xerrors.Errorf("start: %w", err)
}
return cmd.Wait()
}
func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYInit, conn net.Conn) {
defer conn.Close()
var rpty *reconnectingPTY
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
if ok {
rpty, ok = rawRPTY.(*reconnectingPTY)
if !ok {
a.logger.Error(ctx, "found invalid type in reconnecting pty map", slog.F("id", msg.ID))
return
}
} else {
// Empty command will default to the users shell!
cmd, err := a.createCommand(ctx, msg.Command, nil)
if err != nil {
a.logger.Error(ctx, "create reconnecting pty command", slog.Error(err))
return
}
cmd.Env = append(cmd.Env, "TERM=xterm-256color")
// Default to buffer 64KiB.
circularBuffer, err := circbuf.NewBuffer(64 << 10)
if err != nil {
a.logger.Error(ctx, "create circular buffer", slog.Error(err))
return
}
ptty, process, err := pty.Start(cmd)
if err != nil {
a.logger.Error(ctx, "start reconnecting pty command", slog.F("id", msg.ID))
return
}
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
ctx, cancelFunc := context.WithCancel(ctx)
rpty = &reconnectingPTY{
activeConns: make(map[string]net.Conn),
ptty: ptty,
// Timeouts created with an after func can be reset!
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
circularBuffer: circularBuffer,
}
a.reconnectingPTYs.Store(msg.ID, rpty)
go func() {
// CommandContext isn't respected for Windows PTYs right now,
// so we need to manually track the lifecycle.
// When the context has been completed either:
// 1. The timeout completed.
// 2. The parent context was canceled.
<-ctx.Done()
_ = process.Kill()
}()
go func() {
// If the process dies randomly, we should
// close the pty.
_ = process.Wait()
rpty.Close()
}()
go func() {
buffer := make([]byte, 1024)
for {
read, err := rpty.ptty.Output().Read(buffer)
if err != nil {
// When the PTY is closed, this is triggered.
break
}
part := buffer[:read]
rpty.circularBufferMutex.Lock()
_, err = rpty.circularBuffer.Write(part)
rpty.circularBufferMutex.Unlock()
if err != nil {
a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", msg.ID))
break
}
rpty.activeConnsMutex.Lock()
for _, conn := range rpty.activeConns {
_, _ = conn.Write(part)
}
rpty.activeConnsMutex.Unlock()
}
// Cleanup the process, PTY, and delete it's
// ID from memory.
_ = process.Kill()
rpty.Close()
a.reconnectingPTYs.Delete(msg.ID)
a.connCloseWait.Done()
}()
}
// Resize the PTY to initial height + width.
err := rpty.ptty.Resize(msg.Height, msg.Width)
if err != nil {
// We can continue after this, it's not fatal!
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", msg.ID), slog.Error(err))
}
// Write any previously stored data for the TTY.
rpty.circularBufferMutex.RLock()
_, err = conn.Write(rpty.circularBuffer.Bytes())
rpty.circularBufferMutex.RUnlock()
if err != nil {
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", msg.ID), slog.Error(err))
return
}
connectionID := uuid.NewString()
// Multiple connections to the same TTY are permitted.
// This could easily be used for terminal sharing, but
// we do it because it's a nice user experience to
// copy/paste a terminal URL and have it _just work_.
rpty.activeConnsMutex.Lock()
rpty.activeConns[connectionID] = conn
rpty.activeConnsMutex.Unlock()
// Resetting this timeout prevents the PTY from exiting.
rpty.timeout.Reset(a.reconnectingPTYTimeout)
ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2)
defer heartbeat.Stop()
go func() {
// Keep updating the activity while this
// connection is alive!
for {
select {
case <-ctx.Done():
return
case <-heartbeat.C:
}
rpty.timeout.Reset(a.reconnectingPTYTimeout)
}
}()
defer func() {
// After this connection ends, remove it from
// the PTYs active connections. If it isn't
// removed, all PTY data will be sent to it.
rpty.activeConnsMutex.Lock()
delete(rpty.activeConns, connectionID)
rpty.activeConnsMutex.Unlock()
}()
decoder := json.NewDecoder(conn)
var req ReconnectingPTYRequest
for {
err = decoder.Decode(&req)
if xerrors.Is(err, io.EOF) {
return
}
if err != nil {
a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", msg.ID), slog.Error(err))
return
}
_, err = rpty.ptty.Input().Write([]byte(req.Data))
if err != nil {
a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", msg.ID), slog.Error(err))
return
}
// Check if a resize needs to happen!
if req.Height == 0 || req.Width == 0 {
continue
}
err = rpty.ptty.Resize(req.Height, req.Width)
if err != nil {
// We can continue after this, it's not fatal!
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", msg.ID), slog.Error(err))
}
}
}
// dialResponse is written to datachannels with protocol "dial" by the agent as
// the first packet to signify whether the dial succeeded or failed.
type dialResponse struct {
Error string `json:"error,omitempty"`
}
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
defer conn.Close()
writeError := func(responseError error) error {
msg := ""
if responseError != nil {
msg = responseError.Error()
if !xerrors.Is(responseError, io.EOF) {
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
}
}
b, err := json.Marshal(dialResponse{
Error: msg,
})
if err != nil {
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
}
_, err = conn.Write(b)
return err
}
u, err := url.Parse(label)
if err != nil {
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
return
}
network := u.Scheme
addr := u.Host + u.Path
if strings.HasPrefix(network, "unix") {
if runtime.GOOS == "windows" {
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
return
}
addr, err = ExpandRelativeHomePath(addr)
if err != nil {
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
return
}
}
d := net.Dialer{Timeout: 3 * time.Second}
nconn, err := d.DialContext(ctx, network, addr)
if err != nil {
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
return
}
err = writeError(nil)
if err != nil {
return
}
Bicopy(ctx, conn, nconn)
}
// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
case <-a.closed:
return true
default:
return false
}
}
func (a *agent) Close() error {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
return nil
}
close(a.closed)
a.closeCancel()
if a.network != nil {
_ = a.network.Close()
}
_ = a.sshServer.Close()
a.connCloseWait.Wait()
return nil
}
type reconnectingPTY struct {
activeConnsMutex sync.Mutex
activeConns map[string]net.Conn
circularBuffer *circbuf.Buffer
circularBufferMutex sync.RWMutex
timeout *time.Timer
ptty pty.PTY
}
// Close ends all connections to the reconnecting
// PTY and clear the circular buffer.
func (r *reconnectingPTY) Close() {
r.activeConnsMutex.Lock()
defer r.activeConnsMutex.Unlock()
for _, conn := range r.activeConns {
_ = conn.Close()
}
_ = r.ptty.Close()
r.circularBufferMutex.Lock()
r.circularBuffer.Reset()
r.circularBufferMutex.Unlock()
r.timeout.Stop()
}
// Bicopy copies all of the data between the two connections and will close them
// after one or both of them are done writing. If the context is canceled, both
// of the connections will be closed.
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
defer c1.Close()
defer c2.Close()
var wg sync.WaitGroup
copyFunc := func(dst io.WriteCloser, src io.Reader) {
defer wg.Done()
_, _ = io.Copy(dst, src)
}
wg.Add(2)
go copyFunc(c1, c2)
go copyFunc(c2, c1)
// Convert waitgroup to a channel so we can also wait on the context.
done := make(chan struct{})
go func() {
defer close(done)
wg.Wait()
}()
select {
case <-ctx.Done():
case <-done:
}
}
// ExpandRelativeHomePath expands the tilde at the beginning of a path to the
// current user's home directory and returns a full absolute path.
func ExpandRelativeHomePath(in string) (string, error) {
usr, err := user.Current()
if err != nil {
return "", xerrors.Errorf("get current user details: %w", err)
}
if in == "~" {
in = usr.HomeDir
} else if strings.HasPrefix(in, "~/") {
in = filepath.Join(usr.HomeDir, in[2:])
}
return filepath.Abs(in)
}