feat: change agent to use v2 API for reporting stats (#12024)

Modifies the agent to use the v2 API to report its statistics, using the `statsReporter` subcomponent.
This commit is contained in:
Spike Curtis 2024-02-07 15:26:41 +04:00 committed by GitHub
parent 70ad833b02
commit 1cf4b62867
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 137 additions and 303 deletions

View File

@ -89,7 +89,6 @@ type Options struct {
type Client interface {
ConnectRPC(ctx context.Context) (drpc.Conn, error)
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
@ -158,7 +157,6 @@ func New(options Options) Agent {
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration,
connStatsChan: make(chan *agentsdk.Stats, 1),
reportMetadataInterval: options.ReportMetadataInterval,
serviceBannerRefreshInterval: options.ServiceBannerRefreshInterval,
sshMaxTimeout: options.SSHMaxTimeout,
@ -216,8 +214,7 @@ type agent struct {
network *tailnet.Conn
addresses []netip.Prefix
connStatsChan chan *agentsdk.Stats
latestStat atomic.Pointer[agentsdk.Stats]
statsReporter *statsReporter
connCountReconnectingPTY atomic.Int64
@ -822,14 +819,13 @@ func (a *agent) run(ctx context.Context) error {
closed := a.isClosed()
if !closed {
a.network = network
a.statsReporter = newStatsReporter(a.logger, network, a)
}
a.closeMutex.Unlock()
if closed {
_ = network.Close()
return xerrors.New("agent is closed")
}
a.startReportingConnectionStats(ctx)
} else {
// Update the wireguard IPs if the agent ID changed.
err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID))
@ -871,6 +867,15 @@ func (a *agent) run(ctx context.Context) error {
return nil
})
eg.Go(func() error {
a.logger.Debug(egCtx, "running stats report loop")
err := a.statsReporter.reportLoop(egCtx, aAPI)
if err != nil {
return xerrors.Errorf("report stats loop: %w", err)
}
return nil
})
return eg.Wait()
}
@ -1218,115 +1223,83 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
}
// startReportingConnectionStats runs the connection stats reporting goroutine.
func (a *agent) startReportingConnectionStats(ctx context.Context) {
reportStats := func(networkStats map[netlogtype.Connection]netlogtype.Counts) {
a.logger.Debug(ctx, "computing stats report")
stats := &agentsdk.Stats{
ConnectionCount: int64(len(networkStats)),
ConnectionsByProto: map[string]int64{},
}
for conn, counts := range networkStats {
stats.ConnectionsByProto[conn.Proto.String()]++
stats.RxBytes += int64(counts.RxBytes)
stats.RxPackets += int64(counts.RxPackets)
stats.TxBytes += int64(counts.TxBytes)
stats.TxPackets += int64(counts.TxPackets)
}
// The count of active sessions.
sshStats := a.sshServer.ConnStats()
stats.SessionCountSSH = sshStats.Sessions
stats.SessionCountVSCode = sshStats.VSCode
stats.SessionCountJetBrains = sshStats.JetBrains
stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load()
// Compute the median connection latency!
a.logger.Debug(ctx, "starting peer latency measurement for stats")
var wg sync.WaitGroup
var mu sync.Mutex
status := a.network.Status()
durations := []float64{}
pingCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for nodeID, peer := range status.Peer {
if !peer.Active {
continue
}
addresses, found := a.network.NodeAddresses(nodeID)
if !found {
continue
}
if len(addresses) == 0 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
duration, _, _, err := a.network.Ping(pingCtx, addresses[0].Addr())
if err != nil {
return
}
mu.Lock()
durations = append(durations, float64(duration.Microseconds()))
mu.Unlock()
}()
}
wg.Wait()
sort.Float64s(durations)
durationsLength := len(durations)
if durationsLength == 0 {
stats.ConnectionMedianLatencyMS = -1
} else if durationsLength%2 == 0 {
stats.ConnectionMedianLatencyMS = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2
} else {
stats.ConnectionMedianLatencyMS = durations[durationsLength/2]
}
// Convert from microseconds to milliseconds.
stats.ConnectionMedianLatencyMS /= 1000
// Collect agent metrics.
// Agent metrics are changing all the time, so there is no need to perform
// reflect.DeepEqual to see if stats should be transferred.
metricsCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
a.logger.Debug(ctx, "collecting agent metrics for stats")
stats.Metrics = a.collectMetrics(metricsCtx)
a.latestStat.Store(stats)
a.logger.Debug(ctx, "about to send stats")
select {
case a.connStatsChan <- stats:
a.logger.Debug(ctx, "successfully sent stats")
case <-a.closed:
a.logger.Debug(ctx, "didn't send stats because we are closed")
}
// Collect collects additional stats from the agent
func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
a.logger.Debug(context.Background(), "computing stats report")
stats := &proto.Stats{
ConnectionCount: int64(len(networkStats)),
ConnectionsByProto: map[string]int64{},
}
for conn, counts := range networkStats {
stats.ConnectionsByProto[conn.Proto.String()]++
stats.RxBytes += int64(counts.RxBytes)
stats.RxPackets += int64(counts.RxPackets)
stats.TxBytes += int64(counts.TxBytes)
stats.TxPackets += int64(counts.TxPackets)
}
// Report statistics from the created network.
cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, func(d time.Duration) {
a.network.SetConnStatsCallback(d, 2048,
func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
reportStats(virtual)
},
)
})
if err != nil {
a.logger.Error(ctx, "agent failed to report stats", slog.Error(err))
// The count of active sessions.
sshStats := a.sshServer.ConnStats()
stats.SessionCountSsh = sshStats.Sessions
stats.SessionCountVscode = sshStats.VSCode
stats.SessionCountJetbrains = sshStats.JetBrains
stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load()
// Compute the median connection latency!
a.logger.Debug(ctx, "starting peer latency measurement for stats")
var wg sync.WaitGroup
var mu sync.Mutex
status := a.network.Status()
durations := []float64{}
pingCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for nodeID, peer := range status.Peer {
if !peer.Active {
continue
}
addresses, found := a.network.NodeAddresses(nodeID)
if !found {
continue
}
if len(addresses) == 0 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
duration, _, _, err := a.network.Ping(pingCtx, addresses[0].Addr())
if err != nil {
return
}
mu.Lock()
defer mu.Unlock()
durations = append(durations, float64(duration.Microseconds()))
}()
}
wg.Wait()
sort.Float64s(durations)
durationsLength := len(durations)
if durationsLength == 0 {
stats.ConnectionMedianLatencyMs = -1
} else if durationsLength%2 == 0 {
stats.ConnectionMedianLatencyMs = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2
} else {
if err = a.trackConnGoroutine(func() {
// This is OK because the agent never re-creates the tailnet
// and the only shutdown indicator is agent.Close().
<-a.closed
_ = cl.Close()
}); err != nil {
a.logger.Debug(ctx, "report stats goroutine", slog.Error(err))
_ = cl.Close()
}
stats.ConnectionMedianLatencyMs = durations[durationsLength/2]
}
// Convert from microseconds to milliseconds.
stats.ConnectionMedianLatencyMs /= 1000
// Collect agent metrics.
// Agent metrics are changing all the time, so there is no need to perform
// reflect.DeepEqual to see if stats should be transferred.
metricsCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
a.logger.Debug(ctx, "collecting agent metrics for stats")
stats.Metrics = a.collectMetrics(metricsCtx)
return stats
}
var prioritizedProcs = []string{"coder agent"}

View File

@ -52,6 +52,7 @@ import (
"github.com/coder/coder/v2/agent/agentproc/agentproctest"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty/ptytest"
@ -85,11 +86,11 @@ func TestAgent_Stats_SSH(t *testing.T) {
err = session.Shell()
require.NoError(t, err)
var s *agentsdk.Stats
var s *proto.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSSH == 1
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
@ -118,11 +119,11 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
_, err = ptyConn.Write(data)
require.NoError(t, err)
var s *agentsdk.Stats
var s *proto.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPTY == 1
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPty == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
@ -177,14 +178,14 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats: ok=%t, ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountVSCode=%d, ConnectionMedianLatencyMS=%f",
ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVSCode, s.ConnectionMedianLatencyMS)
ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVscode, s.ConnectionMedianLatencyMs)
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 &&
// Ensure that the connection didn't count as a "normal" SSH session.
// This was a special one, so it should be labeled specially in the stats!
s.SessionCountVSCode == 1 &&
s.SessionCountVscode == 1 &&
// Ensure that connection latency is being counted!
// If it isn't, it's set to -1.
s.ConnectionMedianLatencyMS >= 0
s.ConnectionMedianLatencyMs >= 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats",
)
@ -243,9 +244,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
ok, s.ConnectionCount, s.SessionCountJetBrains)
ok, s.ConnectionCount, s.SessionCountJetbrains)
return ok && s.ConnectionCount > 0 &&
s.SessionCountJetBrains == 1
s.SessionCountJetbrains == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats with conn open",
)
@ -258,9 +259,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats after disconnect %t, %d",
ok, s.SessionCountJetBrains)
ok, s.SessionCountJetbrains)
return ok &&
s.SessionCountJetBrains == 0
s.SessionCountJetbrains == 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats after conn closes",
)
@ -1346,7 +1347,7 @@ func TestAgent_Lifecycle(t *testing.T) {
RunOnStop: true,
}},
},
make(chan *agentsdk.Stats, 50),
make(chan *proto.Stats, 50),
tailnet.NewCoordinator(logger),
)
defer client.Close()
@ -1667,7 +1668,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
_ = coordinator.Close()
})
agentID := uuid.New()
statsCh := make(chan *agentsdk.Stats, 50)
statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs()
client := agenttest.NewClient(t,
logger.Named("agent"),
@ -1816,7 +1817,7 @@ func TestAgent_Reconnect(t *testing.T) {
defer coordinator.Close()
agentID := uuid.New()
statsCh := make(chan *agentsdk.Stats, 50)
statsCh := make(chan *proto.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
logger,
@ -1861,7 +1862,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
GitAuthConfigs: 1,
DERPMap: &tailcfg.DERPMap{},
},
make(chan *agentsdk.Stats, 50),
make(chan *proto.Stats, 50),
coordinator,
)
defer client.Close()
@ -2018,7 +2019,7 @@ func setupSSHSession(
func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
*codersdk.WorkspaceAgentConn,
*agenttest.Client,
<-chan *agentsdk.Stats,
<-chan *proto.Stats,
afero.Fs,
agent.Agent,
) {
@ -2046,7 +2047,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *agentsdk.Stats, 50)
statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs()
c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator)
t.Cleanup(c.Close)

View File

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"storj.io/drpc"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
@ -27,11 +28,13 @@ import (
"github.com/coder/coder/v2/testutil"
)
const statsInterval = 500 * time.Millisecond
func NewClient(t testing.TB,
logger slog.Logger,
agentID uuid.UUID,
manifest agentsdk.Manifest,
statsChan chan *agentsdk.Stats,
statsChan chan *agentproto.Stats,
coordinator tailnet.Coordinator,
) *Client {
if manifest.AgentID == uuid.Nil {
@ -51,7 +54,7 @@ func NewClient(t testing.TB,
require.NoError(t, err)
mp, err := agentsdk.ProtoFromManifest(manifest)
require.NoError(t, err)
fakeAAPI := NewFakeAgentAPI(t, logger, mp)
fakeAAPI := NewFakeAgentAPI(t, logger, mp, statsChan)
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
@ -66,7 +69,6 @@ func NewClient(t testing.TB,
t: t,
logger: logger.Named("client"),
agentID: agentID,
statsChan: statsChan,
coordinator: coordinator,
server: server,
fakeAgentAPI: fakeAAPI,
@ -79,7 +81,6 @@ type Client struct {
logger slog.Logger
agentID uuid.UUID
metadata map[string]agentsdk.Metadata
statsChan chan *agentsdk.Stats
coordinator tailnet.Coordinator
server *drpcserver.Server
fakeAgentAPI *FakeAgentAPI
@ -121,38 +122,6 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
return conn, nil
}
func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
go func() {
defer close(doneCh)
setInterval(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case stat := <-statsChan:
select {
case c.statsChan <- stat:
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}
}()
return closeFunc(func() error {
cancel()
<-doneCh
close(c.statsChan)
return nil
}), nil
}
func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
c.mu.Lock()
defer c.mu.Unlock()
@ -223,12 +192,6 @@ func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
return nil
}
type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
type FakeAgentAPI struct {
sync.Mutex
t testing.TB
@ -236,6 +199,7 @@ type FakeAgentAPI struct {
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup
statsCh chan *agentproto.Stats
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
}
@ -264,9 +228,13 @@ func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceB
return agentsdk.ProtoFromServiceBanner(sb), nil
}
func (*FakeAgentAPI) UpdateStats(context.Context, *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
// TODO implement me
panic("implement me")
func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
f.logger.Debug(ctx, "update stats called", slog.F("req", req))
// empty request is sent to get the interval; but our tests don't want empty stats requests
if req.Stats != nil {
f.statsCh <- req.Stats
}
return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil
}
func (*FakeAgentAPI) UpdateLifecycle(context.Context, *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
@ -294,11 +262,12 @@ func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLog
panic("implement me")
}
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI {
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI {
return &FakeAgentAPI{
t: t,
logger: logger.Named("FakeAgentAPI"),
manifest: manifest,
statsCh: statsCh,
startupCh: make(chan *agentproto.Startup, 100),
}
}

View File

@ -10,8 +10,7 @@ import (
"tailscale.com/util/clientmetric"
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/agent/proto"
)
type agentMetrics struct {
@ -53,8 +52,8 @@ func newAgentMetrics(registerer prometheus.Registerer) *agentMetrics {
}
}
func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric {
var collected []agentsdk.AgentMetric
func (a *agent) collectMetrics(ctx context.Context) []*proto.Stats_Metric {
var collected []*proto.Stats_Metric
// Tailscale internal metrics
metrics := clientmetric.Metrics()
@ -63,7 +62,7 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric {
continue
}
collected = append(collected, agentsdk.AgentMetric{
collected = append(collected, &proto.Stats_Metric{
Name: m.Name(),
Type: asMetricType(m.Type()),
Value: float64(m.Value()),
@ -81,16 +80,16 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric {
labels := toAgentMetricLabels(metric.Label)
if metric.Counter != nil {
collected = append(collected, agentsdk.AgentMetric{
collected = append(collected, &proto.Stats_Metric{
Name: metricFamily.GetName(),
Type: agentsdk.AgentMetricTypeCounter,
Type: proto.Stats_Metric_COUNTER,
Value: metric.Counter.GetValue(),
Labels: labels,
})
} else if metric.Gauge != nil {
collected = append(collected, agentsdk.AgentMetric{
collected = append(collected, &proto.Stats_Metric{
Name: metricFamily.GetName(),
Type: agentsdk.AgentMetricTypeGauge,
Type: proto.Stats_Metric_GAUGE,
Value: metric.Gauge.GetValue(),
Labels: labels,
})
@ -102,14 +101,14 @@ func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric {
return collected
}
func toAgentMetricLabels(metricLabels []*prompb.LabelPair) []agentsdk.AgentMetricLabel {
func toAgentMetricLabels(metricLabels []*prompb.LabelPair) []*proto.Stats_Metric_Label {
if len(metricLabels) == 0 {
return nil
}
labels := make([]agentsdk.AgentMetricLabel, 0, len(metricLabels))
labels := make([]*proto.Stats_Metric_Label, 0, len(metricLabels))
for _, metricLabel := range metricLabels {
labels = append(labels, agentsdk.AgentMetricLabel{
labels = append(labels, &proto.Stats_Metric_Label{
Name: metricLabel.GetName(),
Value: metricLabel.GetValue(),
})
@ -130,12 +129,12 @@ func isIgnoredMetric(metricName string) bool {
return false
}
func asMetricType(typ clientmetric.Type) agentsdk.AgentMetricType {
func asMetricType(typ clientmetric.Type) proto.Stats_Metric_Type {
switch typ {
case clientmetric.TypeGauge:
return agentsdk.AgentMetricTypeGauge
return proto.Stats_Metric_GAUGE
case clientmetric.TypeCounter:
return agentsdk.AgentMetricTypeCounter
return proto.Stats_Metric_COUNTER
default:
panic(fmt.Sprintf("unknown metric type: %d", typ))
}

View File

@ -24,6 +24,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
@ -327,7 +328,7 @@ func setupServerTailnetAgent(t *testing.T, agentNum int) ([]agentWithID, *coderd
DERPMap: derpMap,
}
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *proto.Stats, 50), coord)
t.Cleanup(c.Close)
options := agent.Options{

View File

@ -890,6 +890,7 @@ func TestWorkspaceAgentAppHealth(t *testing.T) {
require.EqualValues(t, codersdk.WorkspaceAppHealthUnhealthy, manifest.Apps[1].Health)
}
// TestWorkspaceAgentReportStats tests the legacy (agent API v1) report stats endpoint.
func TestWorkspaceAgentReportStats(t *testing.T) {
t.Parallel()

View File

@ -24,7 +24,6 @@ import (
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/retry"
)
// ExternalLogSourceID is the statically-defined ID of a log-source that
@ -390,61 +389,6 @@ func (c *Client) AuthAzureInstanceIdentity(ctx context.Context) (AuthenticateRes
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// ReportStats begins a stat streaming connection with the Coder server.
// It is resilient to network failures and intermittent coderd issues.
func (c *Client) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *Stats, setInterval func(time.Duration)) (io.Closer, error) {
var interval time.Duration
ctx, cancel := context.WithCancel(ctx)
exited := make(chan struct{})
postStat := func(stat *Stats) {
var nextInterval time.Duration
for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); {
resp, err := c.PostStats(ctx, stat)
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "report stats", slog.Error(err))
}
continue
}
nextInterval = resp.ReportInterval
break
}
if nextInterval != 0 && interval != nextInterval {
setInterval(nextInterval)
}
interval = nextInterval
}
// Send an empty stat to get the interval.
postStat(&Stats{})
go func() {
defer close(exited)
for {
select {
case <-ctx.Done():
return
case stat, ok := <-statsChan:
if !ok {
return
}
postStat(stat)
}
}
}()
return closeFunc(func() error {
cancel()
<-exited
return nil
}), nil
}
// Stats records the Agent's network connection statistics for use in
// user-facing metrics and debugging.
type Stats struct {
@ -509,6 +453,9 @@ type StatsResponse struct {
ReportInterval time.Duration `json:"report_interval"`
}
// PostStats sends agent stats to the coder server
//
// Deprecated: uses agent API v1 endpoint
func (c *Client) PostStats(ctx context.Context, stats *Stats) (StatsResponse, error) {
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/report-stats", stats)
if err != nil {
@ -649,12 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
return authResp, json.NewDecoder(res.Body).Decode(&authResp)
}
type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {

View File

@ -1,22 +1,13 @@
package codersdk_test
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
)
func TestWorkspaceRewriteDERPMap(t *testing.T) {
@ -46,45 +37,3 @@ func TestWorkspaceRewriteDERPMap(t *testing.T) {
require.Equal(t, "coconuts.org", node.HostName)
require.Equal(t, 44558, node.DERPPort)
}
func TestAgentReportStats(t *testing.T) {
t.Parallel()
var (
numReports atomic.Int64
numIntervalCalls atomic.Int64
wantInterval = 5 * time.Millisecond
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numReports.Add(1)
httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.StatsResponse{
ReportInterval: wantInterval,
})
}))
parsed, err := url.Parse(srv.URL)
require.NoError(t, err)
client := agentsdk.New(parsed)
assertStatInterval := func(interval time.Duration) {
numIntervalCalls.Add(1)
assert.Equal(t, wantInterval, interval)
}
chanLen := 3
statCh := make(chan *agentsdk.Stats, chanLen)
for i := 0; i < chanLen; i++ {
statCh <- &agentsdk.Stats{ConnectionsByProto: map[string]int64{}}
}
ctx := context.Background()
closeStream, err := client.ReportStats(ctx, slogtest.Make(t, nil), statCh, assertStatInterval)
require.NoError(t, err)
defer closeStream.Close()
require.Eventually(t,
func() bool { return numReports.Load() >= 3 },
testutil.WaitMedium, testutil.IntervalFast,
)
closeStream.Close()
require.Equal(t, int64(1), numIntervalCalls.Load())
}