fix: Refactor agent to consume API client (#4715)

* fix: Refactor agent to consume API client

This simplifies a lot of code by creating an interface for
the codersdk client into the agent. It also moves agent
authentication code so instance identity will work between
restarts.

Fixes #3485 and #4082.

* Fix client reconnections
This commit is contained in:
Kyle Carberry 2022-10-23 22:35:08 -05:00 committed by GitHub
parent c9bf2a9099
commit bf3224e373
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 379 additions and 364 deletions

View File

@ -34,6 +34,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/pty"
"github.com/coder/coder/tailnet"
@ -52,22 +53,20 @@ const (
)
type Options struct {
CoordinatorDialer CoordinatorDialer
FetchMetadata FetchMetadata
StatsReporter StatsReporter
WorkspaceAgentApps WorkspaceAgentApps
PostWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
ExchangeToken func(ctx context.Context) error
Client Client
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
}
// CoordinatorDialer is a function that constructs a new broker.
// A dialer must be passed in to allow for reconnects.
type CoordinatorDialer func(context.Context) (net.Conn, error)
// FetchMetadata is a function to obtain metadata for the agent.
type FetchMetadata func(context.Context) (codersdk.WorkspaceAgentMetadata, error)
type Client interface {
WorkspaceAgentMetadata(ctx context.Context) (codersdk.WorkspaceAgentMetadata, error)
ListenWorkspaceAgent(ctx context.Context) (net.Conn, error)
AgentReportStats(ctx context.Context, log slog.Logger, stats func() *codersdk.AgentStats) (io.Closer, error)
PostWorkspaceAgentAppHealth(ctx context.Context, req codersdk.PostWorkspaceAppHealthsRequest) error
PostWorkspaceAgentVersion(ctx context.Context, version string) error
}
func New(options Options) io.Closer {
if options.ReconnectingPTYTimeout == 0 {
@ -75,24 +74,23 @@ func New(options Options) io.Closer {
}
ctx, cancelFunc := context.WithCancel(context.Background())
server := &agent{
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
envVars: options.EnvironmentVariables,
coordinatorDialer: options.CoordinatorDialer,
fetchMetadata: options.FetchMetadata,
stats: &Stats{},
statsReporter: options.StatsReporter,
workspaceAgentApps: options.WorkspaceAgentApps,
postWorkspaceAgentAppHealth: options.PostWorkspaceAgentAppHealth,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
envVars: options.EnvironmentVariables,
client: options.Client,
exchangeToken: options.ExchangeToken,
stats: &Stats{},
}
server.init(ctx)
return server
}
type agent struct {
logger slog.Logger
logger slog.Logger
client Client
exchangeToken func(ctx context.Context) error
reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration
@ -104,100 +102,130 @@ type agent struct {
envVars map[string]string
// metadata is atomic because values can change after reconnection.
metadata atomic.Value
fetchMetadata FetchMetadata
sshServer *ssh.Server
metadata atomic.Value
sshServer *ssh.Server
network *tailnet.Conn
coordinatorDialer CoordinatorDialer
stats *Stats
statsReporter StatsReporter
workspaceAgentApps WorkspaceAgentApps
postWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth
network *tailnet.Conn
stats *Stats
}
func (a *agent) run(ctx context.Context) {
var metadata codersdk.WorkspaceAgentMetadata
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
a.logger.Info(ctx, "connecting")
metadata, err = a.fetchMetadata(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if a.isClosed() {
return
}
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
// runLoop attempts to start the agent in a retry loop.
// Coder may be offline temporarily, a connection issue
// may be happening, but regardless after the intermittent
// failure, you'll want the agent to reconnect.
func (a *agent) runLoop(ctx context.Context) {
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
a.logger.Info(ctx, "running loop")
err := a.run(ctx)
// Cancel after the run is complete to clean up any leaked resources!
if err == nil {
continue
}
a.logger.Info(context.Background(), "fetched metadata")
break
}
select {
case <-ctx.Done():
return
default:
}
a.metadata.Store(metadata)
// The startup script has not ran yet!
go func() {
err := a.runStartupScript(ctx, metadata.StartupScript)
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
if a.isClosed() {
return
}
}()
if metadata.DERPMap != nil {
go a.runTailnet(ctx, metadata.DERPMap)
}
if a.workspaceAgentApps != nil && a.postWorkspaceAgentAppHealth != nil {
go NewWorkspaceAppHealthReporter(a.logger, a.workspaceAgentApps, a.postWorkspaceAgentAppHealth)(ctx)
if errors.Is(err, io.EOF) {
a.logger.Info(ctx, "likely disconnected from coder", slog.Error(err))
continue
}
a.logger.Warn(ctx, "run exited with error", slog.Error(err))
}
}
func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
func (a *agent) run(ctx context.Context) error {
// This allows the agent to refresh it's token if necessary.
// For instance identity this is required, since the instance
// may not have re-provisioned, but a new agent ID was created.
if a.exchangeToken != nil {
err := a.exchangeToken(ctx)
if err != nil {
return xerrors.Errorf("exchange token: %w", err)
}
}
err := a.client.PostWorkspaceAgentVersion(ctx, buildinfo.Version())
if err != nil {
return xerrors.Errorf("update workspace agent version: %w", err)
}
metadata, err := a.client.WorkspaceAgentMetadata(ctx)
if err != nil {
return xerrors.Errorf("fetch metadata: %w", err)
}
a.logger.Info(context.Background(), "fetched metadata")
oldMetadata := a.metadata.Swap(metadata)
// The startup script should only execute on the first run!
if oldMetadata == nil {
go func() {
err := a.runStartupScript(ctx, metadata.StartupScript)
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
}
}()
}
// This automatically closes when the context ends!
appReporterCtx, appReporterCtxCancel := context.WithCancel(ctx)
defer appReporterCtxCancel()
go NewWorkspaceAppHealthReporter(
a.logger, metadata.Apps, a.client.PostWorkspaceAgentAppHealth)(appReporterCtx)
a.logger.Debug(ctx, "running tailnet with derpmap", slog.F("derpmap", metadata.DERPMap))
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
return
network := a.network
a.closeMutex.Unlock()
if a.network == nil {
a.logger.Debug(ctx, "creating tailnet")
network, err = a.createTailnet(ctx, metadata.DERPMap)
if err != nil {
return xerrors.Errorf("create tailnet: %w", err)
}
a.closeMutex.Lock()
a.network = network
a.closeMutex.Unlock()
} else {
// Update the DERP map!
network.SetDERPMap(metadata.DERPMap)
}
a.logger.Debug(ctx, "running tailnet with derpmap", slog.F("derpmap", derpMap))
if a.network != nil {
a.network.SetDERPMap(derpMap)
return
a.logger.Debug(ctx, "running coordinator")
err = a.runCoordinator(ctx, network)
if err != nil {
a.logger.Debug(ctx, "coordinator exited", slog.Error(err))
return xerrors.Errorf("run coordinator: %w", err)
}
var err error
a.network, err = tailnet.NewConn(&tailnet.Options{
return nil
}
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*tailnet.Conn, error) {
network, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
})
if err != nil {
a.logger.Critical(ctx, "create tailnet", slog.Error(err))
return
return nil, xerrors.Errorf("create tailnet: %w", err)
}
a.network.SetForwardTCPCallback(func(conn net.Conn, listenerExists bool) net.Conn {
a.network = network
network.SetForwardTCPCallback(func(conn net.Conn, listenerExists bool) net.Conn {
if listenerExists {
// If a listener already exists, we would double-wrap the conn.
return conn
}
return a.stats.wrapConn(conn)
})
go a.runCoordinator(ctx)
sshListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSSHPort))
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSSHPort))
if err != nil {
a.logger.Critical(ctx, "listen for ssh", slog.Error(err))
return
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
}
go func() {
for {
@ -209,10 +237,9 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
}
}()
reconnectingPTYListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort))
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort))
if err != nil {
a.logger.Critical(ctx, "listen for reconnecting pty", slog.Error(err))
return
return nil, xerrors.Errorf("listen for reconnecting pty: %w", err)
}
go func() {
for {
@ -244,10 +271,9 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
}
}()
speedtestListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort))
speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort))
if err != nil {
a.logger.Critical(ctx, "listen for speedtest", slog.Error(err))
return
return nil, xerrors.Errorf("listen for speedtest: %w", err)
}
go func() {
for {
@ -266,10 +292,9 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
}
}()
statisticsListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort))
statisticsListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort))
if err != nil {
a.logger.Critical(ctx, "listen for statistics", slog.Error(err))
return
return nil, xerrors.Errorf("listen for statistics: %w", err)
}
go func() {
defer statisticsListener.Close()
@ -290,59 +315,26 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
a.logger.Critical(ctx, "serve statistics HTTP server", slog.Error(err))
}
}()
return network, nil
}
// runCoordinator listens for nodes and updates the self-node as it changes.
func (a *agent) runCoordinator(ctx context.Context) {
for {
reconnect := a.runCoordinatorWithRetry(ctx)
if !reconnect {
return
}
}
}
func (a *agent) runCoordinatorWithRetry(ctx context.Context) (reconnect bool) {
var coordinator net.Conn
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
coordinator, err = a.coordinatorDialer(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return false
}
if a.isClosed() {
return false
}
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
//nolint:revive // Defer is ok because we're exiting this loop.
defer coordinator.Close()
a.logger.Info(context.Background(), "connected to coordination server")
break
// runCoordinator runs a coordinator and returns whether a reconnect
// should occur.
func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error {
coordinator, err := a.client.ListenWorkspaceAgent(ctx)
if err != nil {
return err
}
defer coordinator.Close()
a.logger.Info(context.Background(), "connected to coordination server")
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, network.UpdateNodes)
network.SetNodeCallback(sendNodes)
select {
case <-ctx.Done():
return false
default:
}
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, a.network.UpdateNodes)
a.network.SetNodeCallback(sendNodes)
select {
case <-ctx.Done():
return false
return ctx.Err()
case err := <-errChan:
if a.isClosed() {
return false
}
if errors.Is(err, context.Canceled) {
return false
}
a.logger.Debug(ctx, "node broker accept exited; restarting connection", slog.Error(err))
return true
return err
}
}
@ -474,22 +466,20 @@ func (a *agent) init(ctx context.Context) {
},
}
go a.run(ctx)
if a.statsReporter != nil {
cl, err := a.statsReporter(ctx, a.logger, func() *codersdk.AgentStats {
return a.stats.Copy()
})
if err != nil {
a.logger.Error(ctx, "report stats", slog.Error(err))
return
}
a.connCloseWait.Add(1)
go func() {
defer a.connCloseWait.Done()
<-a.closed
cl.Close()
}()
go a.runLoop(ctx)
cl, err := a.client.AgentReportStats(ctx, a.logger, func() *codersdk.AgentStats {
return a.stats.Copy()
})
if err != nil {
a.logger.Error(ctx, "report stats", slog.Error(err))
return
}
a.connCloseWait.Add(1)
go func() {
defer a.connCloseWait.Done()
<-a.closed
cl.Close()
}()
}
// createCommand processes raw command input with OpenSSH-like behavior.
@ -696,7 +686,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
ptty, process, err := pty.Start(cmd)
if err != nil {
a.logger.Error(ctx, "start reconnecting pty command", slog.F("id", msg.ID))
a.logger.Error(ctx, "start reconnecting pty command", slog.F("id", msg.ID), slog.Error(err))
return
}

View File

@ -16,6 +16,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@ -503,6 +504,45 @@ func TestAgent(t *testing.T) {
require.NoError(t, err)
t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond())
})
t.Run("Reconnect", func(t *testing.T) {
t.Parallel()
// After the agent is disconnected from a coordinator, it's supposed
// to reconnect!
coordinator := tailnet.NewCoordinator()
agentID := uuid.New()
statsCh := make(chan *codersdk.AgentStats)
derpMap := tailnettest.RunDERPAndSTUN(t)
client := &client{
t: t,
agentID: agentID,
metadata: codersdk.WorkspaceAgentMetadata{
DERPMap: derpMap,
},
statsChan: statsCh,
coordinator: coordinator,
}
initialized := atomic.Int32{}
closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) error {
initialized.Add(1)
return nil
},
Client: client,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelInfo),
})
t.Cleanup(func() {
_ = closer.Close()
})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
client.lastWorkspaceAgent()
require.Eventually(t, func() bool {
return initialized.Load() == 2
}, testutil.WaitShort, testutil.IntervalFast)
})
}
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
@ -572,57 +612,15 @@ func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeo
agentID := uuid.New()
statsCh := make(chan *codersdk.AgentStats)
closer := agent.New(agent.Options{
FetchMetadata: func(ctx context.Context) (codersdk.WorkspaceAgentMetadata, error) {
return metadata, nil
},
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
closed := make(chan struct{})
t.Cleanup(func() {
_ = serverConn.Close()
_ = clientConn.Close()
<-closed
})
go func() {
_ = coordinator.ServeAgent(serverConn, agentID)
close(closed)
}()
return clientConn, nil
Client: &client{
t: t,
agentID: agentID,
metadata: metadata,
statsChan: statsCh,
coordinator: coordinator,
},
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
ReconnectingPTYTimeout: ptyTimeout,
StatsReporter: func(ctx context.Context, log slog.Logger, statsFn func() *codersdk.AgentStats) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
go func() {
defer close(doneCh)
t := time.NewTicker(time.Millisecond * 100)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
}
select {
case statsCh <- statsFn():
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}()
return closeFunc(func() error {
cancel()
<-doneCh
close(statsCh)
return nil
}), nil
},
})
t.Cleanup(func() {
_ = closer.Close()
@ -679,3 +677,73 @@ func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
assert.NoError(t, err, "write payload")
assert.Equal(t, len(payload), n, "payload length does not match")
}
type client struct {
t *testing.T
agentID uuid.UUID
metadata codersdk.WorkspaceAgentMetadata
statsChan chan *codersdk.AgentStats
coordinator tailnet.Coordinator
lastWorkspaceAgent func()
}
func (c *client) WorkspaceAgentMetadata(_ context.Context) (codersdk.WorkspaceAgentMetadata, error) {
return c.metadata, nil
}
func (c *client) ListenWorkspaceAgent(_ context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
closed := make(chan struct{})
c.lastWorkspaceAgent = func() {
_ = serverConn.Close()
_ = clientConn.Close()
<-closed
}
c.t.Cleanup(c.lastWorkspaceAgent)
go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID)
close(closed)
}()
return clientConn, nil
}
func (c *client) AgentReportStats(ctx context.Context, _ slog.Logger, stats func() *codersdk.AgentStats) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
go func() {
defer close(doneCh)
t := time.NewTicker(time.Millisecond * 100)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
}
select {
case c.statsChan <- stats():
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}()
return closeFunc(func() error {
cancel()
<-doneCh
close(c.statsChan)
return nil
}), nil
}
func (*client) PostWorkspaceAgentAppHealth(_ context.Context, _ codersdk.PostWorkspaceAppHealthsRequest) error {
return nil
}
func (*client) PostWorkspaceAgentVersion(_ context.Context, _ string) error {
return nil
}

View File

@ -23,16 +23,8 @@ type PostWorkspaceAgentAppHealth func(context.Context, codersdk.PostWorkspaceApp
type WorkspaceAppHealthReporter func(ctx context.Context)
// NewWorkspaceAppHealthReporter creates a WorkspaceAppHealthReporter that reports app health to coderd.
func NewWorkspaceAppHealthReporter(logger slog.Logger, workspaceAgentApps WorkspaceAgentApps, postWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth) WorkspaceAppHealthReporter {
func NewWorkspaceAppHealthReporter(logger slog.Logger, apps []codersdk.WorkspaceApp, postWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth) WorkspaceAppHealthReporter {
runHealthcheckLoop := func(ctx context.Context) error {
apps, err := workspaceAgentApps(ctx)
if err != nil {
if xerrors.Is(err, context.Canceled) {
return nil
}
return xerrors.Errorf("getting workspace apps: %w", err)
}
// no need to run this loop if no apps for this workspace.
if len(apps) == 0 {
return nil

View File

@ -199,7 +199,7 @@ func setupAppReporter(ctx context.Context, t *testing.T, apps []codersdk.Workspa
return nil
}
go agent.NewWorkspaceAppHealthReporter(slogtest.Make(t, nil).Leveled(slog.LevelDebug), workspaceAgentApps, postWorkspaceAgentAppHealth)(ctx)
go agent.NewWorkspaceAppHealthReporter(slogtest.Make(t, nil).Leveled(slog.LevelDebug), apps, postWorkspaceAgentAppHealth)(ctx)
return workspaceAgentApps, func() {
for _, closeFn := range closers {

View File

@ -1,12 +1,9 @@
package agent
import (
"context"
"io"
"net"
"sync/atomic"
"cdr.dev/slog"
"github.com/coder/coder/codersdk"
)
@ -59,10 +56,3 @@ func (s *Stats) wrapConn(conn net.Conn) net.Conn {
return cs
}
// StatsReporter periodically accept and records agent stats.
type StatsReporter func(
ctx context.Context,
log slog.Logger,
stats func() *codersdk.AgentStats,
) (io.Closer, error)

View File

@ -23,7 +23,6 @@ import (
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/codersdk"
"github.com/coder/retry"
)
func workspaceAgent() *cobra.Command {
@ -80,6 +79,8 @@ func workspaceAgent() *cobra.Command {
slog.F("version", version),
)
client := codersdk.New(coderURL)
// Set a reasonable timeout so requests can't hang forever!
client.HTTPClient.Timeout = 10 * time.Second
if pprofEnabled {
srvClose := serveHandler(cmd.Context(), logger, nil, pprofAddress, "pprof")
@ -143,43 +144,6 @@ func workspaceAgent() *cobra.Command {
}
}
if exchangeToken != nil {
logger.Info(cmd.Context(), "exchanging identity token")
// Agent's can start before resources are returned from the provisioner
// daemon. If there are many resources being provisioned, this time
// could be significant. This is arbitrarily set at an hour to prevent
// tons of idle agents from pinging coderd.
ctx, cancelFunc := context.WithTimeout(cmd.Context(), time.Hour)
defer cancelFunc()
for retry.New(100*time.Millisecond, 5*time.Second).Wait(ctx) {
var response codersdk.WorkspaceAgentAuthenticateResponse
response, err = exchangeToken(ctx)
if err != nil {
logger.Warn(ctx, "authenticate workspace", slog.F("method", auth), slog.Error(err))
continue
}
client.SessionToken = response.SessionToken
logger.Info(ctx, "authenticated", slog.F("method", auth))
break
}
if err != nil {
return xerrors.Errorf("agent failed to authenticate in time: %w", err)
}
}
retryCtx, cancelRetry := context.WithTimeout(cmd.Context(), time.Hour)
defer cancelRetry()
for retrier := retry.New(100*time.Millisecond, 5*time.Second); retrier.Wait(retryCtx); {
err := client.PostWorkspaceAgentVersion(retryCtx, version)
if err != nil {
logger.Warn(retryCtx, "post agent version: %w", slog.Error(err), slog.F("version", version))
continue
}
logger.Info(retryCtx, "updated agent version", slog.F("version", version))
break
}
executablePath, err := os.Executable()
if err != nil {
return xerrors.Errorf("getting os executable: %w", err)
@ -190,17 +154,24 @@ func workspaceAgent() *cobra.Command {
}
closer := agent.New(agent.Options{
FetchMetadata: client.WorkspaceAgentMetadata,
Logger: logger,
Client: client,
Logger: logger,
ExchangeToken: func(ctx context.Context) error {
if exchangeToken == nil {
return nil
}
resp, err := exchangeToken(ctx)
if err != nil {
return err
}
client.SessionToken = resp.SessionToken
return nil
},
EnvironmentVariables: map[string]string{
// Override the "CODER_AGENT_TOKEN" variable in all
// shells so "gitssh" works!
"CODER_AGENT_TOKEN": client.SessionToken,
},
CoordinatorDialer: client.ListenWorkspaceAgentTailnet,
StatsReporter: client.AgentReportStats,
WorkspaceAgentApps: client.WorkspaceAgentApps,
PostWorkspaceAgentAppHealth: client.PostWorkspaceAgentAppHealth,
})
<-cmd.Context().Done()
return closer.Close()

View File

@ -106,9 +106,8 @@ func TestConfigSSH(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
defer func() {
_ = agentCloser.Close()

View File

@ -24,9 +24,8 @@ func TestSpeedtest(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = agentToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
defer agentCloser.Close()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)

View File

@ -89,9 +89,8 @@ func TestSSH(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = agentToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
defer func() {
_ = agentCloser.Close()
@ -110,9 +109,8 @@ func TestSSH(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = agentToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
<-ctx.Done()
_ = agentCloser.Close()
@ -178,9 +176,8 @@ func TestSSH(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = agentToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
defer agentCloser.Close()

View File

@ -471,7 +471,6 @@ func New(options *Options) *API {
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
r.Route("/me", func(r chi.Router) {
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
r.Get("/apps", api.workspaceAgentApps)
r.Get("/metadata", api.workspaceAgentMetadata)
r.Post("/version", api.postWorkspaceAgentVersion)
r.Post("/app-health", api.postWorkspaceAppHealth)

View File

@ -57,7 +57,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/apps": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},

View File

@ -603,10 +603,8 @@ func TestTemplateMetrics(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
Logger: slogtest.Make(t, nil),
StatsReporter: agentClient.AgentReportStats,
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil),
Client: agentClient,
})
defer func() {
_ = agentCloser.Close()

View File

@ -61,20 +61,6 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, apiAgent)
}
func (api *API) workspaceAgentApps(rw http.ResponseWriter, r *http.Request) {
workspaceAgent := httpmw.WorkspaceAgent(r)
dbApps, err := api.Database.GetWorkspaceAppsByAgentID(r.Context(), workspaceAgent.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent applications.",
Detail: err.Error(),
})
return
}
httpapi.Write(r.Context(), rw, http.StatusOK, convertApps(dbApps))
}
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
@ -86,8 +72,17 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
})
return
}
dbApps, err := api.Database.GetWorkspaceAppsByAgentID(r.Context(), workspaceAgent.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent applications.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.WorkspaceAgentMetadata{
Apps: convertApps(dbApps),
DERPMap: api.DERPMap,
EnvironmentVariables: apiAgent.EnvironmentVariables,
StartupScript: apiAgent.StartupScript,

View File

@ -111,9 +111,8 @@ func TestWorkspaceAgentListen(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
_ = agentCloser.Close()
@ -204,7 +203,7 @@ func TestWorkspaceAgentListen(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
_, err = agentClient.ListenWorkspaceAgentTailnet(ctx)
_, err = agentClient.ListenWorkspaceAgent(ctx)
require.Error(t, err)
require.ErrorContains(t, err, "build is outdated")
})
@ -244,9 +243,8 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer agentCloser.Close()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
@ -311,9 +309,8 @@ func TestWorkspaceAgentPTY(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
_ = agentCloser.Close()
@ -409,9 +406,8 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
t.Cleanup(func() {
_ = agentCloser.Close()
@ -671,10 +667,10 @@ func TestWorkspaceAgentAppHealth(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
apiApps, err := agentClient.WorkspaceAgentApps(ctx)
metadata, err := agentClient.WorkspaceAgentMetadata(ctx)
require.NoError(t, err)
require.EqualValues(t, codersdk.WorkspaceAppHealthDisabled, apiApps[0].Health)
require.EqualValues(t, codersdk.WorkspaceAppHealthInitializing, apiApps[1].Health)
require.EqualValues(t, codersdk.WorkspaceAppHealthDisabled, metadata.Apps[0].Health)
require.EqualValues(t, codersdk.WorkspaceAppHealthInitializing, metadata.Apps[1].Health)
err = agentClient.PostWorkspaceAgentAppHealth(ctx, codersdk.PostWorkspaceAppHealthsRequest{})
require.Error(t, err)
// empty
@ -708,9 +704,9 @@ func TestWorkspaceAgentAppHealth(t *testing.T) {
},
})
require.NoError(t, err)
apiApps, err = agentClient.WorkspaceAgentApps(ctx)
metadata, err = agentClient.WorkspaceAgentMetadata(ctx)
require.NoError(t, err)
require.EqualValues(t, codersdk.WorkspaceAppHealthHealthy, apiApps[1].Health)
require.EqualValues(t, codersdk.WorkspaceAppHealthHealthy, metadata.Apps[1].Health)
// update to unhealthy
err = agentClient.PostWorkspaceAgentAppHealth(ctx, codersdk.PostWorkspaceAppHealthsRequest{
Healths: map[string]codersdk.WorkspaceAppHealth{
@ -718,7 +714,7 @@ func TestWorkspaceAgentAppHealth(t *testing.T) {
},
})
require.NoError(t, err)
apiApps, err = agentClient.WorkspaceAgentApps(ctx)
metadata, err = agentClient.WorkspaceAgentMetadata(ctx)
require.NoError(t, err)
require.EqualValues(t, codersdk.WorkspaceAppHealthUnhealthy, apiApps[1].Health)
require.EqualValues(t, codersdk.WorkspaceAppHealthUnhealthy, metadata.Apps[1].Health)
}

View File

@ -195,10 +195,8 @@ func createWorkspaceWithApps(t *testing.T, client *codersdk.Client, orgID uuid.U
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
StatsReporter: agentClient.AgentReportStats,
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
t.Cleanup(func() {
_ = agentCloser.Close()

View File

@ -3,12 +3,14 @@ package wsconncache_test
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"strings"
"sync"
"testing"
"time"
@ -148,17 +150,11 @@ func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeo
coordinator := tailnet.NewCoordinator()
agentID := uuid.New()
closer := agent.New(agent.Options{
FetchMetadata: func(ctx context.Context) (codersdk.WorkspaceAgentMetadata, error) {
return metadata, nil
},
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = serverConn.Close()
_ = clientConn.Close()
})
go coordinator.ServeAgent(serverConn, agentID)
return clientConn, nil
Client: &client{
t: t,
agentID: agentID,
metadata: metadata,
coordinator: coordinator,
},
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelInfo),
ReconnectingPTYTimeout: ptyTimeout,
@ -187,3 +183,41 @@ func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeo
Conn: conn,
}
}
type client struct {
t *testing.T
agentID uuid.UUID
metadata codersdk.WorkspaceAgentMetadata
coordinator tailnet.Coordinator
}
func (c *client) WorkspaceAgentMetadata(_ context.Context) (codersdk.WorkspaceAgentMetadata, error) {
return c.metadata, nil
}
func (c *client) ListenWorkspaceAgent(_ context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
closed := make(chan struct{})
c.t.Cleanup(func() {
_ = serverConn.Close()
_ = clientConn.Close()
<-closed
})
go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID)
close(closed)
}()
return clientConn, nil
}
func (*client) AgentReportStats(_ context.Context, _ slog.Logger, _ func() *codersdk.AgentStats) (io.Closer, error) {
return io.NopCloser(strings.NewReader("")), nil
}
func (*client) PostWorkspaceAgentAppHealth(_ context.Context, _ codersdk.PostWorkspaceAppHealthsRequest) error {
return nil
}
func (*client) PostWorkspaceAgentVersion(_ context.Context, _ string) error {
return nil
}

View File

@ -105,6 +105,9 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
// readBodyAsError reads the response as an .Message, and
// wraps it in a codersdk.Error type for easy marshaling.
func readBodyAsError(res *http.Response) error {
if res == nil {
return xerrors.Errorf("no body returned")
}
defer res.Body.Close()
contentType := res.Header.Get("Content-Type")

View File

@ -118,6 +118,7 @@ type PostWorkspaceAgentVersionRequest struct {
// @typescript-ignore WorkspaceAgentMetadata
type WorkspaceAgentMetadata struct {
Apps []WorkspaceApp `json:"apps"`
DERPMap *tailcfg.DERPMap `json:"derpmap"`
EnvironmentVariables map[string]string `json:"environment_variables"`
StartupScript string `json:"startup_script"`
@ -301,7 +302,7 @@ func (c *Client) WorkspaceAgentMetadata(ctx context.Context) (WorkspaceAgentMeta
return agentMetadata, nil
}
func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, error) {
func (c *Client) ListenWorkspaceAgent(ctx context.Context) (net.Conn, error) {
coordinateURL, err := c.URL.Parse("/api/v2/workspaceagents/me/coordinate")
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
@ -460,20 +461,6 @@ func (c *Client) WorkspaceAgent(ctx context.Context, id uuid.UUID) (WorkspaceAge
return workspaceAgent, json.NewDecoder(res.Body).Decode(&workspaceAgent)
}
// MyWorkspaceAgent returns the requesting agent.
func (c *Client) WorkspaceAgentApps(ctx context.Context) ([]WorkspaceApp, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/apps", nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, readBodyAsError(res)
}
var workspaceApps []WorkspaceApp
return workspaceApps, json.NewDecoder(res.Body).Decode(&workspaceApps)
}
// PostWorkspaceAgentAppHealth updates the workspace agent app health status.
func (c *Client) PostWorkspaceAgentAppHealth(ctx context.Context, req PostWorkspaceAppHealthsRequest) error {
res, err := c.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/app-health", req)
@ -580,7 +567,8 @@ func (c *Client) AgentReportStats(
}})
httpClient := &http.Client{
Jar: jar,
Jar: jar,
Transport: c.HTTPClient.Transport,
}
doneCh := make(chan struct{})

View File

@ -119,9 +119,8 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr
}
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
t.Cleanup(func() {
_ = agentCloser.Close()