diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index c59e50387f..68a2d6a378 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -113,11 +113,9 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { ) api.Logger.Debug(ctx, "accepting agent details", slog.F("agent", workspaceAgent)) - defer conn.Close(websocket.StatusNormalClosure, "") - closeCtx, closeCtxCancel := context.WithCancel(ctx) defer closeCtxCancel() - monitor := api.startAgentWebsocketMonitor(closeCtx, workspaceAgent, build, conn) + monitor := api.startAgentYamuxMonitor(closeCtx, workspaceAgent, build, mux) defer monitor.close() agentAPI := agentapi.New(agentapi.Options{ @@ -214,8 +212,8 @@ func checkBuildIsLatest(ctx context.Context, db database.Store, build database.W func (api *API) startAgentWebsocketMonitor(ctx context.Context, workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild, conn *websocket.Conn, -) *agentWebsocketMonitor { - monitor := &agentWebsocketMonitor{ +) *agentConnectionMonitor { + monitor := &agentConnectionMonitor{ apiCtx: api.ctx, workspaceAgent: workspaceAgent, workspaceBuild: workspaceBuild, @@ -236,6 +234,53 @@ func (api *API) startAgentWebsocketMonitor(ctx context.Context, return monitor } +type yamuxPingerCloser struct { + mux *yamux.Session +} + +func (y *yamuxPingerCloser) Close(websocket.StatusCode, string) error { + return y.mux.Close() +} + +func (y *yamuxPingerCloser) Ping(ctx context.Context) error { + errCh := make(chan error, 1) + go func() { + _, err := y.mux.Ping() + errCh <- err + }() + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + } +} + +func (api *API) startAgentYamuxMonitor(ctx context.Context, + workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild, + mux *yamux.Session, +) *agentConnectionMonitor { + monitor := &agentConnectionMonitor{ + apiCtx: api.ctx, + workspaceAgent: workspaceAgent, + workspaceBuild: workspaceBuild, + conn: &yamuxPingerCloser{mux: mux}, + pingPeriod: api.AgentConnectionUpdateFrequency, + db: api.Database, + replicaID: api.ID, + updater: api, + disconnectTimeout: api.AgentInactiveDisconnectTimeout, + logger: api.Logger.With( + slog.F("workspace_id", workspaceBuild.WorkspaceID), + slog.F("agent_id", workspaceAgent.ID), + ), + } + monitor.init() + monitor.start(ctx) + + return monitor +} + type workspaceUpdater interface { publishWorkspaceUpdate(ctx context.Context, workspaceID uuid.UUID) } @@ -245,7 +290,7 @@ type pingerCloser interface { Close(code websocket.StatusCode, reason string) error } -type agentWebsocketMonitor struct { +type agentConnectionMonitor struct { apiCtx context.Context cancel context.CancelFunc wg sync.WaitGroup @@ -272,7 +317,7 @@ type agentWebsocketMonitor struct { // // We use a custom heartbeat routine here instead of `httpapi.Heartbeat` // because we want to log the agent's last ping time. -func (m *agentWebsocketMonitor) sendPings(ctx context.Context) { +func (m *agentConnectionMonitor) sendPings(ctx context.Context) { t := time.NewTicker(m.pingPeriod) defer t.Stop() @@ -295,7 +340,7 @@ func (m *agentWebsocketMonitor) sendPings(ctx context.Context) { } } -func (m *agentWebsocketMonitor) updateConnectionTimes(ctx context.Context) error { +func (m *agentConnectionMonitor) updateConnectionTimes(ctx context.Context) error { //nolint:gocritic // We only update the agent we are minding. err := m.db.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{ ID: m.workspaceAgent.ID, @@ -314,7 +359,7 @@ func (m *agentWebsocketMonitor) updateConnectionTimes(ctx context.Context) error return nil } -func (m *agentWebsocketMonitor) init() { +func (m *agentConnectionMonitor) init() { now := dbtime.Now() m.firstConnectedAt = m.workspaceAgent.FirstConnectedAt if !m.firstConnectedAt.Valid { @@ -331,7 +376,7 @@ func (m *agentWebsocketMonitor) init() { m.lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive. } -func (m *agentWebsocketMonitor) start(ctx context.Context) { +func (m *agentConnectionMonitor) start(ctx context.Context) { ctx, m.cancel = context.WithCancel(ctx) m.wg.Add(2) go pprof.Do(ctx, pprof.Labels("agent", m.workspaceAgent.ID.String()), @@ -346,7 +391,7 @@ func (m *agentWebsocketMonitor) start(ctx context.Context) { }) } -func (m *agentWebsocketMonitor) monitor(ctx context.Context) { +func (m *agentConnectionMonitor) monitor(ctx context.Context) { defer func() { // If connection closed then context will be canceled, try to // ensure our final update is sent. By waiting at most the agent @@ -384,7 +429,7 @@ func (m *agentWebsocketMonitor) monitor(ctx context.Context) { }() reason := "disconnect" defer func() { - m.logger.Debug(ctx, "agent websocket monitor is closing connection", + m.logger.Debug(ctx, "agent connection monitor is closing connection", slog.F("reason", reason)) _ = m.conn.Close(websocket.StatusGoingAway, reason) }() @@ -409,6 +454,7 @@ func (m *agentWebsocketMonitor) monitor(ctx context.Context) { lastPing := *m.lastPing.Load() if time.Since(lastPing) > m.disconnectTimeout { reason = "ping timeout" + m.logger.Warn(ctx, "connection to agent timed out") return } connectionStatusChanged := m.disconnectedAt.Valid @@ -421,6 +467,7 @@ func (m *agentWebsocketMonitor) monitor(ctx context.Context) { err = m.updateConnectionTimes(ctx) if err != nil { reason = err.Error() + m.logger.Error(ctx, "failed to update agent connection times", slog.Error(err)) return } if connectionStatusChanged { @@ -429,12 +476,13 @@ func (m *agentWebsocketMonitor) monitor(ctx context.Context) { err = checkBuildIsLatest(ctx, m.db, m.workspaceBuild) if err != nil { reason = err.Error() + m.logger.Info(ctx, "disconnected possibly outdated agent", slog.Error(err)) return } } } -func (m *agentWebsocketMonitor) close() { +func (m *agentConnectionMonitor) close() { m.cancel() m.wg.Wait() } diff --git a/coderd/workspaceagentsrpc_internal_test.go b/coderd/workspaceagentsrpc_internal_test.go index 834de4807d..dbae11a218 100644 --- a/coderd/workspaceagentsrpc_internal_test.go +++ b/coderd/workspaceagentsrpc_internal_test.go @@ -23,7 +23,7 @@ import ( "github.com/coder/coder/v2/testutil" ) -func TestAgentWebsocketMonitor_ContextCancel(t *testing.T) { +func TestAgentConnectionMonitor_ContextCancel(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) now := dbtime.Now() @@ -45,7 +45,7 @@ func TestAgentWebsocketMonitor_ContextCancel(t *testing.T) { } replicaID := uuid.New() - uut := &agentWebsocketMonitor{ + uut := &agentConnectionMonitor{ apiCtx: ctx, workspaceAgent: agent, workspaceBuild: build, @@ -97,7 +97,7 @@ func TestAgentWebsocketMonitor_ContextCancel(t *testing.T) { require.Greater(t, m, n) } -func TestAgentWebsocketMonitor_PingTimeout(t *testing.T) { +func TestAgentConnectionMonitor_PingTimeout(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) now := dbtime.Now() @@ -119,7 +119,7 @@ func TestAgentWebsocketMonitor_PingTimeout(t *testing.T) { } replicaID := uuid.New() - uut := &agentWebsocketMonitor{ + uut := &agentConnectionMonitor{ apiCtx: ctx, workspaceAgent: agent, workspaceBuild: build, @@ -157,7 +157,7 @@ func TestAgentWebsocketMonitor_PingTimeout(t *testing.T) { fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) } -func TestAgentWebsocketMonitor_BuildOutdated(t *testing.T) { +func TestAgentConnectionMonitor_BuildOutdated(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) now := dbtime.Now() @@ -179,7 +179,7 @@ func TestAgentWebsocketMonitor_BuildOutdated(t *testing.T) { } replicaID := uuid.New() - uut := &agentWebsocketMonitor{ + uut := &agentConnectionMonitor{ apiCtx: ctx, workspaceAgent: agent, workspaceBuild: build, @@ -217,12 +217,12 @@ func TestAgentWebsocketMonitor_BuildOutdated(t *testing.T) { fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) } -func TestAgentWebsocketMonitor_SendPings(t *testing.T) { +func TestAgentConnectionMonitor_SendPings(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) t.Cleanup(cancel) fConn := &fakePingerCloser{} - uut := &agentWebsocketMonitor{ + uut := &agentConnectionMonitor{ pingPeriod: testutil.IntervalFast, conn: fConn, } @@ -238,7 +238,7 @@ func TestAgentWebsocketMonitor_SendPings(t *testing.T) { require.NotNil(t, lastPing) } -func TestAgentWebsocketMonitor_StartClose(t *testing.T) { +func TestAgentConnectionMonitor_StartClose(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) fConn := &fakePingerCloser{} @@ -259,7 +259,7 @@ func TestAgentWebsocketMonitor_StartClose(t *testing.T) { WorkspaceID: uuid.New(), } replicaID := uuid.New() - uut := &agentWebsocketMonitor{ + uut := &agentConnectionMonitor{ apiCtx: ctx, workspaceAgent: agent, workspaceBuild: build,