coder/coderd/workspaceagentsrpc.go

457 lines
13 KiB
Go

package coderd
import (
"context"
"database/sql"
"fmt"
"io"
"net/http"
"runtime/pprof"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
)
// @Summary Workspace agent RPC API
// @ID workspace-agent-rpc-api
// @Security CoderSessionToken
// @Tags Agents
// @Success 101
// @Router /workspaceagents/me/rpc [get]
// @x-apidocgen {"skip": true}
func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
logger := api.Logger.Named("agentrpc")
version := r.URL.Query().Get("version")
if version == "" {
// The initial version on this HTTP endpoint was 2.0, so assume this version if unspecified.
// Coder v2.7.1 (not to be confused with the Agent API version) calls this endpoint without
// a version parameter and wants Agent API version 2.0.
version = "2.0"
}
if err := proto.CurrentVersion.Validate(version); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unknown or unsupported API version",
Validations: []codersdk.ValidationError{
{Field: "version", Detail: err.Error()},
},
})
return
}
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
build := httpmw.LatestBuild(r)
workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching workspace.",
Detail: err.Error(),
})
return
}
owner, err := api.Database.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching user.",
Detail: err.Error(),
})
return
}
logger = logger.With(
slog.F("owner", owner.Username),
slog.F("workspace_name", workspace.Name),
slog.F("agent_name", workspaceAgent.Name),
)
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
ycfg := yamux.DefaultConfig()
ycfg.LogOutput = nil
ycfg.Logger = slog.Stdlib(ctx, logger.Named("yamux"), slog.LevelInfo)
mux, err := yamux.Server(wsNetConn, ycfg)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to start yamux over websocket.",
Detail: err.Error(),
})
return
}
defer mux.Close()
logger.Debug(ctx, "accepting agent RPC connection", slog.F("agent", workspaceAgent))
closeCtx, closeCtxCancel := context.WithCancel(ctx)
defer closeCtxCancel()
monitor := api.startAgentYamuxMonitor(closeCtx, workspaceAgent, build, mux)
defer monitor.close()
agentAPI := agentapi.New(agentapi.Options{
AgentID: workspaceAgent.ID,
Ctx: api.ctx,
Log: logger,
Database: api.Database,
Pubsub: api.Pubsub,
DerpMapFn: api.DERPMap,
TailnetCoordinator: &api.TailnetCoordinator,
TemplateScheduleStore: api.TemplateScheduleStore,
AppearanceFetcher: &api.AppearanceFetcher,
StatsBatcher: api.statsBatcher,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
PublishWorkspaceAgentLogsUpdateFn: api.publishWorkspaceAgentLogsUpdate,
AccessURL: api.AccessURL,
AppHostname: api.AppHostname,
AgentStatsRefreshInterval: api.AgentStatsRefreshInterval,
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
ExternalAuthConfigs: api.ExternalAuthConfigs,
// Optional:
WorkspaceID: build.WorkspaceID, // saves the extra lookup later
UpdateAgentMetricsFn: api.UpdateAgentMetrics,
})
streamID := tailnet.StreamID{
Name: fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
ID: workspaceAgent.ID,
Auth: tailnet.AgentCoordinateeAuth{ID: workspaceAgent.ID},
}
ctx = tailnet.WithStreamID(ctx, streamID)
ctx = agentapi.WithAPIVersion(ctx, version)
err = agentAPI.Serve(ctx, mux)
if err != nil && !xerrors.Is(err, yamux.ErrSessionShutdown) && !xerrors.Is(err, io.EOF) {
logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
}
func (api *API) startAgentWebsocketMonitor(ctx context.Context,
workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild,
conn *websocket.Conn,
) *agentConnectionMonitor {
monitor := &agentConnectionMonitor{
apiCtx: api.ctx,
workspaceAgent: workspaceAgent,
workspaceBuild: workspaceBuild,
conn: conn,
pingPeriod: api.AgentConnectionUpdateFrequency,
db: api.Database,
replicaID: api.ID,
updater: api,
disconnectTimeout: api.AgentInactiveDisconnectTimeout,
logger: api.Logger.With(
slog.F("workspace_id", workspaceBuild.WorkspaceID),
slog.F("agent_id", workspaceAgent.ID),
),
}
monitor.init()
monitor.start(ctx)
return monitor
}
type yamuxPingerCloser struct {
mux *yamux.Session
}
func (y *yamuxPingerCloser) Close(websocket.StatusCode, string) error {
return y.mux.Close()
}
func (y *yamuxPingerCloser) Ping(ctx context.Context) error {
errCh := make(chan error, 1)
go func() {
_, err := y.mux.Ping()
errCh <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
return err
}
}
func (api *API) startAgentYamuxMonitor(ctx context.Context,
workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild,
mux *yamux.Session,
) *agentConnectionMonitor {
monitor := &agentConnectionMonitor{
apiCtx: api.ctx,
workspaceAgent: workspaceAgent,
workspaceBuild: workspaceBuild,
conn: &yamuxPingerCloser{mux: mux},
pingPeriod: api.AgentConnectionUpdateFrequency,
db: api.Database,
replicaID: api.ID,
updater: api,
disconnectTimeout: api.AgentInactiveDisconnectTimeout,
logger: api.Logger.With(
slog.F("workspace_id", workspaceBuild.WorkspaceID),
slog.F("agent_id", workspaceAgent.ID),
),
}
monitor.init()
monitor.start(ctx)
return monitor
}
type workspaceUpdater interface {
publishWorkspaceUpdate(ctx context.Context, workspaceID uuid.UUID)
}
type pingerCloser interface {
Ping(ctx context.Context) error
Close(code websocket.StatusCode, reason string) error
}
type agentConnectionMonitor struct {
apiCtx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
workspaceAgent database.WorkspaceAgent
workspaceBuild database.WorkspaceBuild
conn pingerCloser
db database.Store
replicaID uuid.UUID
updater workspaceUpdater
logger slog.Logger
pingPeriod time.Duration
// state manipulated by both sendPings() and monitor() goroutines: needs to be threadsafe
lastPing atomic.Pointer[time.Time]
// state manipulated only by monitor() goroutine: does not need to be threadsafe
firstConnectedAt sql.NullTime
lastConnectedAt sql.NullTime
disconnectedAt sql.NullTime
disconnectTimeout time.Duration
}
// sendPings sends websocket pings.
//
// We use a custom heartbeat routine here instead of `httpapi.Heartbeat`
// because we want to log the agent's last ping time.
func (m *agentConnectionMonitor) sendPings(ctx context.Context) {
t := time.NewTicker(m.pingPeriod)
defer t.Stop()
for {
select {
case <-t.C:
case <-ctx.Done():
return
}
// We don't need a context that times out here because the ping will
// eventually go through. If the context times out, then other
// websocket read operations will receive an error, obfuscating the
// actual problem.
err := m.conn.Ping(ctx)
if err != nil {
return
}
m.lastPing.Store(ptr.Ref(time.Now()))
}
}
func (m *agentConnectionMonitor) updateConnectionTimes(ctx context.Context) error {
//nolint:gocritic // We only update the agent we are minding.
err := m.db.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{
ID: m.workspaceAgent.ID,
FirstConnectedAt: m.firstConnectedAt,
LastConnectedAt: m.lastConnectedAt,
DisconnectedAt: m.disconnectedAt,
UpdatedAt: dbtime.Now(),
LastConnectedReplicaID: uuid.NullUUID{
UUID: m.replicaID,
Valid: true,
},
})
if err != nil {
return xerrors.Errorf("failed to update workspace agent connection times: %w", err)
}
return nil
}
func (m *agentConnectionMonitor) init() {
now := dbtime.Now()
m.firstConnectedAt = m.workspaceAgent.FirstConnectedAt
if !m.firstConnectedAt.Valid {
m.firstConnectedAt = sql.NullTime{
Time: now,
Valid: true,
}
}
m.lastConnectedAt = sql.NullTime{
Time: now,
Valid: true,
}
m.disconnectedAt = m.workspaceAgent.DisconnectedAt
m.lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive.
}
func (m *agentConnectionMonitor) start(ctx context.Context) {
ctx, m.cancel = context.WithCancel(ctx)
m.wg.Add(2)
go pprof.Do(ctx, pprof.Labels("agent", m.workspaceAgent.ID.String()),
func(ctx context.Context) {
defer m.wg.Done()
m.sendPings(ctx)
})
go pprof.Do(ctx, pprof.Labels("agent", m.workspaceAgent.ID.String()),
func(ctx context.Context) {
defer m.wg.Done()
m.monitor(ctx)
})
}
func (m *agentConnectionMonitor) monitor(ctx context.Context) {
defer func() {
// If connection closed then context will be canceled, try to
// ensure our final update is sent. By waiting at most the agent
// inactive disconnect timeout we ensure that we don't block but
// also guarantee that the agent will be considered disconnected
// by normal status check.
//
// Use a system context as the agent has disconnected and that token
// may no longer be valid.
//nolint:gocritic
finalCtx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(m.apiCtx), m.disconnectTimeout)
defer cancel()
// Only update timestamp if the disconnect is new.
if !m.disconnectedAt.Valid {
m.disconnectedAt = sql.NullTime{
Time: dbtime.Now(),
Valid: true,
}
}
err := m.updateConnectionTimes(finalCtx)
if err != nil {
// This is a bug with unit tests that cancel the app context and
// cause this error log to be generated. We should fix the unit tests
// as this is a valid log.
//
// The pq error occurs when the server is shutting down.
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
m.logger.Error(finalCtx, "failed to update agent disconnect time",
slog.Error(err),
)
}
}
m.updater.publishWorkspaceUpdate(finalCtx, m.workspaceBuild.WorkspaceID)
}()
reason := "disconnect"
defer func() {
m.logger.Debug(ctx, "agent connection monitor is closing connection",
slog.F("reason", reason))
_ = m.conn.Close(websocket.StatusGoingAway, reason)
}()
err := m.updateConnectionTimes(ctx)
if err != nil {
reason = err.Error()
return
}
m.updater.publishWorkspaceUpdate(ctx, m.workspaceBuild.WorkspaceID)
ticker := time.NewTicker(m.pingPeriod)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
reason = "canceled"
return
case <-ticker.C:
}
lastPing := *m.lastPing.Load()
if time.Since(lastPing) > m.disconnectTimeout {
reason = "ping timeout"
m.logger.Warn(ctx, "connection to agent timed out")
return
}
connectionStatusChanged := m.disconnectedAt.Valid
m.disconnectedAt = sql.NullTime{}
m.lastConnectedAt = sql.NullTime{
Time: dbtime.Now(),
Valid: true,
}
err = m.updateConnectionTimes(ctx)
if err != nil {
reason = err.Error()
if !database.IsQueryCanceledError(err) {
m.logger.Error(ctx, "failed to update agent connection times", slog.Error(err))
}
return
}
if connectionStatusChanged {
m.updater.publishWorkspaceUpdate(ctx, m.workspaceBuild.WorkspaceID)
}
err = checkBuildIsLatest(ctx, m.db, m.workspaceBuild)
if err != nil {
reason = err.Error()
m.logger.Info(ctx, "disconnected possibly outdated agent", slog.Error(err))
return
}
}
}
func (m *agentConnectionMonitor) close() {
m.cancel()
m.wg.Wait()
}
func checkBuildIsLatest(ctx context.Context, db database.Store, build database.WorkspaceBuild) error {
latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}