mirror of https://github.com/coder/coder.git
345 lines
10 KiB
Go
345 lines
10 KiB
Go
package coderd
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"net/http"
|
|
"runtime/pprof"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/hashicorp/yamux"
|
|
"golang.org/x/xerrors"
|
|
"nhooyr.io/websocket"
|
|
|
|
"cdr.dev/slog"
|
|
"github.com/coder/coder/v2/coderd/agentapi"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
// @Summary Workspace agent RPC API
|
|
// @ID workspace-agent-rpc-api
|
|
// @Security CoderSessionToken
|
|
// @Tags Agents
|
|
// @Success 101
|
|
// @Router /workspaceagents/me/rpc [get]
|
|
// @x-apidocgen {"skip": true}
|
|
func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
api.WebsocketWaitMutex.Lock()
|
|
api.WebsocketWaitGroup.Add(1)
|
|
api.WebsocketWaitMutex.Unlock()
|
|
defer api.WebsocketWaitGroup.Done()
|
|
workspaceAgent := httpmw.WorkspaceAgent(r)
|
|
|
|
ensureLatestBuildFn, build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Internal error fetching workspace.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
owner, err := api.Database.GetUserByID(ctx, workspace.OwnerID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Internal error fetching user.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
conn, err := websocket.Accept(rw, r, nil)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Failed to accept websocket.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
|
defer wsNetConn.Close()
|
|
|
|
mux, err := yamux.Server(wsNetConn, nil)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Failed to start yamux over websocket.",
|
|
Detail: err.Error(),
|
|
})
|
|
return
|
|
}
|
|
defer mux.Close()
|
|
|
|
api.Logger.Debug(ctx, "accepting agent RPC connection",
|
|
slog.F("owner", owner.Username),
|
|
slog.F("workspace", workspace.Name),
|
|
slog.F("name", workspaceAgent.Name),
|
|
)
|
|
api.Logger.Debug(ctx, "accepting agent details", slog.F("agent", workspaceAgent))
|
|
|
|
defer conn.Close(websocket.StatusNormalClosure, "")
|
|
|
|
pingFn, ok := api.agentConnectionUpdate(ctx, workspaceAgent, build.WorkspaceID, conn)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
agentAPI := agentapi.New(agentapi.Options{
|
|
AgentID: workspaceAgent.ID,
|
|
|
|
Ctx: api.ctx,
|
|
Log: api.Logger,
|
|
Database: api.Database,
|
|
Pubsub: api.Pubsub,
|
|
DerpMapFn: api.DERPMap,
|
|
TailnetCoordinator: &api.TailnetCoordinator,
|
|
TemplateScheduleStore: api.TemplateScheduleStore,
|
|
StatsBatcher: api.statsBatcher,
|
|
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
|
|
PublishWorkspaceAgentLogsUpdateFn: api.publishWorkspaceAgentLogsUpdate,
|
|
|
|
AccessURL: api.AccessURL,
|
|
AppHostname: api.AppHostname,
|
|
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
|
AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
|
|
AgentStatsRefreshInterval: api.AgentStatsRefreshInterval,
|
|
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
|
|
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
|
|
DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
|
|
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
|
|
|
// Optional:
|
|
WorkspaceID: build.WorkspaceID, // saves the extra lookup later
|
|
UpdateAgentMetricsFn: api.UpdateAgentMetrics,
|
|
})
|
|
|
|
closeCtx, closeCtxCancel := context.WithCancel(ctx)
|
|
go func() {
|
|
defer closeCtxCancel()
|
|
err := agentAPI.Serve(ctx, mux)
|
|
if err != nil {
|
|
api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err))
|
|
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
|
return
|
|
}
|
|
}()
|
|
|
|
pingFn(closeCtx, ensureLatestBuildFn)
|
|
}
|
|
|
|
func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logger, rw http.ResponseWriter, workspaceAgent database.WorkspaceAgent) (func() error, database.WorkspaceBuild, bool) {
|
|
resource, err := db.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Internal error fetching workspace agent resource.",
|
|
Detail: err.Error(),
|
|
})
|
|
return nil, database.WorkspaceBuild{}, false
|
|
}
|
|
|
|
build, err := db.GetWorkspaceBuildByJobID(ctx, resource.JobID)
|
|
if err != nil {
|
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
Message: "Internal error fetching workspace build job.",
|
|
Detail: err.Error(),
|
|
})
|
|
return nil, database.WorkspaceBuild{}, false
|
|
}
|
|
|
|
// Ensure the resource is still valid!
|
|
// We only accept agents for resources on the latest build.
|
|
ensureLatestBuild := func() error {
|
|
latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if build.ID != latestBuild.ID {
|
|
return xerrors.New("build is outdated")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
err = ensureLatestBuild()
|
|
if err != nil {
|
|
logger.Debug(ctx, "agent tried to connect from non-latest build",
|
|
slog.F("resource", resource),
|
|
slog.F("agent", workspaceAgent),
|
|
)
|
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
Message: "Agent trying to connect from non-latest build.",
|
|
Detail: err.Error(),
|
|
})
|
|
return nil, database.WorkspaceBuild{}, false
|
|
}
|
|
|
|
return ensureLatestBuild, build, true
|
|
}
|
|
|
|
func (api *API) agentConnectionUpdate(ctx context.Context, workspaceAgent database.WorkspaceAgent, workspaceID uuid.UUID, conn *websocket.Conn) (func(closeCtx context.Context, ensureLatestBuildFn func() error), bool) {
|
|
// We use a custom heartbeat routine here instead of `httpapi.Heartbeat`
|
|
// because we want to log the agent's last ping time.
|
|
var lastPing atomic.Pointer[time.Time]
|
|
lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive.
|
|
|
|
go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) {
|
|
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
|
|
t := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
|
defer t.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-t.C:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
|
|
// We don't need a context that times out here because the ping will
|
|
// eventually go through. If the context times out, then other
|
|
// websocket read operations will receive an error, obfuscating the
|
|
// actual problem.
|
|
err := conn.Ping(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
lastPing.Store(ptr.Ref(time.Now()))
|
|
}
|
|
})
|
|
|
|
firstConnectedAt := workspaceAgent.FirstConnectedAt
|
|
if !firstConnectedAt.Valid {
|
|
firstConnectedAt = sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
}
|
|
}
|
|
lastConnectedAt := sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
}
|
|
disconnectedAt := workspaceAgent.DisconnectedAt
|
|
updateConnectionTimes := func(ctx context.Context) error {
|
|
//nolint:gocritic // We only update ourself.
|
|
err := api.Database.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{
|
|
ID: workspaceAgent.ID,
|
|
FirstConnectedAt: firstConnectedAt,
|
|
LastConnectedAt: lastConnectedAt,
|
|
DisconnectedAt: disconnectedAt,
|
|
UpdatedAt: dbtime.Now(),
|
|
LastConnectedReplicaID: uuid.NullUUID{
|
|
UUID: api.ID,
|
|
Valid: true,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
defer func() {
|
|
// If connection closed then context will be canceled, try to
|
|
// ensure our final update is sent. By waiting at most the agent
|
|
// inactive disconnect timeout we ensure that we don't block but
|
|
// also guarantee that the agent will be considered disconnected
|
|
// by normal status check.
|
|
//
|
|
// Use a system context as the agent has disconnected and that token
|
|
// may no longer be valid.
|
|
//nolint:gocritic
|
|
ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout)
|
|
defer cancel()
|
|
|
|
// Only update timestamp if the disconnect is new.
|
|
if !disconnectedAt.Valid {
|
|
disconnectedAt = sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
}
|
|
}
|
|
err := updateConnectionTimes(ctx)
|
|
if err != nil {
|
|
// This is a bug with unit tests that cancel the app context and
|
|
// cause this error log to be generated. We should fix the unit tests
|
|
// as this is a valid log.
|
|
//
|
|
// The pq error occurs when the server is shutting down.
|
|
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
|
|
api.Logger.Error(ctx, "failed to update agent disconnect time",
|
|
slog.Error(err),
|
|
slog.F("workspace_id", workspaceID),
|
|
)
|
|
}
|
|
}
|
|
api.publishWorkspaceUpdate(ctx, workspaceID)
|
|
}()
|
|
|
|
err := updateConnectionTimes(ctx)
|
|
if err != nil {
|
|
_ = conn.Close(websocket.StatusGoingAway, err.Error())
|
|
return nil, false
|
|
}
|
|
api.publishWorkspaceUpdate(ctx, workspaceID)
|
|
|
|
return func(closeCtx context.Context, ensureLatestBuildFn func() error) {
|
|
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-closeCtx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
|
|
lastPing := *lastPing.Load()
|
|
|
|
var connectionStatusChanged bool
|
|
if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout {
|
|
if !disconnectedAt.Valid {
|
|
connectionStatusChanged = true
|
|
disconnectedAt = sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
}
|
|
}
|
|
} else {
|
|
connectionStatusChanged = disconnectedAt.Valid
|
|
// TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it?
|
|
disconnectedAt = sql.NullTime{}
|
|
lastConnectedAt = sql.NullTime{
|
|
Time: dbtime.Now(),
|
|
Valid: true,
|
|
}
|
|
}
|
|
err = updateConnectionTimes(ctx)
|
|
if err != nil {
|
|
_ = conn.Close(websocket.StatusGoingAway, err.Error())
|
|
return
|
|
}
|
|
if connectionStatusChanged {
|
|
api.publishWorkspaceUpdate(ctx, workspaceID)
|
|
}
|
|
err := ensureLatestBuildFn()
|
|
if err != nil {
|
|
// Disconnect agents that are no longer valid.
|
|
_ = conn.Close(websocket.StatusGoingAway, "")
|
|
return
|
|
}
|
|
}
|
|
}, true
|
|
}
|