chore: refactor agent stats streaming (#5112)

This commit is contained in:
Colin Adler 2022-11-18 16:46:53 -06:00 committed by GitHub
parent 13a4cfa670
commit ae38bbeab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 534 additions and 310 deletions

View File

@ -32,6 +32,7 @@ import (
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"tailscale.com/types/netlogtype"
"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
@ -98,7 +99,6 @@ func New(options Options) io.Closer {
exchangeToken: options.ExchangeToken,
filesystem: options.Filesystem,
tempDir: options.TempDir,
stats: &Stats{},
}
server.init(ctx)
return server
@ -126,7 +126,6 @@ type agent struct {
sshServer *ssh.Server
network *tailnet.Conn
stats *Stats
}
// runLoop attempts to start the agent in a retry loop.
@ -238,22 +237,16 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
return nil, xerrors.New("closed")
}
network, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
EnableTrafficStats: true,
})
if err != nil {
a.closeMutex.Unlock()
return nil, xerrors.Errorf("create tailnet: %w", err)
}
a.network = network
network.SetForwardTCPCallback(func(conn net.Conn, listenerExists bool) net.Conn {
if listenerExists {
// If a listener already exists, we would double-wrap the conn.
return conn
}
return a.stats.wrapConn(conn)
})
a.connCloseWait.Add(4)
a.closeMutex.Unlock()
@ -268,7 +261,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
if err != nil {
return
}
go a.sshServer.HandleConn(a.stats.wrapConn(conn))
go a.sshServer.HandleConn(conn)
}
}()
@ -284,7 +277,6 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
a.logger.Debug(ctx, "accept pty failed", slog.Error(err))
return
}
conn = a.stats.wrapConn(conn)
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
@ -523,7 +515,13 @@ func (a *agent) init(ctx context.Context) {
go a.runLoop(ctx)
cl, err := a.client.AgentReportStats(ctx, a.logger, func() *codersdk.AgentStats {
return a.stats.Copy()
stats := map[netlogtype.Connection]netlogtype.Counts{}
a.closeMutex.Lock()
if a.network != nil {
stats = a.network.ExtractTrafficStats()
}
a.closeMutex.Unlock()
return convertAgentStats(stats)
})
if err != nil {
a.logger.Error(ctx, "report stats", slog.Error(err))
@ -537,6 +535,23 @@ func (a *agent) init(ctx context.Context) {
}()
}
func convertAgentStats(counts map[netlogtype.Connection]netlogtype.Counts) *codersdk.AgentStats {
stats := &codersdk.AgentStats{
ConnsByProto: map[string]int64{},
NumConns: int64(len(counts)),
}
for conn, count := range counts {
stats.ConnsByProto[conn.Proto.String()]++
stats.RxPackets += int64(count.RxPackets)
stats.RxBytes += int64(count.RxBytes)
stats.TxPackets += int64(count.TxPackets)
stats.TxBytes += int64(count.TxBytes)
}
return stats
}
// createCommand processes raw command input with OpenSSH-like behavior.
// If the rawCommand provided is empty, it will default to the users shell.
// This injects environment variables specified by the user at launch too.

View File

@ -69,10 +69,16 @@ func TestAgent(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
require.NoError(t, session.Run("echo test"))
assert.EqualValues(t, 1, (<-stats).NumConns)
assert.Greater(t, (<-stats).RxBytes, int64(0))
assert.Greater(t, (<-stats).TxBytes, int64(0))
var s *codersdk.AgentStats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
})
t.Run("ReconnectingPTY", func(t *testing.T) {
@ -97,7 +103,7 @@ func TestAgent(t *testing.T) {
var s *codersdk.AgentStats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = (<-stats)
s, ok = <-stats
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
@ -675,7 +681,7 @@ func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeo
}
coordinator := tailnet.NewCoordinator()
agentID := uuid.New()
statsCh := make(chan *codersdk.AgentStats)
statsCh := make(chan *codersdk.AgentStats, 50)
fs := afero.NewMemMapFs()
closer := agent.New(agent.Options{
Client: &client{
@ -693,9 +699,10 @@ func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeo
_ = closer.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
EnableTrafficStats: true,
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
@ -781,7 +788,7 @@ func (c *client) AgentReportStats(ctx context.Context, _ slog.Logger, stats func
go func() {
defer close(doneCh)
t := time.NewTicker(time.Millisecond * 100)
t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
for {
select {

View File

@ -1,58 +0,0 @@
package agent
import (
"net"
"sync/atomic"
"github.com/coder/coder/codersdk"
)
// statsConn wraps a net.Conn with statistics.
type statsConn struct {
*Stats
net.Conn `json:"-"`
}
var _ net.Conn = new(statsConn)
func (c *statsConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
atomic.AddInt64(&c.RxBytes, int64(n))
return n, err
}
func (c *statsConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
atomic.AddInt64(&c.TxBytes, int64(n))
return n, err
}
var _ net.Conn = new(statsConn)
// Stats records the Agent's network connection statistics for use in
// user-facing metrics and debugging.
// Each member value must be written and read with atomic.
type Stats struct {
NumConns int64 `json:"num_comms"`
RxBytes int64 `json:"rx_bytes"`
TxBytes int64 `json:"tx_bytes"`
}
func (s *Stats) Copy() *codersdk.AgentStats {
return &codersdk.AgentStats{
NumConns: atomic.LoadInt64(&s.NumConns),
RxBytes: atomic.LoadInt64(&s.RxBytes),
TxBytes: atomic.LoadInt64(&s.TxBytes),
}
}
// wrapConn returns a new connection that records statistics.
func (s *Stats) wrapConn(conn net.Conn) net.Conn {
atomic.AddInt64(&s.NumConns, 1)
cs := &statsConn{
Stats: s,
Conn: conn,
}
return cs
}

View File

@ -6,6 +6,7 @@ import (
"errors"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
@ -14,14 +15,14 @@ import (
// activityBumpWorkspace automatically bumps the workspace's auto-off timer
// if it is set to expire soon.
func activityBumpWorkspace(log slog.Logger, db database.Store, workspace database.Workspace) {
func activityBumpWorkspace(log slog.Logger, db database.Store, workspaceID uuid.UUID) {
// We set a short timeout so if the app is under load, these
// low priority operations fail first.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
err := db.InTx(func(s database.Store) error {
build, err := s.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
build, err := s.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID)
if errors.Is(err, sql.ErrNoRows) {
return nil
} else if err != nil {
@ -65,15 +66,13 @@ func activityBumpWorkspace(log slog.Logger, db database.Store, workspace databas
return nil
}, nil)
if err != nil {
log.Error(
ctx, "bump failed",
slog.Error(err),
slog.F("workspace_id", workspace.ID),
)
} else {
log.Debug(
ctx, "bumped deadline from activity",
slog.F("workspace_id", workspace.ID),
log.Error(ctx, "bump failed", slog.Error(err),
slog.F("workspace_id", workspaceID),
)
return
}
log.Debug(ctx, "bumped deadline from activity",
slog.F("workspace_id", workspaceID),
)
}

View File

@ -8,7 +8,6 @@ import (
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"

View File

@ -132,7 +132,7 @@ func New(options *Options) *API {
options.APIRateLimit = 512
}
if options.AgentStatsRefreshInterval == 0 {
options.AgentStatsRefreshInterval = 10 * time.Minute
options.AgentStatsRefreshInterval = 5 * time.Minute
}
if options.MetricsCacheRefreshInterval == 0 {
options.MetricsCacheRefreshInterval = time.Hour
@ -493,7 +493,10 @@ func New(options *Options) *API {
r.Get("/gitauth", api.workspaceAgentsGitAuth)
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/coordinate", api.workspaceAgentCoordinate)
r.Get("/report-stats", api.workspaceAgentReportStats)
r.Post("/report-stats", api.workspaceAgentReportStats)
// DEPRECATED in favor of the POST endpoint above.
// TODO: remove in January 2023
r.Get("/report-stats", api.workspaceAgentReportStatsWebsocket)
})
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(

View File

@ -64,6 +64,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/app-health": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!
"GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)},

View File

@ -30,31 +30,31 @@ func New() database.Store {
return &fakeQuerier{
mutex: &sync.RWMutex{},
data: &data{
apiKeys: make([]database.APIKey, 0),
agentStats: make([]database.AgentStat, 0),
organizationMembers: make([]database.OrganizationMember, 0),
organizations: make([]database.Organization, 0),
users: make([]database.User, 0),
gitAuthLinks: make([]database.GitAuthLink, 0),
groups: make([]database.Group, 0),
groupMembers: make([]database.GroupMember, 0),
auditLogs: make([]database.AuditLog, 0),
files: make([]database.File, 0),
gitSSHKey: make([]database.GitSSHKey, 0),
parameterSchemas: make([]database.ParameterSchema, 0),
parameterValues: make([]database.ParameterValue, 0),
provisionerDaemons: make([]database.ProvisionerDaemon, 0),
provisionerJobAgents: make([]database.WorkspaceAgent, 0),
provisionerJobLogs: make([]database.ProvisionerJobLog, 0),
provisionerJobResources: make([]database.WorkspaceResource, 0),
provisionerJobResourceMetadata: make([]database.WorkspaceResourceMetadatum, 0),
provisionerJobs: make([]database.ProvisionerJob, 0),
templateVersions: make([]database.TemplateVersion, 0),
templates: make([]database.Template, 0),
workspaceBuilds: make([]database.WorkspaceBuild, 0),
workspaceApps: make([]database.WorkspaceApp, 0),
workspaces: make([]database.Workspace, 0),
licenses: make([]database.License, 0),
apiKeys: make([]database.APIKey, 0),
agentStats: make([]database.AgentStat, 0),
organizationMembers: make([]database.OrganizationMember, 0),
organizations: make([]database.Organization, 0),
users: make([]database.User, 0),
gitAuthLinks: make([]database.GitAuthLink, 0),
groups: make([]database.Group, 0),
groupMembers: make([]database.GroupMember, 0),
auditLogs: make([]database.AuditLog, 0),
files: make([]database.File, 0),
gitSSHKey: make([]database.GitSSHKey, 0),
parameterSchemas: make([]database.ParameterSchema, 0),
parameterValues: make([]database.ParameterValue, 0),
provisionerDaemons: make([]database.ProvisionerDaemon, 0),
workspaceAgents: make([]database.WorkspaceAgent, 0),
provisionerJobLogs: make([]database.ProvisionerJobLog, 0),
workspaceResources: make([]database.WorkspaceResource, 0),
workspaceResourceMetadata: make([]database.WorkspaceResourceMetadatum, 0),
provisionerJobs: make([]database.ProvisionerJob, 0),
templateVersions: make([]database.TemplateVersion, 0),
templates: make([]database.Template, 0),
workspaceBuilds: make([]database.WorkspaceBuild, 0),
workspaceApps: make([]database.WorkspaceApp, 0),
workspaces: make([]database.Workspace, 0),
licenses: make([]database.License, 0),
},
}
}
@ -89,28 +89,28 @@ type data struct {
userLinks []database.UserLink
// New tables
agentStats []database.AgentStat
auditLogs []database.AuditLog
files []database.File
gitAuthLinks []database.GitAuthLink
gitSSHKey []database.GitSSHKey
groups []database.Group
groupMembers []database.GroupMember
parameterSchemas []database.ParameterSchema
parameterValues []database.ParameterValue
provisionerDaemons []database.ProvisionerDaemon
provisionerJobAgents []database.WorkspaceAgent
provisionerJobLogs []database.ProvisionerJobLog
provisionerJobResources []database.WorkspaceResource
provisionerJobResourceMetadata []database.WorkspaceResourceMetadatum
provisionerJobs []database.ProvisionerJob
templateVersions []database.TemplateVersion
templates []database.Template
workspaceBuilds []database.WorkspaceBuild
workspaceApps []database.WorkspaceApp
workspaces []database.Workspace
licenses []database.License
replicas []database.Replica
agentStats []database.AgentStat
auditLogs []database.AuditLog
files []database.File
gitAuthLinks []database.GitAuthLink
gitSSHKey []database.GitSSHKey
groupMembers []database.GroupMember
groups []database.Group
licenses []database.License
parameterSchemas []database.ParameterSchema
parameterValues []database.ParameterValue
provisionerDaemons []database.ProvisionerDaemon
provisionerJobLogs []database.ProvisionerJobLog
provisionerJobs []database.ProvisionerJob
replicas []database.Replica
templateVersions []database.TemplateVersion
templates []database.Template
workspaceAgents []database.WorkspaceAgent
workspaceApps []database.WorkspaceApp
workspaceBuilds []database.WorkspaceBuild
workspaceResourceMetadata []database.WorkspaceResourceMetadatum
workspaceResources []database.WorkspaceResource
workspaces []database.Workspace
deploymentID string
derpMeshKey string
@ -942,6 +942,52 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas
return database.Workspace{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceByAgentID(_ context.Context, agentID uuid.UUID) (database.Workspace, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var agent database.WorkspaceAgent
for _, _agent := range q.workspaceAgents {
if _agent.ID == agentID {
agent = _agent
break
}
}
if agent.ID == uuid.Nil {
return database.Workspace{}, sql.ErrNoRows
}
var resource database.WorkspaceResource
for _, _resource := range q.workspaceResources {
if _resource.ID == agent.ResourceID {
resource = _resource
break
}
}
if resource.ID == uuid.Nil {
return database.Workspace{}, sql.ErrNoRows
}
var build database.WorkspaceBuild
for _, _build := range q.workspaceBuilds {
if _build.JobID == resource.JobID {
build = _build
break
}
}
if build.ID == uuid.Nil {
return database.Workspace{}, sql.ErrNoRows
}
for _, workspace := range q.workspaces {
if workspace.ID == build.WorkspaceID {
return workspace, nil
}
}
return database.Workspace{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -1801,8 +1847,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken
defer q.mutex.RUnlock()
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- {
agent := q.provisionerJobAgents[i]
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
agent := q.workspaceAgents[i]
if agent.AuthToken == authToken {
return agent, nil
}
@ -1815,8 +1861,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (da
defer q.mutex.RUnlock()
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- {
agent := q.provisionerJobAgents[i]
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
agent := q.workspaceAgents[i]
if agent.ID == id {
return agent, nil
}
@ -1829,8 +1875,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI
defer q.mutex.RUnlock()
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- {
agent := q.provisionerJobAgents[i]
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
agent := q.workspaceAgents[i]
if agent.AuthInstanceID.Valid && agent.AuthInstanceID.String == instanceID {
return agent, nil
}
@ -1843,7 +1889,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc
defer q.mutex.RUnlock()
workspaceAgents := make([]database.WorkspaceAgent, 0)
for _, agent := range q.provisionerJobAgents {
for _, agent := range q.workspaceAgents {
for _, resourceID := range resourceIDs {
if agent.ResourceID != resourceID {
continue
@ -1859,7 +1905,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after ti
defer q.mutex.RUnlock()
workspaceAgents := make([]database.WorkspaceAgent, 0)
for _, agent := range q.provisionerJobAgents {
for _, agent := range q.workspaceAgents {
if agent.CreatedAt.After(after) {
workspaceAgents = append(workspaceAgents, agent)
}
@ -1913,7 +1959,7 @@ func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID)
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, resource := range q.provisionerJobResources {
for _, resource := range q.workspaceResources {
if resource.ID == id {
return resource, nil
}
@ -1926,7 +1972,7 @@ func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid
defer q.mutex.RUnlock()
resources := make([]database.WorkspaceResource, 0)
for _, resource := range q.provisionerJobResources {
for _, resource := range q.workspaceResources {
if resource.JobID != jobID {
continue
}
@ -1940,7 +1986,7 @@ func (q *fakeQuerier) GetWorkspaceResourcesByJobIDs(_ context.Context, jobIDs []
defer q.mutex.RUnlock()
resources := make([]database.WorkspaceResource, 0)
for _, resource := range q.provisionerJobResources {
for _, resource := range q.workspaceResources {
for _, jobID := range jobIDs {
if resource.JobID != jobID {
continue
@ -1956,7 +2002,7 @@ func (q *fakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after
defer q.mutex.RUnlock()
resources := make([]database.WorkspaceResource, 0)
for _, resource := range q.provisionerJobResources {
for _, resource := range q.workspaceResources {
if resource.CreatedAt.After(after) {
resources = append(resources, resource)
}
@ -1978,7 +2024,7 @@ func (q *fakeQuerier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Conte
defer q.mutex.RUnlock()
metadata := make([]database.WorkspaceResourceMetadatum, 0)
for _, m := range q.provisionerJobResourceMetadata {
for _, m := range q.workspaceResourceMetadata {
_, ok := resourceIDs[m.WorkspaceResourceID]
if !ok {
continue
@ -1993,7 +2039,7 @@ func (q *fakeQuerier) GetWorkspaceResourceMetadataByResourceID(_ context.Context
defer q.mutex.RUnlock()
metadata := make([]database.WorkspaceResourceMetadatum, 0)
for _, metadatum := range q.provisionerJobResourceMetadata {
for _, metadatum := range q.workspaceResourceMetadata {
if metadatum.WorkspaceResourceID == id {
metadata = append(metadata, metadatum)
}
@ -2006,7 +2052,7 @@ func (q *fakeQuerier) GetWorkspaceResourceMetadataByResourceIDs(_ context.Contex
defer q.mutex.RUnlock()
metadata := make([]database.WorkspaceResourceMetadatum, 0)
for _, metadatum := range q.provisionerJobResourceMetadata {
for _, metadatum := range q.workspaceResourceMetadata {
for _, id := range ids {
if metadatum.WorkspaceResourceID == id {
metadata = append(metadata, metadatum)
@ -2319,7 +2365,7 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser
TroubleshootingURL: arg.TroubleshootingURL,
}
q.provisionerJobAgents = append(q.provisionerJobAgents, agent)
q.workspaceAgents = append(q.workspaceAgents, agent)
return agent, nil
}
@ -2339,7 +2385,7 @@ func (q *fakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.In
Icon: arg.Icon,
DailyCost: arg.DailyCost,
}
q.provisionerJobResources = append(q.provisionerJobResources, resource)
q.workspaceResources = append(q.workspaceResources, resource)
return resource, nil
}
@ -2354,7 +2400,7 @@ func (q *fakeQuerier) InsertWorkspaceResourceMetadata(_ context.Context, arg dat
Value: arg.Value,
Sensitive: arg.Sensitive,
}
q.provisionerJobResourceMetadata = append(q.provisionerJobResourceMetadata, metadatum)
q.workspaceResourceMetadata = append(q.workspaceResourceMetadata, metadatum)
return metadatum, nil
}
@ -2681,7 +2727,7 @@ func (q *fakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg
q.mutex.Lock()
defer q.mutex.Unlock()
for index, agent := range q.provisionerJobAgents {
for index, agent := range q.workspaceAgents {
if agent.ID != arg.ID {
continue
}
@ -2689,7 +2735,7 @@ func (q *fakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg
agent.LastConnectedAt = arg.LastConnectedAt
agent.DisconnectedAt = arg.DisconnectedAt
agent.UpdatedAt = arg.UpdatedAt
q.provisionerJobAgents[index] = agent
q.workspaceAgents[index] = agent
return nil
}
return sql.ErrNoRows
@ -2699,13 +2745,13 @@ func (q *fakeQuerier) UpdateWorkspaceAgentVersionByID(_ context.Context, arg dat
q.mutex.Lock()
defer q.mutex.Unlock()
for index, agent := range q.provisionerJobAgents {
for index, agent := range q.workspaceAgents {
if agent.ID != arg.ID {
continue
}
agent.Version = arg.Version
q.provisionerJobAgents[index] = agent
q.workspaceAgents[index] = agent
return nil
}
return sql.ErrNoRows

View File

@ -113,6 +113,7 @@ type sqlcQuerier interface {
GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Context, arg GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (WorkspaceBuild, error)
GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg GetWorkspaceBuildsByWorkspaceIDParams) ([]WorkspaceBuild, error)
GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error)
GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (Workspace, error)
GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Workspace, error)
GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWorkspaceByOwnerIDAndNameParams) (Workspace, error)
GetWorkspaceCountByUserID(ctx context.Context, ownerID uuid.UUID) (int64, error)

View File

@ -17,7 +17,7 @@ import (
)
const deleteOldAgentStats = `-- name: DeleteOldAgentStats :exec
DELETE FROM AGENT_STATS WHERE created_at < now() - interval '30 days'
DELETE FROM agent_stats WHERE created_at < NOW() - INTERVAL '30 days'
`
func (q *sqlQuerier) DeleteOldAgentStats(ctx context.Context) error {
@ -45,16 +45,17 @@ func (q *sqlQuerier) GetLatestAgentStat(ctx context.Context, agentID uuid.UUID)
}
const getTemplateDAUs = `-- name: GetTemplateDAUs :many
select
SELECT
(created_at at TIME ZONE 'UTC')::date as date,
user_id
from
FROM
agent_stats
where template_id = $1
group by
WHERE
template_id = $1
GROUP BY
date, user_id
order by
date asc
ORDER BY
date ASC
`
type GetTemplateDAUsRow struct {
@ -6145,6 +6146,55 @@ func (q *sqlQuerier) InsertWorkspaceResourceMetadata(ctx context.Context, arg In
return i, err
}
const getWorkspaceByAgentID = `-- name: GetWorkspaceByAgentID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at
FROM
workspaces
WHERE
workspaces.id = (
SELECT
workspace_id
FROM
workspace_builds
WHERE
workspace_builds.job_id = (
SELECT
job_id
FROM
workspace_resources
WHERE
workspace_resources.id = (
SELECT
resource_id
FROM
workspace_agents
WHERE
workspace_agents.id = $1
)
)
)
`
func (q *sqlQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (Workspace, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceByAgentID, agentID)
var i Workspace
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OwnerID,
&i.OrganizationID,
&i.TemplateID,
&i.Deleted,
&i.Name,
&i.AutostartSchedule,
&i.Ttl,
&i.LastUsedAt,
)
return i, err
}
const getWorkspaceByID = `-- name: GetWorkspaceByID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at

View File

@ -16,16 +16,17 @@ VALUES
SELECT * FROM agent_stats WHERE agent_id = $1 ORDER BY created_at DESC LIMIT 1;
-- name: GetTemplateDAUs :many
select
SELECT
(created_at at TIME ZONE 'UTC')::date as date,
user_id
from
FROM
agent_stats
where template_id = $1
group by
WHERE
template_id = $1
GROUP BY
date, user_id
order by
date asc;
ORDER BY
date ASC;
-- name: DeleteOldAgentStats :exec
DELETE FROM AGENT_STATS WHERE created_at < now() - interval '30 days';
DELETE FROM agent_stats WHERE created_at < NOW() - INTERVAL '30 days';

View File

@ -8,6 +8,35 @@ WHERE
LIMIT
1;
-- name: GetWorkspaceByAgentID :one
SELECT
*
FROM
workspaces
WHERE
workspaces.id = (
SELECT
workspace_id
FROM
workspace_builds
WHERE
workspace_builds.job_id = (
SELECT
job_id
FROM
workspace_resources
WHERE
workspace_resources.id = (
SELECT
resource_id
FROM
workspace_agents
WHERE
workspace_agents.id = @agent_id
)
)
);
-- name: GetWorkspaces :many
SELECT
workspaces.*, COUNT(*) OVER () as count

View File

@ -75,41 +75,41 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
})
return
}
dbApps, err := api.Database.GetWorkspaceAppsByAgentID(r.Context(), workspaceAgent.ID)
dbApps, err := api.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent applications.",
Detail: err.Error(),
})
return
}
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace resource.",
Detail: err.Error(),
})
return
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace build.",
Detail: err.Error(),
})
return
}
workspace, err := api.Database.GetWorkspaceByID(r.Context(), build.WorkspaceID)
workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace.",
Detail: err.Error(),
})
return
}
owner, err := api.Database.GetUserByID(r.Context(), workspace.OwnerID)
owner, err := api.Database.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace owner.",
Detail: err.Error(),
})
@ -755,31 +755,69 @@ func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordin
func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
})
return
}
var req codersdk.AgentStats
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.RxBytes == 0 && req.TxBytes == 0 {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AgentStatsResponse{
ReportInterval: api.AgentStatsRefreshInterval,
})
return
}
activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace.ID)
now := database.Now()
_, err = api.Database.InsertAgentStat(ctx, database.InsertAgentStatParams{
ID: uuid.New(),
CreatedAt: now,
AgentID: workspaceAgent.ID,
WorkspaceID: workspace.ID,
UserID: workspace.OwnerID,
TemplateID: workspace.TemplateID,
Payload: json.RawMessage("{}"),
})
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
err = api.Database.UpdateWorkspaceLastUsedAt(ctx, database.UpdateWorkspaceLastUsedAtParams{
ID: workspace.ID,
LastUsedAt: now,
})
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AgentStatsResponse{
ReportInterval: api.AgentStatsRefreshInterval,
})
}
func (api *API) workspaceAgentReportStatsWebsocket(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace resource.",
Detail: err.Error(),
})
return
}
build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get build.",
Detail: err.Error(),
})
return
}
workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
@ -861,14 +899,13 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
api.Logger.Debug(ctx, "read stats report",
slog.F("interval", api.AgentStatsRefreshInterval),
slog.F("agent", workspaceAgent.ID),
slog.F("resource", resource.ID),
slog.F("workspace", workspace.ID),
slog.F("update_db", updateDB),
slog.F("payload", rep),
)
if updateDB {
go activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace)
go activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace.ID)
lastReport = rep
@ -876,7 +913,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
ID: uuid.New(),
CreatedAt: database.Now(),
AgentID: workspaceAgent.ID,
WorkspaceID: build.WorkspaceID,
WorkspaceID: workspace.ID,
UserID: workspace.OwnerID,
TemplateID: workspace.TemplateID,
Payload: json.RawMessage(repJSON),
@ -888,7 +925,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
}
err = api.Database.UpdateWorkspaceLastUsedAt(ctx, database.UpdateWorkspaceLastUsedAtParams{
ID: build.WorkspaceID,
ID: workspace.ID,
LastUsedAt: database.Now(),
})
if err != nil {
@ -901,22 +938,23 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
}
func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
var req codersdk.PostWorkspaceAppHealthsRequest
if !httpapi.Read(r.Context(), rw, r, &req) {
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.Healths == nil || len(req.Healths) == 0 {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Health field is empty",
})
return
}
apps, err := api.Database.GetWorkspaceAppsByAgentID(r.Context(), workspaceAgent.ID)
apps, err := api.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Error getting agent apps",
Detail: err.Error(),
})
@ -935,7 +973,7 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request)
return nil
}()
if old == nil {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Error setting workspace app health",
Detail: xerrors.Errorf("workspace app name %s not found", id).Error(),
})
@ -943,7 +981,7 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request)
}
if old.HealthcheckUrl == "" {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Error setting workspace app health",
Detail: xerrors.Errorf("health checking is disabled for workspace app %s", id).Error(),
})
@ -955,7 +993,7 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request)
case codersdk.WorkspaceAppHealthHealthy:
case codersdk.WorkspaceAppHealthUnhealthy:
default:
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Error setting workspace app health",
Detail: xerrors.Errorf("workspace app health %s is not a valid value", newHealth).Error(),
})
@ -972,12 +1010,12 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request)
}
for _, app := range newApps {
err = api.Database.UpdateWorkspaceAppHealthByID(r.Context(), database.UpdateWorkspaceAppHealthByIDParams{
err = api.Database.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{
ID: app.ID,
Health: app.Health,
})
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Error setting workspace app health",
Detail: err.Error(),
})
@ -985,33 +1023,33 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request)
}
}
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace resource.",
Detail: err.Error(),
})
return
}
job, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
job, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace build.",
Detail: err.Error(),
})
return
}
workspace, err := api.Database.GetWorkspaceByID(r.Context(), job.WorkspaceID)
workspace, err := api.Database.GetWorkspaceByID(ctx, job.WorkspaceID)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace.",
Detail: err.Error(),
})
return
}
api.publishWorkspaceUpdate(r.Context(), workspace.ID)
api.publishWorkspaceUpdate(ctx, workspace.ID)
httpapi.Write(r.Context(), rw, http.StatusOK, nil)
httpapi.Write(ctx, rw, http.StatusOK, nil)
}
// postWorkspaceAgentsGitAuth returns a username and password for use
@ -1101,7 +1139,7 @@ func (api *API) workspaceAgentsGitAuth(rw http.ResponseWriter, r *http.Request)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
case <-ctx.Done():
return
case <-ticker.C:
case <-authChan:

View File

@ -1065,6 +1065,65 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
})
}
func TestWorkspaceAgentReportStats(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.ProvisionComplete,
ProvisionApply: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SetSessionToken(authToken)
_, err := agentClient.PostAgentStats(context.Background(), &codersdk.AgentStats{
ConnsByProto: map[string]int64{"TCP": 1},
NumConns: 1,
RxPackets: 1,
RxBytes: 1,
TxPackets: 1,
TxBytes: 1,
})
require.NoError(t, err)
newWorkspace, err := client.Workspace(context.Background(), workspace.ID)
require.NoError(t, err)
assert.True(t,
newWorkspace.LastUsedAt.After(workspace.LastUsedAt),
"%s is not after %s", newWorkspace.LastUsedAt, workspace.LastUsedAt,
)
})
}
func gitAuthCallback(t *testing.T, id string, client *codersdk.Client) *http.Response {
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse

View File

@ -232,6 +232,6 @@ type AgentStatsReportResponse struct {
NumConns int64 `json:"num_comms"`
// RxBytes is the number of received bytes.
RxBytes int64 `json:"rx_bytes"`
// TxBytes is the number of received bytes.
// TxBytes is the number of transmitted bytes.
TxBytes int64 `json:"tx_bytes"`
}

View File

@ -18,7 +18,6 @@ import (
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"tailscale.com/tailcfg"
"cdr.dev/slog"
@ -553,12 +552,46 @@ func (c *Client) WorkspaceAgentListeningPorts(ctx context.Context, agentID uuid.
// Stats records the Agent's network connection statistics for use in
// user-facing metrics and debugging.
// Each member value must be written and read with atomic.
// @typescript-ignore AgentStats
type AgentStats struct {
// ConnsByProto is a count of connections by protocol.
ConnsByProto map[string]int64 `json:"conns_by_proto"`
// NumConns is the number of connections received by an agent.
NumConns int64 `json:"num_comms"`
RxBytes int64 `json:"rx_bytes"`
TxBytes int64 `json:"tx_bytes"`
// RxPackets is the number of received packets.
RxPackets int64 `json:"rx_packets"`
// RxBytes is the number of received bytes.
RxBytes int64 `json:"rx_bytes"`
// TxPackets is the number of transmitted bytes.
TxPackets int64 `json:"tx_packets"`
// TxBytes is the number of transmitted bytes.
TxBytes int64 `json:"tx_bytes"`
}
// @typescript-ignore AgentStatsResponse
type AgentStatsResponse struct {
// ReportInterval is the duration after which the agent should send stats
// again.
ReportInterval time.Duration `json:"report_interval"`
}
func (c *Client) PostAgentStats(ctx context.Context, stats *AgentStats) (AgentStatsResponse, error) {
res, err := c.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/report-stats", stats)
if err != nil {
return AgentStatsResponse{}, xerrors.Errorf("send request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return AgentStatsResponse{}, readBodyAsError(res)
}
var interval AgentStatsResponse
err = json.NewDecoder(res.Body).Decode(&interval)
if err != nil {
return AgentStatsResponse{}, xerrors.Errorf("decode stats response: %w", err)
}
return interval, nil
}
// AgentReportStats begins a stat streaming connection with the Coder server.
@ -566,84 +599,41 @@ type AgentStats struct {
func (c *Client) AgentReportStats(
ctx context.Context,
log slog.Logger,
stats func() *AgentStats,
getStats func() *AgentStats,
) (io.Closer, error) {
serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/report-stats")
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(serverURL, []*http.Cookie{{
Name: SessionTokenKey,
Value: c.SessionToken(),
}})
httpClient := &http.Client{
Jar: jar,
Transport: c.HTTPClient.Transport,
}
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
go func() {
defer close(doneCh)
// Immediately trigger a stats push to get the correct interval.
timer := time.NewTimer(time.Nanosecond)
defer timer.Stop()
// If the agent connection succeeds for a while, then fails, then succeeds
// for a while (etc.) the retry may hit the maximum. This is a normal
// case for long-running agents that experience coderd upgrades, so
// we use a short maximum retry limit.
for r := retry.New(time.Second, time.Minute); r.Wait(ctx); {
err = func() error {
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
if res == nil {
return err
}
return readBodyAsError(res)
}
for {
var req AgentStatsReportRequest
err := wsjson.Read(ctx, conn, &req)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
return err
}
s := stats()
resp := AgentStatsReportResponse{
NumConns: s.NumConns,
RxBytes: s.RxBytes,
TxBytes: s.TxBytes,
}
err = wsjson.Write(ctx, conn, resp)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
return err
}
}
}()
if err != nil && ctx.Err() == nil {
log.Error(ctx, "report stats", slog.Error(err))
for {
select {
case <-ctx.Done():
return
case <-timer.C:
}
var nextInterval time.Duration
for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); {
resp, err := c.PostAgentStats(ctx, getStats())
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "report stats", slog.Error(err))
}
continue
}
nextInterval = resp.ReportInterval
break
}
timer.Reset(nextInterval)
}
}()
return closeFunc(func() error {
cancel()
<-doneCh
return nil
}), nil
}

View File

@ -6,13 +6,17 @@ import (
"net/http/httptest"
"net/url"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
func TestWorkspaceAgentMetadata(t *testing.T) {
@ -47,3 +51,30 @@ func TestWorkspaceAgentMetadata(t *testing.T) {
require.Equal(t, parsed.Hostname(), node.HostName)
require.Equal(t, parsed.Port(), strconv.Itoa(node.DERPPort))
}
func TestAgentReportStats(t *testing.T) {
t.Parallel()
var numReports atomic.Int64
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numReports.Add(1)
httpapi.Write(context.Background(), w, http.StatusOK, codersdk.AgentStatsResponse{
ReportInterval: 5 * time.Millisecond,
})
}))
parsed, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(parsed)
ctx := context.Background()
closeStream, err := client.AgentReportStats(ctx, slogtest.Make(t, nil), func() *codersdk.AgentStats {
return &codersdk.AgentStats{}
})
require.NoError(t, err)
defer closeStream.Close()
require.Eventually(t,
func() bool { return numReports.Load() >= 3 },
testutil.WaitMedium, testutil.IntervalFast,
)
}

2
go.mod
View File

@ -40,7 +40,7 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0
// There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here:
// https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221117204504-2d6503f027c3
// Switch to our fork that imports fixes from http://github.com/tailscale/ssh.
// See: https://github.com/coder/coder/issues/3371

4
go.sum
View File

@ -355,8 +355,8 @@ github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY=
github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY=
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko=
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc h1:qozpteSLz0ifMasetJ+/Qac5Ud/NRNIlgTubGf6TAaQ=
github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc/go.mod h1:lkCb74eSJwxeNq8YwyILoHD5vtHktiZnTOxBxo3tbNc=
github.com/coder/tailscale v1.1.1-0.20221117204504-2d6503f027c3 h1:lq8GmpE5bn8A36uxq1h+TWnaQKPugtRkxKrYZA78O9c=
github.com/coder/tailscale v1.1.1-0.20221117204504-2d6503f027c3/go.mod h1:lkCb74eSJwxeNq8YwyILoHD5vtHktiZnTOxBxo3tbNc=
github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE=
github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU=
github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU=

View File

@ -26,6 +26,7 @@ import (
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
tslogger "tailscale.com/types/logger"
"tailscale.com/types/netlogtype"
"tailscale.com/types/netmap"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
@ -35,15 +36,14 @@ import (
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/wgcfg/nmcfg"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/cryptorand"
"cdr.dev/slog"
)
func init() {
// Globally disable network namespacing.
// All networking happens in userspace.
// Globally disable network namespacing. All networking happens in
// userspace.
netns.SetEnabled(false)
}
@ -55,6 +55,11 @@ type Options struct {
// If so, only DERPs can establish connections.
BlockEndpoints bool
Logger slog.Logger
// EnableTrafficStats enables per-connection traffic statistics.
// ExtractTrafficStats must be called to reset the counters and be
// periodically called while enabled to avoid unbounded memory use.
EnableTrafficStats bool
}
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
@ -143,8 +148,9 @@ func NewConn(options *Options) (*Conn, error) {
}
tunDevice, magicConn, dnsManager, ok := wireguardInternals.GetInternals()
if !ok {
return nil, xerrors.New("failed to get wireguard internals")
return nil, xerrors.New("get wireguard internals")
}
tunDevice.SetStatisticsEnabled(options.EnableTrafficStats)
// Update the keys for the magic connection!
err = magicConn.SetPrivateKey(nodePrivateKey)
@ -649,6 +655,13 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
c.logger.Debug(c.dialContext, "forwarded connection closed", slog.F("local_addr", dialAddrStr))
}
// ExtractTrafficStats extracts and resets the counters for all active
// connections. It must be called periodically otherwise the memory used is
// unbounded. EnableTrafficStats must be true when calling NewConn.
func (c *Conn) ExtractTrafficStats() map[netlogtype.Connection]netlogtype.Counts {
return c.tunDevice.ExtractStatistics()
}
type listenKey struct {
network string
host string