mirror of https://github.com/coder/coder.git
1458 lines
40 KiB
Go
1458 lines
40 KiB
Go
package agent
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/armon/circbuf"
|
|
"github.com/google/uuid"
|
|
"github.com/spf13/afero"
|
|
"go.uber.org/atomic"
|
|
"golang.org/x/exp/slices"
|
|
"golang.org/x/xerrors"
|
|
"tailscale.com/net/speedtest"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/types/netlogtype"
|
|
|
|
"cdr.dev/slog"
|
|
"github.com/coder/coder/agent/agentssh"
|
|
"github.com/coder/coder/buildinfo"
|
|
"github.com/coder/coder/coderd/database"
|
|
"github.com/coder/coder/coderd/gitauth"
|
|
"github.com/coder/coder/codersdk"
|
|
"github.com/coder/coder/codersdk/agentsdk"
|
|
"github.com/coder/coder/pty"
|
|
"github.com/coder/coder/tailnet"
|
|
"github.com/coder/retry"
|
|
)
|
|
|
|
const (
|
|
ProtocolReconnectingPTY = "reconnecting-pty"
|
|
ProtocolSSH = "ssh"
|
|
ProtocolDial = "dial"
|
|
)
|
|
|
|
type Options struct {
|
|
Filesystem afero.Fs
|
|
LogDir string
|
|
TempDir string
|
|
ExchangeToken func(ctx context.Context) (string, error)
|
|
Client Client
|
|
ReconnectingPTYTimeout time.Duration
|
|
EnvironmentVariables map[string]string
|
|
Logger slog.Logger
|
|
IgnorePorts map[int]string
|
|
SSHMaxTimeout time.Duration
|
|
TailnetListenPort uint16
|
|
}
|
|
|
|
type Client interface {
|
|
Manifest(ctx context.Context) (agentsdk.Manifest, error)
|
|
Listen(ctx context.Context) (net.Conn, error)
|
|
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
|
|
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
|
|
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
|
|
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
|
|
PostMetadata(ctx context.Context, key string, req agentsdk.PostMetadataRequest) error
|
|
PatchStartupLogs(ctx context.Context, req agentsdk.PatchStartupLogs) error
|
|
}
|
|
|
|
type Agent interface {
|
|
HTTPDebug() http.Handler
|
|
io.Closer
|
|
}
|
|
|
|
func New(options Options) Agent {
|
|
if options.ReconnectingPTYTimeout == 0 {
|
|
options.ReconnectingPTYTimeout = 5 * time.Minute
|
|
}
|
|
if options.Filesystem == nil {
|
|
options.Filesystem = afero.NewOsFs()
|
|
}
|
|
if options.TempDir == "" {
|
|
options.TempDir = os.TempDir()
|
|
}
|
|
if options.LogDir == "" {
|
|
if options.TempDir != os.TempDir() {
|
|
options.Logger.Debug(context.Background(), "log dir not set, using temp dir", slog.F("temp_dir", options.TempDir))
|
|
}
|
|
options.LogDir = options.TempDir
|
|
}
|
|
if options.ExchangeToken == nil {
|
|
options.ExchangeToken = func(ctx context.Context) (string, error) {
|
|
return "", nil
|
|
}
|
|
}
|
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
|
a := &agent{
|
|
tailnetListenPort: options.TailnetListenPort,
|
|
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
|
|
logger: options.Logger,
|
|
closeCancel: cancelFunc,
|
|
closed: make(chan struct{}),
|
|
envVars: options.EnvironmentVariables,
|
|
client: options.Client,
|
|
exchangeToken: options.ExchangeToken,
|
|
filesystem: options.Filesystem,
|
|
logDir: options.LogDir,
|
|
tempDir: options.TempDir,
|
|
lifecycleUpdate: make(chan struct{}, 1),
|
|
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
|
|
ignorePorts: options.IgnorePorts,
|
|
connStatsChan: make(chan *agentsdk.Stats, 1),
|
|
sshMaxTimeout: options.SSHMaxTimeout,
|
|
}
|
|
a.init(ctx)
|
|
return a
|
|
}
|
|
|
|
type agent struct {
|
|
logger slog.Logger
|
|
client Client
|
|
exchangeToken func(ctx context.Context) (string, error)
|
|
tailnetListenPort uint16
|
|
filesystem afero.Fs
|
|
logDir string
|
|
tempDir string
|
|
// ignorePorts tells the api handler which ports to ignore when
|
|
// listing all listening ports. This is helpful to hide ports that
|
|
// are used by the agent, that the user does not care about.
|
|
ignorePorts map[int]string
|
|
|
|
reconnectingPTYs sync.Map
|
|
reconnectingPTYTimeout time.Duration
|
|
|
|
connCloseWait sync.WaitGroup
|
|
closeCancel context.CancelFunc
|
|
closeMutex sync.Mutex
|
|
closed chan struct{}
|
|
|
|
envVars map[string]string
|
|
// manifest is atomic because values can change after reconnection.
|
|
manifest atomic.Pointer[agentsdk.Manifest]
|
|
sessionToken atomic.Pointer[string]
|
|
sshServer *agentssh.Server
|
|
sshMaxTimeout time.Duration
|
|
|
|
lifecycleUpdate chan struct{}
|
|
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
|
|
lifecycleMu sync.RWMutex // Protects following.
|
|
lifecycleState codersdk.WorkspaceAgentLifecycle
|
|
|
|
network *tailnet.Conn
|
|
connStatsChan chan *agentsdk.Stats
|
|
latestStat atomic.Pointer[agentsdk.Stats]
|
|
|
|
connCountReconnectingPTY atomic.Int64
|
|
}
|
|
|
|
func (a *agent) init(ctx context.Context) {
|
|
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.filesystem, a.sshMaxTimeout, "")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
sshSrv.Env = a.envVars
|
|
sshSrv.AgentToken = func() string { return *a.sessionToken.Load() }
|
|
sshSrv.Manifest = &a.manifest
|
|
a.sshServer = sshSrv
|
|
|
|
go a.runLoop(ctx)
|
|
}
|
|
|
|
// runLoop attempts to start the agent in a retry loop.
|
|
// Coder may be offline temporarily, a connection issue
|
|
// may be happening, but regardless after the intermittent
|
|
// failure, you'll want the agent to reconnect.
|
|
func (a *agent) runLoop(ctx context.Context) {
|
|
go a.reportLifecycleLoop(ctx)
|
|
go a.reportMetadataLoop(ctx)
|
|
|
|
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
|
a.logger.Info(ctx, "connecting to coderd")
|
|
err := a.run(ctx)
|
|
// Cancel after the run is complete to clean up any leaked resources!
|
|
if err == nil {
|
|
continue
|
|
}
|
|
if errors.Is(err, context.Canceled) {
|
|
return
|
|
}
|
|
if a.isClosed() {
|
|
return
|
|
}
|
|
if errors.Is(err, io.EOF) {
|
|
a.logger.Info(ctx, "disconnected from coderd")
|
|
continue
|
|
}
|
|
a.logger.Warn(ctx, "run exited with error", slog.Error(err))
|
|
}
|
|
}
|
|
|
|
func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentMetadataDescription) *codersdk.WorkspaceAgentMetadataResult {
|
|
var out bytes.Buffer
|
|
result := &codersdk.WorkspaceAgentMetadataResult{
|
|
// CollectedAt is set here for testing purposes and overrode by
|
|
// coderd to the time of server receipt to solve clock skew.
|
|
//
|
|
// In the future, the server may accept the timestamp from the agent
|
|
// if it can guarantee the clocks are synchronized.
|
|
CollectedAt: time.Now(),
|
|
}
|
|
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
|
|
if err != nil {
|
|
result.Error = fmt.Sprintf("create cmd: %+v", err)
|
|
return result
|
|
}
|
|
cmd := cmdPty.AsExec()
|
|
|
|
cmd.Stdout = &out
|
|
cmd.Stderr = &out
|
|
cmd.Stdin = io.LimitReader(nil, 0)
|
|
|
|
// We split up Start and Wait instead of calling Run so that we can return a more precise error.
|
|
err = cmd.Start()
|
|
if err != nil {
|
|
result.Error = fmt.Sprintf("start cmd: %+v", err)
|
|
return result
|
|
}
|
|
|
|
// This error isn't mutually exclusive with useful output.
|
|
err = cmd.Wait()
|
|
const bufLimit = 10 << 10
|
|
if out.Len() > bufLimit {
|
|
err = errors.Join(
|
|
err,
|
|
xerrors.Errorf("output truncated from %v to %v bytes", out.Len(), bufLimit),
|
|
)
|
|
out.Truncate(bufLimit)
|
|
}
|
|
|
|
// Important: if the command times out, we may see a misleading error like
|
|
// "exit status 1", so it's important to include the context error.
|
|
err = errors.Join(err, ctx.Err())
|
|
|
|
if err != nil {
|
|
result.Error = fmt.Sprintf("run cmd: %+v", err)
|
|
}
|
|
result.Value = out.String()
|
|
return result
|
|
}
|
|
|
|
func adjustIntervalForTests(i int64) time.Duration {
|
|
// In tests we want to set shorter intervals because engineers are
|
|
// impatient.
|
|
base := time.Second
|
|
if flag.Lookup("test.v") != nil {
|
|
base = time.Millisecond * 100
|
|
}
|
|
return time.Duration(i) * base
|
|
}
|
|
|
|
type metadataResultAndKey struct {
|
|
result *codersdk.WorkspaceAgentMetadataResult
|
|
key string
|
|
}
|
|
|
|
type trySingleflight struct {
|
|
m sync.Map
|
|
}
|
|
|
|
func (t *trySingleflight) Do(key string, fn func()) {
|
|
_, loaded := t.m.LoadOrStore(key, struct{}{})
|
|
if !loaded {
|
|
// There is already a goroutine running for this key.
|
|
return
|
|
}
|
|
|
|
defer t.m.Delete(key)
|
|
fn()
|
|
}
|
|
|
|
func (a *agent) reportMetadataLoop(ctx context.Context) {
|
|
baseInterval := adjustIntervalForTests(1)
|
|
|
|
const metadataLimit = 128
|
|
|
|
var (
|
|
baseTicker = time.NewTicker(baseInterval)
|
|
lastCollectedAts = make(map[string]time.Time)
|
|
metadataResults = make(chan metadataResultAndKey, metadataLimit)
|
|
)
|
|
defer baseTicker.Stop()
|
|
|
|
// We use a custom singleflight that immediately returns if there is already
|
|
// a goroutine running for a given key. This is to prevent a build-up of
|
|
// goroutines waiting on Do when the script takes many multiples of
|
|
// baseInterval to run.
|
|
var flight trySingleflight
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case mr := <-metadataResults:
|
|
lastCollectedAts[mr.key] = mr.result.CollectedAt
|
|
err := a.client.PostMetadata(ctx, mr.key, *mr.result)
|
|
if err != nil {
|
|
a.logger.Error(ctx, "report metadata", slog.Error(err))
|
|
}
|
|
case <-baseTicker.C:
|
|
}
|
|
|
|
if len(metadataResults) > 0 {
|
|
// The inner collection loop expects the channel is empty before spinning up
|
|
// all the collection goroutines.
|
|
a.logger.Debug(
|
|
ctx, "metadata collection backpressured",
|
|
slog.F("queue_len", len(metadataResults)),
|
|
)
|
|
continue
|
|
}
|
|
|
|
manifest := a.manifest.Load()
|
|
if manifest == nil {
|
|
continue
|
|
}
|
|
|
|
if len(manifest.Metadata) > metadataLimit {
|
|
a.logger.Error(
|
|
ctx, "metadata limit exceeded",
|
|
slog.F("limit", metadataLimit), slog.F("got", len(manifest.Metadata)),
|
|
)
|
|
continue
|
|
}
|
|
|
|
// If the manifest changes (e.g. on agent reconnect) we need to
|
|
// purge old cache values to prevent lastCollectedAt from growing
|
|
// boundlessly.
|
|
for key := range lastCollectedAts {
|
|
if slices.IndexFunc(manifest.Metadata, func(md codersdk.WorkspaceAgentMetadataDescription) bool {
|
|
return md.Key == key
|
|
}) < 0 {
|
|
delete(lastCollectedAts, key)
|
|
}
|
|
}
|
|
|
|
// Spawn a goroutine for each metadata collection, and use a
|
|
// channel to synchronize the results and avoid both messy
|
|
// mutex logic and overloading the API.
|
|
for _, md := range manifest.Metadata {
|
|
collectedAt, ok := lastCollectedAts[md.Key]
|
|
if ok {
|
|
// If the interval is zero, we assume the user just wants
|
|
// a single collection at startup, not a spinning loop.
|
|
if md.Interval == 0 {
|
|
continue
|
|
}
|
|
// The last collected value isn't quite stale yet, so we skip it.
|
|
if collectedAt.Add(
|
|
adjustIntervalForTests(md.Interval),
|
|
).After(time.Now()) {
|
|
continue
|
|
}
|
|
}
|
|
|
|
md := md
|
|
// We send the result to the channel in the goroutine to avoid
|
|
// sending the same result multiple times. So, we don't care about
|
|
// the return values.
|
|
go flight.Do(md.Key, func() {
|
|
timeout := md.Timeout
|
|
if timeout == 0 {
|
|
timeout = md.Interval
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx,
|
|
time.Duration(timeout)*time.Second,
|
|
)
|
|
defer cancel()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
case metadataResults <- metadataResultAndKey{
|
|
key: md.Key,
|
|
result: a.collectMetadata(ctx, md),
|
|
}:
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// reportLifecycleLoop reports the current lifecycle state once.
|
|
// Only the latest state is reported, intermediate states may be
|
|
// lost if the agent can't communicate with the API.
|
|
func (a *agent) reportLifecycleLoop(ctx context.Context) {
|
|
var lastReported codersdk.WorkspaceAgentLifecycle
|
|
for {
|
|
select {
|
|
case <-a.lifecycleUpdate:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
|
|
for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); {
|
|
a.lifecycleMu.RLock()
|
|
state := a.lifecycleState
|
|
a.lifecycleMu.RUnlock()
|
|
|
|
if state == lastReported {
|
|
break
|
|
}
|
|
|
|
a.logger.Debug(ctx, "reporting lifecycle state", slog.F("state", state))
|
|
|
|
err := a.client.PostLifecycle(ctx, agentsdk.PostLifecycleRequest{
|
|
State: state,
|
|
})
|
|
if err == nil {
|
|
lastReported = state
|
|
select {
|
|
case a.lifecycleReported <- state:
|
|
case <-a.lifecycleReported:
|
|
a.lifecycleReported <- state
|
|
}
|
|
break
|
|
}
|
|
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
|
return
|
|
}
|
|
// If we fail to report the state we probably shouldn't exit, log only.
|
|
a.logger.Error(ctx, "post state", slog.Error(err))
|
|
}
|
|
}
|
|
}
|
|
|
|
// setLifecycle sets the lifecycle state and notifies the lifecycle loop.
|
|
// The state is only updated if it's a valid state transition.
|
|
func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentLifecycle) {
|
|
a.lifecycleMu.Lock()
|
|
lastState := a.lifecycleState
|
|
if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastState) > slices.Index(codersdk.WorkspaceAgentLifecycleOrder, state) {
|
|
a.logger.Warn(ctx, "attempted to set lifecycle state to a previous state", slog.F("last", lastState), slog.F("state", state))
|
|
a.lifecycleMu.Unlock()
|
|
return
|
|
}
|
|
a.lifecycleState = state
|
|
a.logger.Debug(ctx, "set lifecycle state", slog.F("state", state), slog.F("last", lastState))
|
|
a.lifecycleMu.Unlock()
|
|
|
|
select {
|
|
case a.lifecycleUpdate <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (a *agent) run(ctx context.Context) error {
|
|
// This allows the agent to refresh it's token if necessary.
|
|
// For instance identity this is required, since the instance
|
|
// may not have re-provisioned, but a new agent ID was created.
|
|
sessionToken, err := a.exchangeToken(ctx)
|
|
if err != nil {
|
|
return xerrors.Errorf("exchange token: %w", err)
|
|
}
|
|
a.sessionToken.Store(&sessionToken)
|
|
|
|
manifest, err := a.client.Manifest(ctx)
|
|
if err != nil {
|
|
return xerrors.Errorf("fetch metadata: %w", err)
|
|
}
|
|
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", manifest))
|
|
|
|
// Expand the directory and send it back to coderd so external
|
|
// applications that rely on the directory can use it.
|
|
//
|
|
// An example is VS Code Remote, which must know the directory
|
|
// before initializing a connection.
|
|
manifest.Directory, err = expandDirectory(manifest.Directory)
|
|
if err != nil {
|
|
return xerrors.Errorf("expand directory: %w", err)
|
|
}
|
|
err = a.client.PostStartup(ctx, agentsdk.PostStartupRequest{
|
|
Version: buildinfo.Version(),
|
|
ExpandedDirectory: manifest.Directory,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("update workspace agent version: %w", err)
|
|
}
|
|
|
|
oldManifest := a.manifest.Swap(&manifest)
|
|
|
|
// The startup script should only execute on the first run!
|
|
if oldManifest == nil {
|
|
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStarting)
|
|
|
|
// Perform overrides early so that Git auth can work even if users
|
|
// connect to a workspace that is not yet ready. We don't run this
|
|
// concurrently with the startup script to avoid conflicts between
|
|
// them.
|
|
if manifest.GitAuthConfigs > 0 {
|
|
// If this fails, we should consider surfacing the error in the
|
|
// startup log and setting the lifecycle state to be "start_error"
|
|
// (after startup script completion), but for now we'll just log it.
|
|
err := gitauth.OverrideVSCodeConfigs(a.filesystem)
|
|
if err != nil {
|
|
a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err))
|
|
}
|
|
}
|
|
|
|
lifecycleState := codersdk.WorkspaceAgentLifecycleReady
|
|
scriptDone := make(chan error, 1)
|
|
scriptStart := time.Now()
|
|
err = a.trackConnGoroutine(func() {
|
|
defer close(scriptDone)
|
|
scriptDone <- a.runStartupScript(ctx, manifest.StartupScript)
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("track startup script: %w", err)
|
|
}
|
|
go func() {
|
|
var timeout <-chan time.Time
|
|
// If timeout is zero, an older version of the coder
|
|
// provider was used. Otherwise a timeout is always > 0.
|
|
if manifest.StartupScriptTimeout > 0 {
|
|
t := time.NewTimer(manifest.StartupScriptTimeout)
|
|
defer t.Stop()
|
|
timeout = t.C
|
|
}
|
|
|
|
var err error
|
|
select {
|
|
case err = <-scriptDone:
|
|
case <-timeout:
|
|
a.logger.Warn(ctx, "startup script timed out")
|
|
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartTimeout)
|
|
err = <-scriptDone // The script can still complete after a timeout.
|
|
}
|
|
if errors.Is(err, context.Canceled) {
|
|
return
|
|
}
|
|
// Only log if there was a startup script.
|
|
if manifest.StartupScript != "" {
|
|
execTime := time.Since(scriptStart)
|
|
if err != nil {
|
|
a.logger.Warn(ctx, "startup script failed", slog.F("execution_time", execTime), slog.Error(err))
|
|
lifecycleState = codersdk.WorkspaceAgentLifecycleStartError
|
|
} else {
|
|
a.logger.Info(ctx, "startup script completed", slog.F("execution_time", execTime))
|
|
}
|
|
}
|
|
a.setLifecycle(ctx, lifecycleState)
|
|
}()
|
|
}
|
|
|
|
// This automatically closes when the context ends!
|
|
appReporterCtx, appReporterCtxCancel := context.WithCancel(ctx)
|
|
defer appReporterCtxCancel()
|
|
go NewWorkspaceAppHealthReporter(
|
|
a.logger, manifest.Apps, a.client.PostAppHealth)(appReporterCtx)
|
|
|
|
a.closeMutex.Lock()
|
|
network := a.network
|
|
a.closeMutex.Unlock()
|
|
if network == nil {
|
|
network, err = a.createTailnet(ctx, manifest.DERPMap)
|
|
if err != nil {
|
|
return xerrors.Errorf("create tailnet: %w", err)
|
|
}
|
|
a.closeMutex.Lock()
|
|
// Re-check if agent was closed while initializing the network.
|
|
closed := a.isClosed()
|
|
if !closed {
|
|
a.network = network
|
|
}
|
|
a.closeMutex.Unlock()
|
|
if closed {
|
|
_ = network.Close()
|
|
return xerrors.New("agent is closed")
|
|
}
|
|
|
|
a.startReportingConnectionStats(ctx)
|
|
} else {
|
|
// Update the DERP map!
|
|
network.SetDERPMap(manifest.DERPMap)
|
|
}
|
|
|
|
a.logger.Debug(ctx, "running tailnet connection coordinator")
|
|
err = a.runCoordinator(ctx, network)
|
|
if err != nil {
|
|
return xerrors.Errorf("run coordinator: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *agent) trackConnGoroutine(fn func()) error {
|
|
a.closeMutex.Lock()
|
|
defer a.closeMutex.Unlock()
|
|
if a.isClosed() {
|
|
return xerrors.New("track conn goroutine: agent is closed")
|
|
}
|
|
a.connCloseWait.Add(1)
|
|
go func() {
|
|
defer a.connCloseWait.Done()
|
|
fn()
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ *tailnet.Conn, err error) {
|
|
network, err := tailnet.NewConn(&tailnet.Options{
|
|
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
|
|
DERPMap: derpMap,
|
|
Logger: a.logger.Named("tailnet"),
|
|
ListenPort: a.tailnetListenPort,
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("create tailnet: %w", err)
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
network.Close()
|
|
}
|
|
}()
|
|
|
|
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.WorkspaceAgentSSHPort))
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = sshListener.Close()
|
|
}
|
|
}()
|
|
if err = a.trackConnGoroutine(func() {
|
|
_ = a.sshServer.Serve(sshListener)
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.WorkspaceAgentReconnectingPTYPort))
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("listen for reconnecting pty: %w", err)
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = reconnectingPTYListener.Close()
|
|
}
|
|
}()
|
|
if err = a.trackConnGoroutine(func() {
|
|
logger := a.logger.Named("reconnecting-pty")
|
|
var wg sync.WaitGroup
|
|
for {
|
|
conn, err := reconnectingPTYListener.Accept()
|
|
if err != nil {
|
|
if !a.isClosed() {
|
|
logger.Debug(ctx, "accept pty failed", slog.Error(err))
|
|
}
|
|
break
|
|
}
|
|
logger.Debug(ctx, "accepted conn", slog.F("remote", conn.RemoteAddr().String()))
|
|
wg.Add(1)
|
|
closed := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-closed:
|
|
case <-a.closed:
|
|
_ = conn.Close()
|
|
}
|
|
wg.Done()
|
|
}()
|
|
go func() {
|
|
defer close(closed)
|
|
// This cannot use a JSON decoder, since that can
|
|
// buffer additional data that is required for the PTY.
|
|
rawLen := make([]byte, 2)
|
|
_, err = conn.Read(rawLen)
|
|
if err != nil {
|
|
return
|
|
}
|
|
length := binary.LittleEndian.Uint16(rawLen)
|
|
data := make([]byte, length)
|
|
_, err = conn.Read(data)
|
|
if err != nil {
|
|
return
|
|
}
|
|
var msg codersdk.WorkspaceAgentReconnectingPTYInit
|
|
err = json.Unmarshal(data, &msg)
|
|
if err != nil {
|
|
logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
|
|
return
|
|
}
|
|
_ = a.handleReconnectingPTY(ctx, logger, msg, conn)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.WorkspaceAgentSpeedtestPort))
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("listen for speedtest: %w", err)
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = speedtestListener.Close()
|
|
}
|
|
}()
|
|
if err = a.trackConnGoroutine(func() {
|
|
var wg sync.WaitGroup
|
|
for {
|
|
conn, err := speedtestListener.Accept()
|
|
if err != nil {
|
|
if !a.isClosed() {
|
|
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
|
|
}
|
|
break
|
|
}
|
|
wg.Add(1)
|
|
closed := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-closed:
|
|
case <-a.closed:
|
|
_ = conn.Close()
|
|
}
|
|
wg.Done()
|
|
}()
|
|
go func() {
|
|
defer close(closed)
|
|
_ = speedtest.ServeConn(conn)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
apiListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.WorkspaceAgentHTTPAPIServerPort))
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("api listener: %w", err)
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = apiListener.Close()
|
|
}
|
|
}()
|
|
if err = a.trackConnGoroutine(func() {
|
|
defer apiListener.Close()
|
|
server := &http.Server{
|
|
Handler: a.apiHandler(),
|
|
ReadTimeout: 20 * time.Second,
|
|
ReadHeaderTimeout: 20 * time.Second,
|
|
WriteTimeout: 20 * time.Second,
|
|
ErrorLog: slog.Stdlib(ctx, a.logger.Named("http_api_server"), slog.LevelInfo),
|
|
}
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
case <-a.closed:
|
|
}
|
|
_ = server.Close()
|
|
}()
|
|
|
|
err := server.Serve(apiListener)
|
|
if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
|
|
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err))
|
|
}
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return network, nil
|
|
}
|
|
|
|
// runCoordinator runs a coordinator and returns whether a reconnect
|
|
// should occur.
|
|
func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
coordinator, err := a.client.Listen(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer coordinator.Close()
|
|
a.logger.Info(ctx, "connected to coordination endpoint")
|
|
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error {
|
|
return network.UpdateNodes(nodes, false)
|
|
})
|
|
network.SetNodeCallback(sendNodes)
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err := <-errChan:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (a *agent) runStartupScript(ctx context.Context, script string) error {
|
|
return a.runScript(ctx, "startup", script)
|
|
}
|
|
|
|
func (a *agent) runShutdownScript(ctx context.Context, script string) error {
|
|
return a.runScript(ctx, "shutdown", script)
|
|
}
|
|
|
|
func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
|
|
if script == "" {
|
|
return nil
|
|
}
|
|
|
|
a.logger.Info(ctx, "running script", slog.F("lifecycle", lifecycle), slog.F("script", script))
|
|
fileWriter, err := a.filesystem.OpenFile(filepath.Join(a.logDir, fmt.Sprintf("coder-%s-script.log", lifecycle)), os.O_CREATE|os.O_RDWR, 0o600)
|
|
if err != nil {
|
|
return xerrors.Errorf("open %s script log file: %w", lifecycle, err)
|
|
}
|
|
defer func() {
|
|
_ = fileWriter.Close()
|
|
}()
|
|
|
|
var writer io.Writer = fileWriter
|
|
if lifecycle == "startup" {
|
|
// Create pipes for startup logs reader and writer
|
|
logsReader, logsWriter := io.Pipe()
|
|
defer func() {
|
|
_ = logsReader.Close()
|
|
}()
|
|
writer = io.MultiWriter(fileWriter, logsWriter)
|
|
flushedLogs, err := a.trackScriptLogs(ctx, logsReader)
|
|
if err != nil {
|
|
return xerrors.Errorf("track script logs: %w", err)
|
|
}
|
|
defer func() {
|
|
_ = logsWriter.Close()
|
|
<-flushedLogs
|
|
}()
|
|
}
|
|
|
|
cmdPty, err := a.sshServer.CreateCommand(ctx, script, nil)
|
|
if err != nil {
|
|
return xerrors.Errorf("create command: %w", err)
|
|
}
|
|
cmd := cmdPty.AsExec()
|
|
cmd.Stdout = writer
|
|
cmd.Stderr = writer
|
|
err = cmd.Run()
|
|
if err != nil {
|
|
// cmd.Run does not return a context canceled error, it returns "signal: killed".
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
|
|
return xerrors.Errorf("run: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *agent) trackScriptLogs(ctx context.Context, reader io.Reader) (chan struct{}, error) {
|
|
// Initialize variables for log management
|
|
queuedLogs := make([]agentsdk.StartupLog, 0)
|
|
var flushLogsTimer *time.Timer
|
|
var logMutex sync.Mutex
|
|
logsFlushed := sync.NewCond(&sync.Mutex{})
|
|
var logsSending bool
|
|
defer func() {
|
|
logMutex.Lock()
|
|
if flushLogsTimer != nil {
|
|
flushLogsTimer.Stop()
|
|
}
|
|
logMutex.Unlock()
|
|
}()
|
|
|
|
// sendLogs function uploads the queued logs to the server
|
|
sendLogs := func() {
|
|
// Lock logMutex and check if logs are already being sent
|
|
logMutex.Lock()
|
|
if logsSending {
|
|
logMutex.Unlock()
|
|
return
|
|
}
|
|
if flushLogsTimer != nil {
|
|
flushLogsTimer.Stop()
|
|
}
|
|
if len(queuedLogs) == 0 {
|
|
logMutex.Unlock()
|
|
return
|
|
}
|
|
// Move the current queued logs to logsToSend and clear the queue
|
|
logsToSend := queuedLogs
|
|
logsSending = true
|
|
queuedLogs = make([]agentsdk.StartupLog, 0)
|
|
logMutex.Unlock()
|
|
|
|
// Retry uploading logs until successful or a specific error occurs
|
|
for r := retry.New(time.Second, 5*time.Second); r.Wait(ctx); {
|
|
err := a.client.PatchStartupLogs(ctx, agentsdk.PatchStartupLogs{
|
|
Logs: logsToSend,
|
|
})
|
|
if err == nil {
|
|
break
|
|
}
|
|
var sdkErr *codersdk.Error
|
|
if errors.As(err, &sdkErr) {
|
|
if sdkErr.StatusCode() == http.StatusRequestEntityTooLarge {
|
|
a.logger.Warn(ctx, "startup logs too large, dropping logs")
|
|
break
|
|
}
|
|
}
|
|
a.logger.Error(ctx, "upload startup logs", slog.Error(err), slog.F("to_send", logsToSend))
|
|
}
|
|
// Reset logsSending flag
|
|
logMutex.Lock()
|
|
logsSending = false
|
|
flushLogsTimer.Reset(100 * time.Millisecond)
|
|
logMutex.Unlock()
|
|
logsFlushed.Broadcast()
|
|
}
|
|
// queueLog function appends a log to the queue and triggers sendLogs if necessary
|
|
queueLog := func(log agentsdk.StartupLog) {
|
|
logMutex.Lock()
|
|
defer logMutex.Unlock()
|
|
|
|
// Append log to the queue
|
|
queuedLogs = append(queuedLogs, log)
|
|
|
|
// If there are more than 100 logs, send them immediately
|
|
if len(queuedLogs) > 100 {
|
|
// Don't early return after this, because we still want
|
|
// to reset the timer just in case logs come in while
|
|
// we're sending.
|
|
go sendLogs()
|
|
}
|
|
// Reset or set the flushLogsTimer to trigger sendLogs after 100 milliseconds
|
|
if flushLogsTimer != nil {
|
|
flushLogsTimer.Reset(100 * time.Millisecond)
|
|
return
|
|
}
|
|
flushLogsTimer = time.AfterFunc(100*time.Millisecond, sendLogs)
|
|
}
|
|
|
|
// It's important that we either flush or drop all logs before returning
|
|
// because the startup state is reported after flush.
|
|
//
|
|
// It'd be weird for the startup state to be ready, but logs are still
|
|
// coming in.
|
|
logsFinished := make(chan struct{})
|
|
err := a.trackConnGoroutine(func() {
|
|
scanner := bufio.NewScanner(reader)
|
|
for scanner.Scan() {
|
|
queueLog(agentsdk.StartupLog{
|
|
CreatedAt: database.Now(),
|
|
Output: scanner.Text(),
|
|
})
|
|
}
|
|
defer close(logsFinished)
|
|
logsFlushed.L.Lock()
|
|
for {
|
|
logMutex.Lock()
|
|
if len(queuedLogs) == 0 {
|
|
logMutex.Unlock()
|
|
break
|
|
}
|
|
logMutex.Unlock()
|
|
logsFlushed.Wait()
|
|
}
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("track conn goroutine: %w", err)
|
|
}
|
|
return logsFinished, nil
|
|
}
|
|
|
|
func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) {
|
|
defer conn.Close()
|
|
|
|
a.connCountReconnectingPTY.Add(1)
|
|
defer a.connCountReconnectingPTY.Add(-1)
|
|
|
|
connectionID := uuid.NewString()
|
|
logger = logger.With(slog.F("id", msg.ID), slog.F("connection_id", connectionID))
|
|
logger.Debug(ctx, "starting handler")
|
|
|
|
defer func() {
|
|
if err := retErr; err != nil {
|
|
a.closeMutex.Lock()
|
|
closed := a.isClosed()
|
|
a.closeMutex.Unlock()
|
|
|
|
// If the agent is closed, we don't want to
|
|
// log this as an error since it's expected.
|
|
if closed {
|
|
logger.Debug(ctx, "session error after agent close", slog.Error(err))
|
|
} else {
|
|
logger.Error(ctx, "session error", slog.Error(err))
|
|
}
|
|
}
|
|
logger.Debug(ctx, "session closed")
|
|
}()
|
|
|
|
var rpty *reconnectingPTY
|
|
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
|
|
if ok {
|
|
logger.Debug(ctx, "connecting to existing session")
|
|
rpty, ok = rawRPTY.(*reconnectingPTY)
|
|
if !ok {
|
|
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY)
|
|
}
|
|
} else {
|
|
logger.Debug(ctx, "creating new session")
|
|
|
|
// Empty command will default to the users shell!
|
|
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
|
|
if err != nil {
|
|
return xerrors.Errorf("create command: %w", err)
|
|
}
|
|
cmd.Env = append(cmd.Env, "TERM=xterm-256color")
|
|
|
|
// Default to buffer 64KiB.
|
|
circularBuffer, err := circbuf.NewBuffer(64 << 10)
|
|
if err != nil {
|
|
return xerrors.Errorf("create circular buffer: %w", err)
|
|
}
|
|
|
|
ptty, process, err := pty.Start(cmd)
|
|
if err != nil {
|
|
return xerrors.Errorf("start command: %w", err)
|
|
}
|
|
|
|
ctx, cancelFunc := context.WithCancel(ctx)
|
|
rpty = &reconnectingPTY{
|
|
activeConns: map[string]net.Conn{
|
|
// We have to put the connection in the map instantly otherwise
|
|
// the connection won't be closed if the process instantly dies.
|
|
connectionID: conn,
|
|
},
|
|
ptty: ptty,
|
|
// Timeouts created with an after func can be reset!
|
|
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
|
|
circularBuffer: circularBuffer,
|
|
}
|
|
a.reconnectingPTYs.Store(msg.ID, rpty)
|
|
// We don't need to separately monitor for the process exiting.
|
|
// When it exits, our ptty.OutputReader() will return EOF after
|
|
// reading all process output.
|
|
if err = a.trackConnGoroutine(func() {
|
|
buffer := make([]byte, 1024)
|
|
for {
|
|
read, err := rpty.ptty.OutputReader().Read(buffer)
|
|
if err != nil {
|
|
// When the PTY is closed, this is triggered.
|
|
// Error is typically a benign EOF, so only log for debugging.
|
|
logger.Debug(ctx, "unable to read pty output, command exited?", slog.Error(err))
|
|
break
|
|
}
|
|
part := buffer[:read]
|
|
rpty.circularBufferMutex.Lock()
|
|
_, err = rpty.circularBuffer.Write(part)
|
|
rpty.circularBufferMutex.Unlock()
|
|
if err != nil {
|
|
logger.Error(ctx, "write to circular buffer", slog.Error(err))
|
|
break
|
|
}
|
|
rpty.activeConnsMutex.Lock()
|
|
for cid, conn := range rpty.activeConns {
|
|
_, err = conn.Write(part)
|
|
if err != nil {
|
|
logger.Debug(ctx,
|
|
"error writing to active conn",
|
|
slog.F("other_conn_id", cid),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
rpty.activeConnsMutex.Unlock()
|
|
}
|
|
|
|
// Cleanup the process, PTY, and delete it's
|
|
// ID from memory.
|
|
_ = process.Kill()
|
|
rpty.Close()
|
|
a.reconnectingPTYs.Delete(msg.ID)
|
|
}); err != nil {
|
|
return xerrors.Errorf("start routine: %w", err)
|
|
}
|
|
}
|
|
// Resize the PTY to initial height + width.
|
|
err := rpty.ptty.Resize(msg.Height, msg.Width)
|
|
if err != nil {
|
|
// We can continue after this, it's not fatal!
|
|
logger.Error(ctx, "resize", slog.Error(err))
|
|
}
|
|
// Write any previously stored data for the TTY.
|
|
rpty.circularBufferMutex.RLock()
|
|
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
|
|
rpty.circularBufferMutex.RUnlock()
|
|
// Note that there is a small race here between writing buffered
|
|
// data and storing conn in activeConns. This is likely a very minor
|
|
// edge case, but we should look into ways to avoid it. Holding
|
|
// activeConnsMutex would be one option, but holding this mutex
|
|
// while also holding circularBufferMutex seems dangerous.
|
|
_, err = conn.Write(prevBuf)
|
|
if err != nil {
|
|
return xerrors.Errorf("write buffer to conn: %w", err)
|
|
}
|
|
// Multiple connections to the same TTY are permitted.
|
|
// This could easily be used for terminal sharing, but
|
|
// we do it because it's a nice user experience to
|
|
// copy/paste a terminal URL and have it _just work_.
|
|
rpty.activeConnsMutex.Lock()
|
|
rpty.activeConns[connectionID] = conn
|
|
rpty.activeConnsMutex.Unlock()
|
|
// Resetting this timeout prevents the PTY from exiting.
|
|
rpty.timeout.Reset(a.reconnectingPTYTimeout)
|
|
|
|
ctx, cancelFunc := context.WithCancel(ctx)
|
|
defer cancelFunc()
|
|
heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2)
|
|
defer heartbeat.Stop()
|
|
go func() {
|
|
// Keep updating the activity while this
|
|
// connection is alive!
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-heartbeat.C:
|
|
}
|
|
rpty.timeout.Reset(a.reconnectingPTYTimeout)
|
|
}
|
|
}()
|
|
defer func() {
|
|
// After this connection ends, remove it from
|
|
// the PTYs active connections. If it isn't
|
|
// removed, all PTY data will be sent to it.
|
|
rpty.activeConnsMutex.Lock()
|
|
delete(rpty.activeConns, connectionID)
|
|
rpty.activeConnsMutex.Unlock()
|
|
}()
|
|
decoder := json.NewDecoder(conn)
|
|
var req codersdk.ReconnectingPTYRequest
|
|
for {
|
|
err = decoder.Decode(&req)
|
|
if xerrors.Is(err, io.EOF) {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
logger.Warn(ctx, "read conn", slog.Error(err))
|
|
return nil
|
|
}
|
|
_, err = rpty.ptty.InputWriter().Write([]byte(req.Data))
|
|
if err != nil {
|
|
logger.Warn(ctx, "write to pty", slog.Error(err))
|
|
return nil
|
|
}
|
|
// Check if a resize needs to happen!
|
|
if req.Height == 0 || req.Width == 0 {
|
|
continue
|
|
}
|
|
err = rpty.ptty.Resize(req.Height, req.Width)
|
|
if err != nil {
|
|
// We can continue after this, it's not fatal!
|
|
logger.Error(ctx, "resize", slog.Error(err))
|
|
}
|
|
}
|
|
}
|
|
|
|
// startReportingConnectionStats runs the connection stats reporting goroutine.
|
|
func (a *agent) startReportingConnectionStats(ctx context.Context) {
|
|
reportStats := func(networkStats map[netlogtype.Connection]netlogtype.Counts) {
|
|
stats := &agentsdk.Stats{
|
|
ConnectionCount: int64(len(networkStats)),
|
|
ConnectionsByProto: map[string]int64{},
|
|
}
|
|
for conn, counts := range networkStats {
|
|
stats.ConnectionsByProto[conn.Proto.String()]++
|
|
stats.RxBytes += int64(counts.RxBytes)
|
|
stats.RxPackets += int64(counts.RxPackets)
|
|
stats.TxBytes += int64(counts.TxBytes)
|
|
stats.TxPackets += int64(counts.TxPackets)
|
|
}
|
|
|
|
// The count of active sessions.
|
|
sshStats := a.sshServer.ConnStats()
|
|
stats.SessionCountSSH = sshStats.Sessions
|
|
stats.SessionCountVSCode = sshStats.VSCode
|
|
stats.SessionCountJetBrains = sshStats.JetBrains
|
|
|
|
stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load()
|
|
|
|
// Compute the median connection latency!
|
|
var wg sync.WaitGroup
|
|
var mu sync.Mutex
|
|
status := a.network.Status()
|
|
durations := []float64{}
|
|
ctx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancelFunc()
|
|
for nodeID, peer := range status.Peer {
|
|
if !peer.Active {
|
|
continue
|
|
}
|
|
addresses, found := a.network.NodeAddresses(nodeID)
|
|
if !found {
|
|
continue
|
|
}
|
|
if len(addresses) == 0 {
|
|
continue
|
|
}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
duration, _, _, err := a.network.Ping(ctx, addresses[0].Addr())
|
|
if err != nil {
|
|
return
|
|
}
|
|
mu.Lock()
|
|
durations = append(durations, float64(duration.Microseconds()))
|
|
mu.Unlock()
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
sort.Float64s(durations)
|
|
durationsLength := len(durations)
|
|
if durationsLength == 0 {
|
|
stats.ConnectionMedianLatencyMS = -1
|
|
} else if durationsLength%2 == 0 {
|
|
stats.ConnectionMedianLatencyMS = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2
|
|
} else {
|
|
stats.ConnectionMedianLatencyMS = durations[durationsLength/2]
|
|
}
|
|
// Convert from microseconds to milliseconds.
|
|
stats.ConnectionMedianLatencyMS /= 1000
|
|
|
|
// Collect agent metrics.
|
|
// Agent metrics are changing all the time, so there is no need to perform
|
|
// reflect.DeepEqual to see if stats should be transferred.
|
|
stats.Metrics = collectMetrics()
|
|
|
|
a.latestStat.Store(stats)
|
|
|
|
select {
|
|
case a.connStatsChan <- stats:
|
|
case <-a.closed:
|
|
}
|
|
}
|
|
|
|
// Report statistics from the created network.
|
|
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, func(d time.Duration) {
|
|
a.network.SetConnStatsCallback(d, 2048,
|
|
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
|
|
reportStats(virtual)
|
|
},
|
|
)
|
|
})
|
|
if err != nil {
|
|
a.logger.Error(ctx, "report stats", slog.Error(err))
|
|
} else {
|
|
if err = a.trackConnGoroutine(func() {
|
|
// This is OK because the agent never re-creates the tailnet
|
|
// and the only shutdown indicator is agent.Close().
|
|
<-a.closed
|
|
_ = cl.Close()
|
|
}); err != nil {
|
|
a.logger.Debug(ctx, "report stats goroutine", slog.Error(err))
|
|
_ = cl.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
// isClosed returns whether the API is closed or not.
|
|
func (a *agent) isClosed() bool {
|
|
select {
|
|
case <-a.closed:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (a *agent) HTTPDebug() http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
a.closeMutex.Lock()
|
|
network := a.network
|
|
a.closeMutex.Unlock()
|
|
|
|
if network == nil {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("network is not ready yet"))
|
|
return
|
|
}
|
|
|
|
if r.URL.Path == "/debug/magicsock" {
|
|
network.MagicsockServeHTTPDebug(w, r)
|
|
} else {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
_, _ = w.Write([]byte("404 not found"))
|
|
}
|
|
})
|
|
}
|
|
|
|
func (a *agent) Close() error {
|
|
a.closeMutex.Lock()
|
|
defer a.closeMutex.Unlock()
|
|
if a.isClosed() {
|
|
return nil
|
|
}
|
|
|
|
ctx := context.Background()
|
|
a.logger.Info(ctx, "shutting down agent")
|
|
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown)
|
|
|
|
// Attempt to gracefully shut down all active SSH connections and
|
|
// stop accepting new ones.
|
|
err := a.sshServer.Shutdown(ctx)
|
|
if err != nil {
|
|
a.logger.Error(ctx, "ssh server shutdown", slog.Error(err))
|
|
}
|
|
|
|
lifecycleState := codersdk.WorkspaceAgentLifecycleOff
|
|
if manifest := a.manifest.Load(); manifest != nil && manifest.ShutdownScript != "" {
|
|
scriptDone := make(chan error, 1)
|
|
scriptStart := time.Now()
|
|
go func() {
|
|
defer close(scriptDone)
|
|
scriptDone <- a.runShutdownScript(ctx, manifest.ShutdownScript)
|
|
}()
|
|
|
|
var timeout <-chan time.Time
|
|
// If timeout is zero, an older version of the coder
|
|
// provider was used. Otherwise a timeout is always > 0.
|
|
if manifest.ShutdownScriptTimeout > 0 {
|
|
t := time.NewTimer(manifest.ShutdownScriptTimeout)
|
|
defer t.Stop()
|
|
timeout = t.C
|
|
}
|
|
|
|
var err error
|
|
select {
|
|
case err = <-scriptDone:
|
|
case <-timeout:
|
|
a.logger.Warn(ctx, "shutdown script timed out")
|
|
a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShutdownTimeout)
|
|
err = <-scriptDone // The script can still complete after a timeout.
|
|
}
|
|
execTime := time.Since(scriptStart)
|
|
if err != nil {
|
|
a.logger.Warn(ctx, "shutdown script failed", slog.F("execution_time", execTime), slog.Error(err))
|
|
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
|
|
} else {
|
|
a.logger.Info(ctx, "shutdown script completed", slog.F("execution_time", execTime))
|
|
}
|
|
}
|
|
|
|
// Set final state and wait for it to be reported because context
|
|
// cancellation will stop the report loop.
|
|
a.setLifecycle(ctx, lifecycleState)
|
|
|
|
// Wait for the lifecycle to be reported, but don't wait forever so
|
|
// that we don't break user expectations.
|
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
lifecycleWaitLoop:
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
break lifecycleWaitLoop
|
|
case s := <-a.lifecycleReported:
|
|
if s == lifecycleState {
|
|
break lifecycleWaitLoop
|
|
}
|
|
}
|
|
}
|
|
|
|
close(a.closed)
|
|
a.closeCancel()
|
|
_ = a.sshServer.Close()
|
|
if a.network != nil {
|
|
_ = a.network.Close()
|
|
}
|
|
a.connCloseWait.Wait()
|
|
|
|
return nil
|
|
}
|
|
|
|
type reconnectingPTY struct {
|
|
activeConnsMutex sync.Mutex
|
|
activeConns map[string]net.Conn
|
|
|
|
circularBuffer *circbuf.Buffer
|
|
circularBufferMutex sync.RWMutex
|
|
timeout *time.Timer
|
|
ptty pty.PTYCmd
|
|
}
|
|
|
|
// Close ends all connections to the reconnecting
|
|
// PTY and clear the circular buffer.
|
|
func (r *reconnectingPTY) Close() {
|
|
r.activeConnsMutex.Lock()
|
|
defer r.activeConnsMutex.Unlock()
|
|
for _, conn := range r.activeConns {
|
|
_ = conn.Close()
|
|
}
|
|
_ = r.ptty.Close()
|
|
r.circularBufferMutex.Lock()
|
|
r.circularBuffer.Reset()
|
|
r.circularBufferMutex.Unlock()
|
|
r.timeout.Stop()
|
|
}
|
|
|
|
// userHomeDir returns the home directory of the current user, giving
|
|
// priority to the $HOME environment variable.
|
|
func userHomeDir() (string, error) {
|
|
// First we check the environment.
|
|
homedir, err := os.UserHomeDir()
|
|
if err == nil {
|
|
return homedir, nil
|
|
}
|
|
|
|
// As a fallback, we try the user information.
|
|
u, err := user.Current()
|
|
if err != nil {
|
|
return "", xerrors.Errorf("current user: %w", err)
|
|
}
|
|
return u.HomeDir, nil
|
|
}
|
|
|
|
// expandDirectory converts a directory path to an absolute path.
|
|
// It primarily resolves the home directory and any environment
|
|
// variables that may be set
|
|
func expandDirectory(dir string) (string, error) {
|
|
if dir == "" {
|
|
return "", nil
|
|
}
|
|
if dir[0] == '~' {
|
|
home, err := userHomeDir()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
dir = filepath.Join(home, dir[1:])
|
|
}
|
|
dir = os.ExpandEnv(dir)
|
|
|
|
if !filepath.IsAbs(dir) {
|
|
home, err := userHomeDir()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
dir = filepath.Join(home, dir)
|
|
}
|
|
return dir, nil
|
|
}
|