feat: change agent to use v2 API for reporting stats (#12024)

Modifies the agent to use the v2 API to report its statistics, using the `statsReporter` subcomponent.
This commit is contained in:
Spike Curtis 2024-02-07 15:26:41 +04:00 committed by GitHub
parent 70ad833b02
commit 1cf4b62867
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 137 additions and 303 deletions

View File

@ -89,7 +89,6 @@ type Options struct {
type Client interface { type Client interface {
ConnectRPC(ctx context.Context) (drpc.Conn, error) ConnectRPC(ctx context.Context) (drpc.Conn, error)
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
@ -158,7 +157,6 @@ func New(options Options) Agent {
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}}, lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
ignorePorts: options.IgnorePorts, ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration, portCacheDuration: options.PortCacheDuration,
connStatsChan: make(chan *agentsdk.Stats, 1),
reportMetadataInterval: options.ReportMetadataInterval, reportMetadataInterval: options.ReportMetadataInterval,
serviceBannerRefreshInterval: options.ServiceBannerRefreshInterval, serviceBannerRefreshInterval: options.ServiceBannerRefreshInterval,
sshMaxTimeout: options.SSHMaxTimeout, sshMaxTimeout: options.SSHMaxTimeout,
@ -216,8 +214,7 @@ type agent struct {
network *tailnet.Conn network *tailnet.Conn
addresses []netip.Prefix addresses []netip.Prefix
connStatsChan chan *agentsdk.Stats statsReporter *statsReporter
latestStat atomic.Pointer[agentsdk.Stats]
connCountReconnectingPTY atomic.Int64 connCountReconnectingPTY atomic.Int64
@ -822,14 +819,13 @@ func (a *agent) run(ctx context.Context) error {
closed := a.isClosed() closed := a.isClosed()
if !closed { if !closed {
a.network = network a.network = network
a.statsReporter = newStatsReporter(a.logger, network, a)
} }
a.closeMutex.Unlock() a.closeMutex.Unlock()
if closed { if closed {
_ = network.Close() _ = network.Close()
return xerrors.New("agent is closed") return xerrors.New("agent is closed")
} }
a.startReportingConnectionStats(ctx)
} else { } else {
// Update the wireguard IPs if the agent ID changed. // Update the wireguard IPs if the agent ID changed.
err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID))
@ -871,6 +867,15 @@ func (a *agent) run(ctx context.Context) error {
return nil return nil
}) })
eg.Go(func() error {
a.logger.Debug(egCtx, "running stats report loop")
err := a.statsReporter.reportLoop(egCtx, aAPI)
if err != nil {
return xerrors.Errorf("report stats loop: %w", err)
}
return nil
})
return eg.Wait() return eg.Wait()
} }
@ -1218,115 +1223,83 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger) return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
} }
// startReportingConnectionStats runs the connection stats reporting goroutine. // Collect collects additional stats from the agent
func (a *agent) startReportingConnectionStats(ctx context.Context) { func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
reportStats := func(networkStats map[netlogtype.Connection]netlogtype.Counts) { a.logger.Debug(context.Background(), "computing stats report")
a.logger.Debug(ctx, "computing stats report") stats := &proto.Stats{
stats := &agentsdk.Stats{ ConnectionCount: int64(len(networkStats)),
ConnectionCount: int64(len(networkStats)), ConnectionsByProto: map[string]int64{},
ConnectionsByProto: map[string]int64{}, }
} for conn, counts := range networkStats {
for conn, counts := range networkStats { stats.ConnectionsByProto[conn.Proto.String()]++
stats.ConnectionsByProto[conn.Proto.String()]++ stats.RxBytes += int64(counts.RxBytes)
stats.RxBytes += int64(counts.RxBytes) stats.RxPackets += int64(counts.RxPackets)
stats.RxPackets += int64(counts.RxPackets) stats.TxBytes += int64(counts.TxBytes)
stats.TxBytes += int64(counts.TxBytes) stats.TxPackets += int64(counts.TxPackets)
stats.TxPackets += int64(counts.TxPackets)
}
// The count of active sessions.
sshStats := a.sshServer.ConnStats()
stats.SessionCountSSH = sshStats.Sessions
stats.SessionCountVSCode = sshStats.VSCode
stats.SessionCountJetBrains = sshStats.JetBrains
stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load()
// Compute the median connection latency!
a.logger.Debug(ctx, "starting peer latency measurement for stats")
var wg sync.WaitGroup
var mu sync.Mutex
status := a.network.Status()
durations := []float64{}
pingCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for nodeID, peer := range status.Peer {
if !peer.Active {
continue
}
addresses, found := a.network.NodeAddresses(nodeID)
if !found {
continue
}
if len(addresses) == 0 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
duration, _, _, err := a.network.Ping(pingCtx, addresses[0].Addr())
if err != nil {
return
}
mu.Lock()
durations = append(durations, float64(duration.Microseconds()))
mu.Unlock()
}()
}
wg.Wait()
sort.Float64s(durations)
durationsLength := len(durations)
if durationsLength == 0 {
stats.ConnectionMedianLatencyMS = -1
} else if durationsLength%2 == 0 {
stats.ConnectionMedianLatencyMS = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2
} else {
stats.ConnectionMedianLatencyMS = durations[durationsLength/2]
}
// Convert from microseconds to milliseconds.
stats.ConnectionMedianLatencyMS /= 1000
// Collect agent metrics.
// Agent metrics are changing all the time, so there is no need to perform
// reflect.DeepEqual to see if stats should be transferred.
metricsCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
a.logger.Debug(ctx, "collecting agent metrics for stats")
stats.Metrics = a.collectMetrics(metricsCtx)
a.latestStat.Store(stats)
a.logger.Debug(ctx, "about to send stats")
select {
case a.connStatsChan <- stats:
a.logger.Debug(ctx, "successfully sent stats")
case <-a.closed:
a.logger.Debug(ctx, "didn't send stats because we are closed")
}
} }
// Report statistics from the created network. // The count of active sessions.
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, func(d time.Duration) { sshStats := a.sshServer.ConnStats()
a.network.SetConnStatsCallback(d, 2048, stats.SessionCountSsh = sshStats.Sessions
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) { stats.SessionCountVscode = sshStats.VSCode
reportStats(virtual) stats.SessionCountJetbrains = sshStats.JetBrains
},
) stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load()
})
if err != nil { // Compute the median connection latency!
a.logger.Error(ctx, "agent failed to report stats", slog.Error(err)) a.logger.Debug(ctx, "starting peer latency measurement for stats")
var wg sync.WaitGroup
var mu sync.Mutex
status := a.network.Status()
durations := []float64{}
pingCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for nodeID, peer := range status.Peer {
if !peer.Active {
continue
}
addresses, found := a.network.NodeAddresses(nodeID)
if !found {
continue
}
if len(addresses) == 0 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
duration, _, _, err := a.network.Ping(pingCtx, addresses[0].Addr())
if err != nil {
return
}
mu.Lock()
defer mu.Unlock()
durations = append(durations, float64(duration.Microseconds()))
}()
}
wg.Wait()
sort.Float64s(durations)
durationsLength := len(durations)
if durationsLength == 0 {
stats.ConnectionMedianLatencyMs = -1
} else if durationsLength%2 == 0 {
stats.ConnectionMedianLatencyMs = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2
} else { } else {
if err = a.trackConnGoroutine(func() { stats.ConnectionMedianLatencyMs = durations[durationsLength/2]
// This is OK because the agent never re-creates the tailnet
// and the only shutdown indicator is agent.Close().
<-a.closed
_ = cl.Close()
}); err != nil {
a.logger.Debug(ctx, "report stats goroutine", slog.Error(err))
_ = cl.Close()
}
} }
// Convert from microseconds to milliseconds.
stats.ConnectionMedianLatencyMs /= 1000
// Collect agent metrics.
// Agent metrics are changing all the time, so there is no need to perform
// reflect.DeepEqual to see if stats should be transferred.
metricsCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
a.logger.Debug(ctx, "collecting agent metrics for stats")
stats.Metrics = a.collectMetrics(metricsCtx)
return stats
} }
var prioritizedProcs = []string{"coder agent"} var prioritizedProcs = []string{"coder agent"}

View File

@ -52,6 +52,7 @@ import (
"github.com/coder/coder/v2/agent/agentproc/agentproctest" "github.com/coder/coder/v2/agent/agentproc/agentproctest"
"github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/pty/ptytest"
@ -85,11 +86,11 @@ func TestAgent_Stats_SSH(t *testing.T) {
err = session.Shell() err = session.Shell()
require.NoError(t, err) require.NoError(t, err)
var s *agentsdk.Stats var s *proto.Stats
require.Eventuallyf(t, func() bool { require.Eventuallyf(t, func() bool {
var ok bool var ok bool
s, ok = <-stats s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSSH == 1 return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
}, testutil.WaitLong, testutil.IntervalFast, }, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s, "never saw stats: %+v", s,
) )
@ -118,11 +119,11 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
_, err = ptyConn.Write(data) _, err = ptyConn.Write(data)
require.NoError(t, err) require.NoError(t, err)
var s *agentsdk.Stats var s *proto.Stats
require.Eventuallyf(t, func() bool { require.Eventuallyf(t, func() bool {
var ok bool var ok bool
s, ok = <-stats s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPTY == 1 return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPty == 1
}, testutil.WaitLong, testutil.IntervalFast, }, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s, "never saw stats: %+v", s,
) )
@ -177,14 +178,14 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool { require.Eventuallyf(t, func() bool {
s, ok := <-stats s, ok := <-stats
t.Logf("got stats: ok=%t, ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountVSCode=%d, ConnectionMedianLatencyMS=%f", t.Logf("got stats: ok=%t, ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountVSCode=%d, ConnectionMedianLatencyMS=%f",
ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVSCode, s.ConnectionMedianLatencyMS) ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVscode, s.ConnectionMedianLatencyMs)
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 &&
// Ensure that the connection didn't count as a "normal" SSH session. // Ensure that the connection didn't count as a "normal" SSH session.
// This was a special one, so it should be labeled specially in the stats! // This was a special one, so it should be labeled specially in the stats!
s.SessionCountVSCode == 1 && s.SessionCountVscode == 1 &&
// Ensure that connection latency is being counted! // Ensure that connection latency is being counted!
// If it isn't, it's set to -1. // If it isn't, it's set to -1.
s.ConnectionMedianLatencyMS >= 0 s.ConnectionMedianLatencyMs >= 0
}, testutil.WaitLong, testutil.IntervalFast, }, testutil.WaitLong, testutil.IntervalFast,
"never saw stats", "never saw stats",
) )
@ -243,9 +244,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool { require.Eventuallyf(t, func() bool {
s, ok := <-stats s, ok := <-stats
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d", t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
ok, s.ConnectionCount, s.SessionCountJetBrains) ok, s.ConnectionCount, s.SessionCountJetbrains)
return ok && s.ConnectionCount > 0 && return ok && s.ConnectionCount > 0 &&
s.SessionCountJetBrains == 1 s.SessionCountJetbrains == 1
}, testutil.WaitLong, testutil.IntervalFast, }, testutil.WaitLong, testutil.IntervalFast,
"never saw stats with conn open", "never saw stats with conn open",
) )
@ -258,9 +259,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool { require.Eventuallyf(t, func() bool {
s, ok := <-stats s, ok := <-stats
t.Logf("got stats after disconnect %t, %d", t.Logf("got stats after disconnect %t, %d",
ok, s.SessionCountJetBrains) ok, s.SessionCountJetbrains)
return ok && return ok &&
s.SessionCountJetBrains == 0 s.SessionCountJetbrains == 0
}, testutil.WaitLong, testutil.IntervalFast, }, testutil.WaitLong, testutil.IntervalFast,
"never saw stats after conn closes", "never saw stats after conn closes",
) )
@ -1346,7 +1347,7 @@ func TestAgent_Lifecycle(t *testing.T) {
RunOnStop: true, RunOnStop: true,
}}, }},
}, },
make(chan *agentsdk.Stats, 50), make(chan *proto.Stats, 50),
tailnet.NewCoordinator(logger), tailnet.NewCoordinator(logger),
) )
defer client.Close() defer client.Close()
@ -1667,7 +1668,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
_ = coordinator.Close() _ = coordinator.Close()
}) })
agentID := uuid.New() agentID := uuid.New()
statsCh := make(chan *agentsdk.Stats, 50) statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
client := agenttest.NewClient(t, client := agenttest.NewClient(t,
logger.Named("agent"), logger.Named("agent"),
@ -1816,7 +1817,7 @@ func TestAgent_Reconnect(t *testing.T) {
defer coordinator.Close() defer coordinator.Close()
agentID := uuid.New() agentID := uuid.New()
statsCh := make(chan *agentsdk.Stats, 50) statsCh := make(chan *proto.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t) derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t, client := agenttest.NewClient(t,
logger, logger,
@ -1861,7 +1862,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
GitAuthConfigs: 1, GitAuthConfigs: 1,
DERPMap: &tailcfg.DERPMap{}, DERPMap: &tailcfg.DERPMap{},
}, },
make(chan *agentsdk.Stats, 50), make(chan *proto.Stats, 50),
coordinator, coordinator,
) )
defer client.Close() defer client.Close()
@ -2018,7 +2019,7 @@ func setupSSHSession(
func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
*codersdk.WorkspaceAgentConn, *codersdk.WorkspaceAgentConn,
*agenttest.Client, *agenttest.Client,
<-chan *agentsdk.Stats, <-chan *proto.Stats,
afero.Fs, afero.Fs,
agent.Agent, agent.Agent,
) { ) {
@ -2046,7 +2047,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
t.Cleanup(func() { t.Cleanup(func() {
_ = coordinator.Close() _ = coordinator.Close()
}) })
statsCh := make(chan *agentsdk.Stats, 50) statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator) c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator)
t.Cleanup(c.Close) t.Cleanup(c.Close)

View File

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"storj.io/drpc" "storj.io/drpc"
"storj.io/drpc/drpcmux" "storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver" "storj.io/drpc/drpcserver"
@ -27,11 +28,13 @@ import (
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
const statsInterval = 500 * time.Millisecond
func NewClient(t testing.TB, func NewClient(t testing.TB,
logger slog.Logger, logger slog.Logger,
agentID uuid.UUID, agentID uuid.UUID,
manifest agentsdk.Manifest, manifest agentsdk.Manifest,
statsChan chan *agentsdk.Stats, statsChan chan *agentproto.Stats,
coordinator tailnet.Coordinator, coordinator tailnet.Coordinator,
) *Client { ) *Client {
if manifest.AgentID == uuid.Nil { if manifest.AgentID == uuid.Nil {
@ -51,7 +54,7 @@ func NewClient(t testing.TB,
require.NoError(t, err) require.NoError(t, err)
mp, err := agentsdk.ProtoFromManifest(manifest) mp, err := agentsdk.ProtoFromManifest(manifest)
require.NoError(t, err) require.NoError(t, err)
fakeAAPI := NewFakeAgentAPI(t, logger, mp) fakeAAPI := NewFakeAgentAPI(t, logger, mp, statsChan)
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI) err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
require.NoError(t, err) require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{ server := drpcserver.NewWithOptions(mux, drpcserver.Options{
@ -66,7 +69,6 @@ func NewClient(t testing.TB,
t: t, t: t,
logger: logger.Named("client"), logger: logger.Named("client"),
agentID: agentID, agentID: agentID,
statsChan: statsChan,
coordinator: coordinator, coordinator: coordinator,
server: server, server: server,
fakeAgentAPI: fakeAAPI, fakeAgentAPI: fakeAAPI,
@ -79,7 +81,6 @@ type Client struct {
logger slog.Logger logger slog.Logger
agentID uuid.UUID agentID uuid.UUID
metadata map[string]agentsdk.Metadata metadata map[string]agentsdk.Metadata
statsChan chan *agentsdk.Stats
coordinator tailnet.Coordinator coordinator tailnet.Coordinator
server *drpcserver.Server server *drpcserver.Server
fakeAgentAPI *FakeAgentAPI fakeAgentAPI *FakeAgentAPI
@ -121,38 +122,6 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
return conn, nil return conn, nil
} }
func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
go func() {
defer close(doneCh)
setInterval(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case stat := <-statsChan:
select {
case c.statsChan <- stat:
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}
}()
return closeFunc(func() error {
cancel()
<-doneCh
close(c.statsChan)
return nil
}), nil
}
func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle { func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -223,12 +192,6 @@ func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
return nil return nil
} }
type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
type FakeAgentAPI struct { type FakeAgentAPI struct {
sync.Mutex sync.Mutex
t testing.TB t testing.TB
@ -236,6 +199,7 @@ type FakeAgentAPI struct {
manifest *agentproto.Manifest manifest *agentproto.Manifest
startupCh chan *agentproto.Startup startupCh chan *agentproto.Startup
statsCh chan *agentproto.Stats
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
} }
@ -264,9 +228,13 @@ func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceB
return agentsdk.ProtoFromServiceBanner(sb), nil return agentsdk.ProtoFromServiceBanner(sb), nil
} }
func (*FakeAgentAPI) UpdateStats(context.Context, *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) { func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
// TODO implement me f.logger.Debug(ctx, "update stats called", slog.F("req", req))
panic("implement me") // empty request is sent to get the interval; but our tests don't want empty stats requests
if req.Stats != nil {
f.statsCh <- req.Stats
}
return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil
} }
func (*FakeAgentAPI) UpdateLifecycle(context.Context, *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) { func (*FakeAgentAPI) UpdateLifecycle(context.Context, *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
@ -294,11 +262,12 @@ func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLog
panic("implement me") panic("implement me")
} }
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI { func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI {
return &FakeAgentAPI{ return &FakeAgentAPI{
t: t, t: t,
logger: logger.Named("FakeAgentAPI"), logger: logger.Named("FakeAgentAPI"),
manifest: manifest, manifest: manifest,
statsCh: statsCh,
startupCh: make(chan *agentproto.Startup, 100), startupCh: make(chan *agentproto.Startup, 100),
} }
} }

View File

@ -10,8 +10,7 @@ import (
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk/agentsdk"
) )
type agentMetrics struct { type agentMetrics struct {
@ -53,8 +52,8 @@ func newAgentMetrics(registerer prometheus.Registerer) *agentMetrics {
} }
} }
func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric { func (a *agent) collectMetrics(ctx context.Context) []*proto.Stats_Metric {
var collected []agentsdk.AgentMetric var collected []*proto.Stats_Metric
// Tailscale internal metrics // Tailscale internal metrics
metrics := clientmetric.Metrics() metrics := clientmetric.Metrics()
@ -63,7 +62,7 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric {
continue continue
} }
collected = append(collected, agentsdk.AgentMetric{ collected = append(collected, &proto.Stats_Metric{
Name: m.Name(), Name: m.Name(),
Type: asMetricType(m.Type()), Type: asMetricType(m.Type()),
Value: float64(m.Value()), Value: float64(m.Value()),
@ -81,16 +80,16 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric {
labels := toAgentMetricLabels(metric.Label) labels := toAgentMetricLabels(metric.Label)
if metric.Counter != nil { if metric.Counter != nil {
collected = append(collected, agentsdk.AgentMetric{ collected = append(collected, &proto.Stats_Metric{
Name: metricFamily.GetName(), Name: metricFamily.GetName(),
Type: agentsdk.AgentMetricTypeCounter, Type: proto.Stats_Metric_COUNTER,
Value: metric.Counter.GetValue(), Value: metric.Counter.GetValue(),
Labels: labels, Labels: labels,
}) })
} else if metric.Gauge != nil { } else if metric.Gauge != nil {
collected = append(collected, agentsdk.AgentMetric{ collected = append(collected, &proto.Stats_Metric{
Name: metricFamily.GetName(), Name: metricFamily.GetName(),
Type: agentsdk.AgentMetricTypeGauge, Type: proto.Stats_Metric_GAUGE,
Value: metric.Gauge.GetValue(), Value: metric.Gauge.GetValue(),
Labels: labels, Labels: labels,
}) })
@ -102,14 +101,14 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric {
return collected return collected
} }
func toAgentMetricLabels(metricLabels []*prompb.LabelPair) []agentsdk.AgentMetricLabel { func toAgentMetricLabels(metricLabels []*prompb.LabelPair) []*proto.Stats_Metric_Label {
if len(metricLabels) == 0 { if len(metricLabels) == 0 {
return nil return nil
} }
labels := make([]agentsdk.AgentMetricLabel, 0, len(metricLabels)) labels := make([]*proto.Stats_Metric_Label, 0, len(metricLabels))
for _, metricLabel := range metricLabels { for _, metricLabel := range metricLabels {
labels = append(labels, agentsdk.AgentMetricLabel{ labels = append(labels, &proto.Stats_Metric_Label{
Name: metricLabel.GetName(), Name: metricLabel.GetName(),
Value: metricLabel.GetValue(), Value: metricLabel.GetValue(),
}) })
@ -130,12 +129,12 @@ func isIgnoredMetric(metricName string) bool {
return false return false
} }
func asMetricType(typ clientmetric.Type) agentsdk.AgentMetricType { func asMetricType(typ clientmetric.Type) proto.Stats_Metric_Type {
switch typ { switch typ {
case clientmetric.TypeGauge: case clientmetric.TypeGauge:
return agentsdk.AgentMetricTypeGauge return proto.Stats_Metric_GAUGE
case clientmetric.TypeCounter: case clientmetric.TypeCounter:
return agentsdk.AgentMetricTypeCounter return proto.Stats_Metric_COUNTER
default: default:
panic(fmt.Sprintf("unknown metric type: %d", typ)) panic(fmt.Sprintf("unknown metric type: %d", typ))
} }

View File

@ -24,6 +24,7 @@ import (
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
@ -327,7 +328,7 @@ func setupServerTailnetAgent(t *testing.T, agentNum int) ([]agentWithID, *coderd
DERPMap: derpMap, DERPMap: derpMap,
} }
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord) c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *proto.Stats, 50), coord)
t.Cleanup(c.Close) t.Cleanup(c.Close)
options := agent.Options{ options := agent.Options{

View File

@ -890,6 +890,7 @@ func TestWorkspaceAgentAppHealth(t *testing.T) {
require.EqualValues(t, codersdk.WorkspaceAppHealthUnhealthy, manifest.Apps[1].Health) require.EqualValues(t, codersdk.WorkspaceAppHealthUnhealthy, manifest.Apps[1].Health)
} }
// TestWorkspaceAgentReportStats tests the legacy (agent API v1) report stats endpoint.
func TestWorkspaceAgentReportStats(t *testing.T) { func TestWorkspaceAgentReportStats(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -24,7 +24,6 @@ import (
"github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc" drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/retry"
) )
// ExternalLogSourceID is the statically-defined ID of a log-source that // ExternalLogSourceID is the statically-defined ID of a log-source that
@ -390,61 +389,6 @@ func (c *Client) AuthAzureInstanceIdentity(ctx context.Context) (AuthenticateRes
return resp, json.NewDecoder(res.Body).Decode(&resp) return resp, json.NewDecoder(res.Body).Decode(&resp)
} }
// ReportStats begins a stat streaming connection with the Coder server.
// It is resilient to network failures and intermittent coderd issues.
func (c *Client) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *Stats, setInterval func(time.Duration)) (io.Closer, error) {
var interval time.Duration
ctx, cancel := context.WithCancel(ctx)
exited := make(chan struct{})
postStat := func(stat *Stats) {
var nextInterval time.Duration
for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); {
resp, err := c.PostStats(ctx, stat)
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "report stats", slog.Error(err))
}
continue
}
nextInterval = resp.ReportInterval
break
}
if nextInterval != 0 && interval != nextInterval {
setInterval(nextInterval)
}
interval = nextInterval
}
// Send an empty stat to get the interval.
postStat(&Stats{})
go func() {
defer close(exited)
for {
select {
case <-ctx.Done():
return
case stat, ok := <-statsChan:
if !ok {
return
}
postStat(stat)
}
}
}()
return closeFunc(func() error {
cancel()
<-exited
return nil
}), nil
}
// Stats records the Agent's network connection statistics for use in // Stats records the Agent's network connection statistics for use in
// user-facing metrics and debugging. // user-facing metrics and debugging.
type Stats struct { type Stats struct {
@ -509,6 +453,9 @@ type StatsResponse struct {
ReportInterval time.Duration `json:"report_interval"` ReportInterval time.Duration `json:"report_interval"`
} }
// PostStats sends agent stats to the coder server
//
// Deprecated: uses agent API v1 endpoint
func (c *Client) PostStats(ctx context.Context, stats *Stats) (StatsResponse, error) { func (c *Client) PostStats(ctx context.Context, stats *Stats) (StatsResponse, error) {
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/report-stats", stats) res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/report-stats", stats)
if err != nil { if err != nil {
@ -649,12 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
return authResp, json.NewDecoder(res.Body).Decode(&authResp) return authResp, json.NewDecoder(res.Body).Decode(&authResp)
} }
type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func // wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered. // is called if a read or write error is encountered.
type wsNetConn struct { type wsNetConn struct {

View File

@ -1,22 +1,13 @@
package codersdk_test package codersdk_test
import ( import (
"context"
"net/http"
"net/http/httptest"
"net/url" "net/url"
"sync/atomic"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
) )
func TestWorkspaceRewriteDERPMap(t *testing.T) { func TestWorkspaceRewriteDERPMap(t *testing.T) {
@ -46,45 +37,3 @@ func TestWorkspaceRewriteDERPMap(t *testing.T) {
require.Equal(t, "coconuts.org", node.HostName) require.Equal(t, "coconuts.org", node.HostName)
require.Equal(t, 44558, node.DERPPort) require.Equal(t, 44558, node.DERPPort)
} }
func TestAgentReportStats(t *testing.T) {
t.Parallel()
var (
numReports atomic.Int64
numIntervalCalls atomic.Int64
wantInterval = 5 * time.Millisecond
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numReports.Add(1)
httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.StatsResponse{
ReportInterval: wantInterval,
})
}))
parsed, err := url.Parse(srv.URL)
require.NoError(t, err)
client := agentsdk.New(parsed)
assertStatInterval := func(interval time.Duration) {
numIntervalCalls.Add(1)
assert.Equal(t, wantInterval, interval)
}
chanLen := 3
statCh := make(chan *agentsdk.Stats, chanLen)
for i := 0; i < chanLen; i++ {
statCh <- &agentsdk.Stats{ConnectionsByProto: map[string]int64{}}
}
ctx := context.Background()
closeStream, err := client.ReportStats(ctx, slogtest.Make(t, nil), statCh, assertStatInterval)
require.NoError(t, err)
defer closeStream.Close()
require.Eventually(t,
func() bool { return numReports.Load() >= 3 },
testutil.WaitMedium, testutil.IntervalFast,
)
closeStream.Close()
require.Equal(t, int64(1), numIntervalCalls.Load())
}