fix: ensure agent token is from latest build in middleware (#12443)

This commit is contained in:
Garrett Delfosse 2024-03-14 12:27:32 -04:00 committed by GitHub
parent 63696d762f
commit 0723dd3abf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 242 additions and 260 deletions

View File

@ -902,7 +902,7 @@ func New(options *Options) *API {
httpmw.RequireAPIKeyOrWorkspaceProxyAuth(), httpmw.RequireAPIKeyOrWorkspaceProxyAuth(),
).Get("/connection", api.workspaceAgentConnectionGeneric) ).Get("/connection", api.workspaceAgentConnectionGeneric)
r.Route("/me", func(r chi.Router) { r.Route("/me", func(r chi.Router) {
r.Use(httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{ r.Use(httpmw.ExtractWorkspaceAgentAndLatestBuild(httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{
DB: options.Database, DB: options.Database,
Optional: false, Optional: false,
})) }))

View File

@ -1880,12 +1880,12 @@ func (q *querier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]databas
return q.db.GetUsersByIDs(ctx, ids) return q.db.GetUsersByIDs(ctx, ids)
} }
func (q *querier) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { func (q *querier) GetWorkspaceAgentAndLatestBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow, error) {
// This is a system function // This is a system function
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, err return database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}, err
} }
return q.db.GetWorkspaceAgentAndOwnerByAuthToken(ctx, authToken) return q.db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, authToken)
} }
func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {

View File

@ -2274,7 +2274,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
s.Run("GetReplicaByID", s.Subtest(func(db database.Store, check *expects) { s.Run("GetReplicaByID", s.Subtest(func(db database.Store, check *expects) {
check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead).Errors(sql.ErrNoRows) check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead).Errors(sql.ErrNoRows)
})) }))
s.Run("GetWorkspaceAgentAndOwnerByAuthToken", s.Subtest(func(db database.Store, check *expects) { s.Run("GetWorkspaceAgentAndLatestBuildByAuthToken", s.Subtest(func(db database.Store, check *expects) {
check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead).Errors(sql.ErrNoRows) check.Args(uuid.New()).Asserts(rbac.ResourceSystem, rbac.ActionRead).Errors(sql.ErrNoRows)
})) }))
s.Run("GetUserLinksByUserID", s.Subtest(func(db database.Store, check *expects) { s.Run("GetUserLinksByUserID", s.Subtest(func(db database.Store, check *expects) {

View File

@ -69,7 +69,7 @@ func New() database.Store {
templates: make([]database.TemplateTable, 0), templates: make([]database.TemplateTable, 0),
workspaceAgentStats: make([]database.WorkspaceAgentStat, 0), workspaceAgentStats: make([]database.WorkspaceAgentStat, 0),
workspaceAgentLogs: make([]database.WorkspaceAgentLog, 0), workspaceAgentLogs: make([]database.WorkspaceAgentLog, 0),
workspaceBuilds: make([]database.WorkspaceBuildTable, 0), workspaceBuilds: make([]database.WorkspaceBuild, 0),
workspaceApps: make([]database.WorkspaceApp, 0), workspaceApps: make([]database.WorkspaceApp, 0),
workspaces: make([]database.Workspace, 0), workspaces: make([]database.Workspace, 0),
licenses: make([]database.License, 0), licenses: make([]database.License, 0),
@ -171,7 +171,7 @@ type data struct {
workspaceApps []database.WorkspaceApp workspaceApps []database.WorkspaceApp
workspaceAppStatsLastInsertID int64 workspaceAppStatsLastInsertID int64
workspaceAppStats []database.WorkspaceAppStat workspaceAppStats []database.WorkspaceAppStat
workspaceBuilds []database.WorkspaceBuildTable workspaceBuilds []database.WorkspaceBuild
workspaceBuildParameters []database.WorkspaceBuildParameter workspaceBuildParameters []database.WorkspaceBuildParameter
workspaceResourceMetadata []database.WorkspaceResourceMetadatum workspaceResourceMetadata []database.WorkspaceResourceMetadatum
workspaceResources []database.WorkspaceResource workspaceResources []database.WorkspaceResource
@ -542,7 +542,7 @@ func (q *FakeQuerier) templateVersionWithUserNoLock(tpl database.TemplateVersion
return withUser return withUser
} }
func (q *FakeQuerier) workspaceBuildWithUserNoLock(tpl database.WorkspaceBuildTable) database.WorkspaceBuild { func (q *FakeQuerier) workspaceBuildWithUserNoLock(tpl database.WorkspaceBuild) database.WorkspaceBuild {
var user database.User var user database.User
for _, _user := range q.users { for _, _user := range q.users {
if _user.ID == tpl.InitiatorID { if _user.ID == tpl.InitiatorID {
@ -2801,7 +2801,7 @@ func (q *FakeQuerier) GetQuotaConsumedForUser(_ context.Context, userID uuid.UUI
continue continue
} }
var lastBuild database.WorkspaceBuildTable var lastBuild database.WorkspaceBuild
for _, build := range q.workspaceBuilds { for _, build := range q.workspaceBuilds {
if build.WorkspaceID != workspace.ID { if build.WorkspaceID != workspace.ID {
continue continue
@ -3488,7 +3488,7 @@ func (q *FakeQuerier) GetTemplateParameterInsights(ctx context.Context, arg data
defer q.mutex.RUnlock() defer q.mutex.RUnlock()
// WITH latest_workspace_builds ... // WITH latest_workspace_builds ...
latestWorkspaceBuilds := make(map[uuid.UUID]database.WorkspaceBuildTable) latestWorkspaceBuilds := make(map[uuid.UUID]database.WorkspaceBuild)
for _, wb := range q.workspaceBuilds { for _, wb := range q.workspaceBuilds {
if wb.CreatedAt.Before(arg.StartTime) || wb.CreatedAt.Equal(arg.EndTime) || wb.CreatedAt.After(arg.EndTime) { if wb.CreatedAt.Before(arg.StartTime) || wb.CreatedAt.Equal(arg.EndTime) || wb.CreatedAt.After(arg.EndTime) {
continue continue
@ -4270,20 +4270,14 @@ func (q *FakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]datab
return users, nil return users, nil
} }
func (q *FakeQuerier) GetWorkspaceAgentAndOwnerByAuthToken(_ context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { func (q *FakeQuerier) GetWorkspaceAgentAndLatestBuildByAuthToken(_ context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow, error) {
q.mutex.RLock() q.mutex.RLock()
defer q.mutex.RUnlock() defer q.mutex.RUnlock()
rows := []database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}
// map of build number -> row // We want to return the latest build number for each workspace
rows := make(map[int32]database.GetWorkspaceAgentAndOwnerByAuthTokenRow) latestBuildNumber := make(map[uuid.UUID]int32)
// We want to return the latest build number
var latestBuildNumber int32
for _, agt := range q.workspaceAgents { for _, agt := range q.workspaceAgents {
if agt.AuthToken != authToken {
continue
}
// get the related workspace and user // get the related workspace and user
for _, res := range q.workspaceResources { for _, res := range q.workspaceResources {
if agt.ResourceID != res.ID { if agt.ResourceID != res.ID {
@ -4300,47 +4294,43 @@ func (q *FakeQuerier) GetWorkspaceAgentAndOwnerByAuthToken(_ context.Context, au
if ws.Deleted { if ws.Deleted {
continue continue
} }
var row database.GetWorkspaceAgentAndOwnerByAuthTokenRow row := database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{
row.WorkspaceID = ws.ID Workspace: database.Workspace{
row.TemplateID = ws.TemplateID ID: ws.ID,
TemplateID: ws.TemplateID,
},
WorkspaceAgent: agt,
WorkspaceBuild: build,
}
usr, err := q.getUserByIDNoLock(ws.OwnerID) usr, err := q.getUserByIDNoLock(ws.OwnerID)
if err != nil { if err != nil {
return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, sql.ErrNoRows return database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}, sql.ErrNoRows
}
row.OwnerID = usr.ID
row.OwnerRoles = append(usr.RBACRoles, "member")
// We also need to get org roles for the user
row.OwnerName = usr.Username
row.WorkspaceAgent = agt
row.TemplateVersionID = build.TemplateVersionID
for _, mem := range q.organizationMembers {
if mem.UserID == usr.ID {
row.OwnerRoles = append(row.OwnerRoles, fmt.Sprintf("organization-member:%s", mem.OrganizationID.String()))
}
}
// And group memberships
for _, groupMem := range q.groupMembers {
if groupMem.UserID == usr.ID {
row.OwnerGroups = append(row.OwnerGroups, groupMem.GroupID.String())
}
} }
row.Workspace.OwnerID = usr.ID
// Keep track of the latest build number // Keep track of the latest build number
rows[build.BuildNumber] = row rows = append(rows, row)
if build.BuildNumber > latestBuildNumber { if build.BuildNumber > latestBuildNumber[ws.ID] {
latestBuildNumber = build.BuildNumber latestBuildNumber[ws.ID] = build.BuildNumber
} }
} }
} }
} }
} }
if len(rows) == 0 { for i := range rows {
return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, sql.ErrNoRows if rows[i].WorkspaceAgent.AuthToken != authToken {
continue
}
if rows[i].WorkspaceBuild.BuildNumber != latestBuildNumber[rows[i].Workspace.ID] {
continue
}
return rows[i], nil
} }
// Return the row related to the latest build return database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow{}, sql.ErrNoRows
return rows[latestBuildNumber], nil
} }
func (q *FakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { func (q *FakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
@ -6243,7 +6233,7 @@ func (q *FakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.Inser
q.mutex.Lock() q.mutex.Lock()
defer q.mutex.Unlock() defer q.mutex.Unlock()
workspaceBuild := database.WorkspaceBuildTable{ workspaceBuild := database.WorkspaceBuild{
ID: arg.ID, ID: arg.ID,
CreatedAt: arg.CreatedAt, CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt, UpdatedAt: arg.UpdatedAt,

View File

@ -1103,10 +1103,10 @@ func (m metricsStore) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]dat
return users, err return users, err
} }
func (m metricsStore) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { func (m metricsStore) GetWorkspaceAgentAndLatestBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow, error) {
start := time.Now() start := time.Now()
r0, r1 := m.s.GetWorkspaceAgentAndOwnerByAuthToken(ctx, authToken) r0, r1 := m.s.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, authToken)
m.queryLatencies.WithLabelValues("GetWorkspaceAgentAndOwnerByAuthToken").Observe(time.Since(start).Seconds()) m.queryLatencies.WithLabelValues("GetWorkspaceAgentAndLatestBuildByAuthToken").Observe(time.Since(start).Seconds())
return r0, r1 return r0, r1
} }

View File

@ -2295,19 +2295,19 @@ func (mr *MockStoreMockRecorder) GetUsersByIDs(arg0, arg1 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsersByIDs", reflect.TypeOf((*MockStore)(nil).GetUsersByIDs), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsersByIDs", reflect.TypeOf((*MockStore)(nil).GetUsersByIDs), arg0, arg1)
} }
// GetWorkspaceAgentAndOwnerByAuthToken mocks base method. // GetWorkspaceAgentAndLatestBuildByAuthToken mocks base method.
func (m *MockStore) GetWorkspaceAgentAndOwnerByAuthToken(arg0 context.Context, arg1 uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { func (m *MockStore) GetWorkspaceAgentAndLatestBuildByAuthToken(arg0 context.Context, arg1 uuid.UUID) (database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceAgentAndOwnerByAuthToken", arg0, arg1) ret := m.ctrl.Call(m, "GetWorkspaceAgentAndLatestBuildByAuthToken", arg0, arg1)
ret0, _ := ret[0].(database.GetWorkspaceAgentAndOwnerByAuthTokenRow) ret0, _ := ret[0].(database.GetWorkspaceAgentAndLatestBuildByAuthTokenRow)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetWorkspaceAgentAndOwnerByAuthToken indicates an expected call of GetWorkspaceAgentAndOwnerByAuthToken. // GetWorkspaceAgentAndLatestBuildByAuthToken indicates an expected call of GetWorkspaceAgentAndLatestBuildByAuthToken.
func (mr *MockStoreMockRecorder) GetWorkspaceAgentAndOwnerByAuthToken(arg0, arg1 any) *gomock.Call { func (mr *MockStoreMockRecorder) GetWorkspaceAgentAndLatestBuildByAuthToken(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentAndOwnerByAuthToken", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentAndOwnerByAuthToken), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentAndLatestBuildByAuthToken", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentAndLatestBuildByAuthToken), arg0, arg1)
} }
// GetWorkspaceAgentByID mocks base method. // GetWorkspaceAgentByID mocks base method.

View File

@ -230,7 +230,7 @@ type sqlcQuerier interface {
// to look up references to actions. eg. a user could build a workspace // to look up references to actions. eg. a user could build a workspace
// for another user, then be deleted... we still want them to appear! // for another user, then be deleted... we still want them to appear!
GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error)
GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (GetWorkspaceAgentAndOwnerByAuthTokenRow, error) GetWorkspaceAgentAndLatestBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (GetWorkspaceAgentAndLatestBuildByAuthTokenRow, error)
GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error)
GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentLifecycleStateByIDRow, error) GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentLifecycleStateByIDRow, error)

View File

@ -8671,80 +8671,66 @@ func (q *sqlQuerier) DeleteOldWorkspaceAgentLogs(ctx context.Context) error {
return err return err
} }
const getWorkspaceAgentAndOwnerByAuthToken = `-- name: GetWorkspaceAgentAndOwnerByAuthToken :one const getWorkspaceAgentAndLatestBuildByAuthToken = `-- name: GetWorkspaceAgentAndLatestBuildByAuthToken :one
SELECT SELECT
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, workspaces.dormant_at, workspaces.deleting_at, workspaces.automatic_updates, workspaces.favorite,
workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.expanded_directory, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, workspace_agents.display_apps, workspace_agents.api_version, workspace_agents.display_order, workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.expanded_directory, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, workspace_agents.display_apps, workspace_agents.api_version, workspace_agents.display_order,
workspaces.id AS workspace_id, workspace_build_with_user.id, workspace_build_with_user.created_at, workspace_build_with_user.updated_at, workspace_build_with_user.workspace_id, workspace_build_with_user.template_version_id, workspace_build_with_user.build_number, workspace_build_with_user.transition, workspace_build_with_user.initiator_id, workspace_build_with_user.provisioner_state, workspace_build_with_user.job_id, workspace_build_with_user.deadline, workspace_build_with_user.reason, workspace_build_with_user.daily_cost, workspace_build_with_user.max_deadline, workspace_build_with_user.initiator_by_avatar_url, workspace_build_with_user.initiator_by_username
users.id AS owner_id, FROM
users.username AS owner_name, -- Only get the latest build for each workspace
users.status AS owner_status, (
workspaces.template_id AS template_id, SELECT
workspace_builds.template_version_id AS template_version_id, workspace_id, MAX(build_number) as max_build_number
array_cat( FROM
array_append(users.rbac_roles, 'member'), workspace_build_with_user
array_append(ARRAY[]::text[], 'organization-member:' || organization_members.organization_id::text) GROUP BY
)::text[] as owner_roles, workspace_id
array_agg(COALESCE(group_members.group_id::text, ''))::text[] AS owner_groups ) as latest_builds
FROM users -- Pull the workspace_build rows for returning
INNER JOIN INNER JOIN workspace_build_with_user
workspaces ON workspace_build_with_user.workspace_id = latest_builds.workspace_id
ON AND workspace_build_with_user.build_number = latest_builds.max_build_number
workspaces.owner_id = users.id -- For each latest build, grab the resources to relate to an agent
INNER JOIN INNER JOIN workspace_resources
workspace_builds ON workspace_resources.job_id = workspace_build_with_user.job_id
ON -- Agent <-> Resource is 1:1
workspace_builds.workspace_id = workspaces.id INNER JOIN workspace_agents
INNER JOIN ON workspace_agents.resource_id = workspace_resources.id
workspace_resources -- We need the owner ID
ON INNER JOIN workspaces
workspace_resources.job_id = workspace_builds.job_id ON workspace_build_with_user.workspace_id = workspaces.id
INNER JOIN
workspace_agents
ON
workspace_agents.resource_id = workspace_resources.id
INNER JOIN -- every user is a member of some org
organization_members
ON
organization_members.user_id = users.id
LEFT JOIN -- as they may not be a member of any groups
group_members
ON
group_members.user_id = users.id
WHERE WHERE
-- TODO: we can add more conditions here, such as: -- This should only match 1 agent, so 1 returned row or 0
-- 1) The user must be active
-- 2) The workspace must be running
workspace_agents.auth_token = $1 workspace_agents.auth_token = $1
AND AND
workspaces.deleted = FALSE workspaces.deleted = FALSE
GROUP BY
workspace_agents.id,
workspaces.id,
users.id,
organization_members.organization_id,
workspace_builds.build_number,
workspace_builds.template_version_id
ORDER BY
workspace_builds.build_number DESC
LIMIT 1
` `
type GetWorkspaceAgentAndOwnerByAuthTokenRow struct { type GetWorkspaceAgentAndLatestBuildByAuthTokenRow struct {
WorkspaceAgent WorkspaceAgent `db:"workspace_agent" json:"workspace_agent"` Workspace Workspace `db:"workspace" json:"workspace"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` WorkspaceAgent WorkspaceAgent `db:"workspace_agent" json:"workspace_agent"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` WorkspaceBuild WorkspaceBuild `db:"workspace_build" json:"workspace_build"`
OwnerName string `db:"owner_name" json:"owner_name"`
OwnerStatus UserStatus `db:"owner_status" json:"owner_status"`
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"`
OwnerRoles []string `db:"owner_roles" json:"owner_roles"`
OwnerGroups []string `db:"owner_groups" json:"owner_groups"`
} }
func (q *sqlQuerier) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { func (q *sqlQuerier) GetWorkspaceAgentAndLatestBuildByAuthToken(ctx context.Context, authToken uuid.UUID) (GetWorkspaceAgentAndLatestBuildByAuthTokenRow, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceAgentAndOwnerByAuthToken, authToken) row := q.db.QueryRowContext(ctx, getWorkspaceAgentAndLatestBuildByAuthToken, authToken)
var i GetWorkspaceAgentAndOwnerByAuthTokenRow var i GetWorkspaceAgentAndLatestBuildByAuthTokenRow
err := row.Scan( err := row.Scan(
&i.Workspace.ID,
&i.Workspace.CreatedAt,
&i.Workspace.UpdatedAt,
&i.Workspace.OwnerID,
&i.Workspace.OrganizationID,
&i.Workspace.TemplateID,
&i.Workspace.Deleted,
&i.Workspace.Name,
&i.Workspace.AutostartSchedule,
&i.Workspace.Ttl,
&i.Workspace.LastUsedAt,
&i.Workspace.DormantAt,
&i.Workspace.DeletingAt,
&i.Workspace.AutomaticUpdates,
&i.Workspace.Favorite,
&i.WorkspaceAgent.ID, &i.WorkspaceAgent.ID,
&i.WorkspaceAgent.CreatedAt, &i.WorkspaceAgent.CreatedAt,
&i.WorkspaceAgent.UpdatedAt, &i.WorkspaceAgent.UpdatedAt,
@ -8776,14 +8762,22 @@ func (q *sqlQuerier) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, a
pq.Array(&i.WorkspaceAgent.DisplayApps), pq.Array(&i.WorkspaceAgent.DisplayApps),
&i.WorkspaceAgent.APIVersion, &i.WorkspaceAgent.APIVersion,
&i.WorkspaceAgent.DisplayOrder, &i.WorkspaceAgent.DisplayOrder,
&i.WorkspaceID, &i.WorkspaceBuild.ID,
&i.OwnerID, &i.WorkspaceBuild.CreatedAt,
&i.OwnerName, &i.WorkspaceBuild.UpdatedAt,
&i.OwnerStatus, &i.WorkspaceBuild.WorkspaceID,
&i.TemplateID, &i.WorkspaceBuild.TemplateVersionID,
&i.TemplateVersionID, &i.WorkspaceBuild.BuildNumber,
pq.Array(&i.OwnerRoles), &i.WorkspaceBuild.Transition,
pq.Array(&i.OwnerGroups), &i.WorkspaceBuild.InitiatorID,
&i.WorkspaceBuild.ProvisionerState,
&i.WorkspaceBuild.JobID,
&i.WorkspaceBuild.Deadline,
&i.WorkspaceBuild.Reason,
&i.WorkspaceBuild.DailyCost,
&i.WorkspaceBuild.MaxDeadline,
&i.WorkspaceBuild.InitiatorByAvatarUrl,
&i.WorkspaceBuild.InitiatorByUsername,
) )
return i, err return i, err
} }

View File

@ -214,59 +214,37 @@ WHERE
wb.workspace_id = @workspace_id :: uuid wb.workspace_id = @workspace_id :: uuid
); );
-- name: GetWorkspaceAgentAndOwnerByAuthToken :one -- name: GetWorkspaceAgentAndLatestBuildByAuthToken :one
SELECT SELECT
sqlc.embed(workspaces),
sqlc.embed(workspace_agents), sqlc.embed(workspace_agents),
workspaces.id AS workspace_id, sqlc.embed(workspace_build_with_user)
users.id AS owner_id, FROM
users.username AS owner_name, -- Only get the latest build for each workspace
users.status AS owner_status, (
workspaces.template_id AS template_id, SELECT
workspace_builds.template_version_id AS template_version_id, workspace_id, MAX(build_number) as max_build_number
array_cat( FROM
array_append(users.rbac_roles, 'member'), workspace_build_with_user
array_append(ARRAY[]::text[], 'organization-member:' || organization_members.organization_id::text) GROUP BY
)::text[] as owner_roles, workspace_id
array_agg(COALESCE(group_members.group_id::text, ''))::text[] AS owner_groups ) as latest_builds
FROM users -- Pull the workspace_build rows for returning
INNER JOIN INNER JOIN workspace_build_with_user
workspaces ON workspace_build_with_user.workspace_id = latest_builds.workspace_id
ON AND workspace_build_with_user.build_number = latest_builds.max_build_number
workspaces.owner_id = users.id -- For each latest build, grab the resources to relate to an agent
INNER JOIN INNER JOIN workspace_resources
workspace_builds ON workspace_resources.job_id = workspace_build_with_user.job_id
ON -- Agent <-> Resource is 1:1
workspace_builds.workspace_id = workspaces.id INNER JOIN workspace_agents
INNER JOIN ON workspace_agents.resource_id = workspace_resources.id
workspace_resources -- We need the owner ID
ON INNER JOIN workspaces
workspace_resources.job_id = workspace_builds.job_id ON workspace_build_with_user.workspace_id = workspaces.id
INNER JOIN
workspace_agents
ON
workspace_agents.resource_id = workspace_resources.id
INNER JOIN -- every user is a member of some org
organization_members
ON
organization_members.user_id = users.id
LEFT JOIN -- as they may not be a member of any groups
group_members
ON
group_members.user_id = users.id
WHERE WHERE
-- TODO: we can add more conditions here, such as: -- This should only match 1 agent, so 1 returned row or 0
-- 1) The user must be active
-- 2) The workspace must be running
workspace_agents.auth_token = @auth_token workspace_agents.auth_token = @auth_token
AND AND
workspaces.deleted = FALSE workspaces.deleted = FALSE
GROUP BY ;
workspace_agents.id,
workspaces.id,
users.id,
organization_members.organization_id,
workspace_builds.build_number,
workspace_builds.template_version_id
ORDER BY
workspace_builds.build_number DESC
LIMIT 1;

View File

@ -32,7 +32,23 @@ func WorkspaceAgent(r *http.Request) database.WorkspaceAgent {
return user return user
} }
type ExtractWorkspaceAgentConfig struct { type latestBuildContextKey struct{}
func latestBuildOptional(r *http.Request) (database.WorkspaceBuild, bool) {
wb, ok := r.Context().Value(latestBuildContextKey{}).(database.WorkspaceBuild)
return wb, ok
}
// LatestBuild returns the Latest Build from the ExtractLatestBuild handler.
func LatestBuild(r *http.Request) database.WorkspaceBuild {
wb, ok := latestBuildOptional(r)
if !ok {
panic("developer error: agent middleware not provided or was made optional")
}
return wb
}
type ExtractWorkspaceAgentAndLatestBuildConfig struct {
DB database.Store DB database.Store
// Optional indicates whether the middleware should be optional. If true, any // Optional indicates whether the middleware should be optional. If true, any
// requests without the a token or with an invalid token will be allowed to // requests without the a token or with an invalid token will be allowed to
@ -40,8 +56,8 @@ type ExtractWorkspaceAgentConfig struct {
Optional bool Optional bool
} }
// ExtractWorkspaceAgent requires authentication using a valid agent token. // ExtractWorkspaceAgentAndLatestBuild requires authentication using a valid agent token.
func ExtractWorkspaceAgent(opts ExtractWorkspaceAgentConfig) func(http.Handler) http.Handler { func ExtractWorkspaceAgentAndLatestBuild(opts ExtractWorkspaceAgentAndLatestBuildConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
@ -76,7 +92,7 @@ func ExtractWorkspaceAgent(opts ExtractWorkspaceAgentConfig) func(http.Handler)
} }
//nolint:gocritic // System needs to be able to get workspace agents. //nolint:gocritic // System needs to be able to get workspace agents.
row, err := opts.DB.GetWorkspaceAgentAndOwnerByAuthToken(dbauthz.AsSystemRestricted(ctx), token) row, err := opts.DB.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), token)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
optionalWrite(http.StatusUnauthorized, codersdk.Response{ optionalWrite(http.StatusUnauthorized, codersdk.Response{
@ -93,19 +109,30 @@ func ExtractWorkspaceAgent(opts ExtractWorkspaceAgentConfig) func(http.Handler)
return return
} }
//nolint:gocritic // System needs to be able to get owner roles.
roles, err := opts.DB.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), row.Workspace.OwnerID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error checking workspace agent authorization.",
Detail: err.Error(),
})
return
}
subject := rbac.Subject{ subject := rbac.Subject{
ID: row.OwnerID.String(), ID: row.Workspace.OwnerID.String(),
Roles: rbac.RoleNames(row.OwnerRoles), Roles: rbac.RoleNames(roles.Roles),
Groups: row.OwnerGroups, Groups: roles.Groups,
Scope: rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{ Scope: rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{
WorkspaceID: row.WorkspaceID, WorkspaceID: row.Workspace.ID,
OwnerID: row.OwnerID, OwnerID: row.Workspace.OwnerID,
TemplateID: row.TemplateID, TemplateID: row.Workspace.TemplateID,
VersionID: row.TemplateVersionID, VersionID: row.WorkspaceBuild.TemplateVersionID,
}), }),
}.WithCachedASTValue() }.WithCachedASTValue()
ctx = context.WithValue(ctx, workspaceAgentContextKey{}, row.WorkspaceAgent) ctx = context.WithValue(ctx, workspaceAgentContextKey{}, row.WorkspaceAgent)
ctx = context.WithValue(ctx, latestBuildContextKey{}, row.WorkspaceBuild)
// Also set the dbauthz actor for the request. // Also set the dbauthz actor for the request.
ctx = dbauthz.As(ctx, subject) ctx = dbauthz.As(ctx, subject)
next.ServeHTTP(rw, r.WithContext(ctx)) next.ServeHTTP(rw, r.WithContext(ctx))

View File

@ -23,8 +23,8 @@ func TestWorkspaceAgent(t *testing.T) {
t.Parallel() t.Parallel()
db, _ := dbtestutil.NewDB(t) db, _ := dbtestutil.NewDB(t)
req, rtr := setup(t, db, uuid.New(), httpmw.ExtractWorkspaceAgent( req, rtr, _, _ := setup(t, db, uuid.New(), httpmw.ExtractWorkspaceAgentAndLatestBuild(
httpmw.ExtractWorkspaceAgentConfig{ httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{
DB: db, DB: db,
Optional: false, Optional: false,
})) }))
@ -42,8 +42,8 @@ func TestWorkspaceAgent(t *testing.T) {
t.Parallel() t.Parallel()
db, _ := dbtestutil.NewDB(t) db, _ := dbtestutil.NewDB(t)
authToken := uuid.New() authToken := uuid.New()
req, rtr := setup(t, db, authToken, httpmw.ExtractWorkspaceAgent( req, rtr, _, _ := setup(t, db, authToken, httpmw.ExtractWorkspaceAgentAndLatestBuild(
httpmw.ExtractWorkspaceAgentConfig{ httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{
DB: db, DB: db,
Optional: false, Optional: false,
})) }))
@ -57,9 +57,47 @@ func TestWorkspaceAgent(t *testing.T) {
t.Cleanup(func() { _ = res.Body.Close() }) t.Cleanup(func() { _ = res.Body.Close() })
require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, http.StatusOK, res.StatusCode)
}) })
t.Run("Latest", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
authToken := uuid.New()
req, rtr, ws, tpv := setup(t, db, authToken, httpmw.ExtractWorkspaceAgentAndLatestBuild(
httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{
DB: db,
Optional: false,
}),
)
// Create a newer build
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: ws.OrganizationID,
})
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: ws.ID,
JobID: job.ID,
TemplateVersionID: tpv.ID,
BuildNumber: 2,
})
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resource.ID,
})
rw := httptest.NewRecorder()
req.Header.Set(codersdk.SessionTokenHeader, authToken.String())
rtr.ServeHTTP(rw, req)
//nolint:bodyclose // Closed in `t.Cleanup`
res := rw.Result()
t.Cleanup(func() { _ = res.Body.Close() })
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
} }
func setup(t testing.TB, db database.Store, authToken uuid.UUID, mw func(http.Handler) http.Handler) (*http.Request, http.Handler) { func setup(t testing.TB, db database.Store, authToken uuid.UUID, mw func(http.Handler) http.Handler) (*http.Request, http.Handler, database.Workspace, database.TemplateVersion) {
t.Helper() t.Helper()
org := dbgen.Organization(t, db, database.Organization{}) org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{ user := dbgen.User(t, db, database.User{
@ -107,5 +145,5 @@ func setup(t testing.TB, db database.Store, authToken uuid.UUID, mw func(http.Ha
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
}) })
return req, rtr return req, rtr, workspace, templateVersion
} }

View File

@ -954,13 +954,9 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
api.WebsocketWaitGroup.Add(1) api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock() api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done() defer api.WebsocketWaitGroup.Done()
// The middleware only accept agents for resources on the latest build.
workspaceAgent := httpmw.WorkspaceAgent(r) workspaceAgent := httpmw.WorkspaceAgent(r)
// Ensure the resource is still valid! build := httpmw.LatestBuild(r)
// We only accept agents for resources on the latest build.
build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent)
if !ok {
return
}
workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID) workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
if err != nil { if err != nil {

View File

@ -404,7 +404,7 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) {
require.Error(t, err) require.Error(t, err)
var sdkErr *codersdk.Error var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr) require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusForbidden, sdkErr.StatusCode()) require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
}) })
t.Run("FailDeleted", func(t *testing.T) { t.Run("FailDeleted", func(t *testing.T) {
@ -488,7 +488,7 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
agentToken, err := uuid.Parse(r.AgentToken) agentToken, err := uuid.Parse(r.AgentToken)
require.NoError(t, err) require.NoError(t, err)
//nolint: gocritic // testing //nolint: gocritic // testing
ao, err := db.GetWorkspaceAgentAndOwnerByAuthToken(dbauthz.AsSystemRestricted(ctx), agentToken) ao, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentToken)
require.NoError(t, err) require.NoError(t, err)
//nolint: bodyclose // closed by ReadBodyAsError //nolint: bodyclose // closed by ReadBodyAsError

View File

@ -61,11 +61,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
api.WebsocketWaitMutex.Unlock() api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done() defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r) workspaceAgent := httpmw.WorkspaceAgent(r)
build := httpmw.LatestBuild(r)
build, ok := ensureLatestBuild(ctx, api.Database, logger, rw, workspaceAgent)
if !ok {
return
}
workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID) workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
if err != nil { if err != nil {
@ -167,54 +163,6 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
} }
} }
func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logger, rw http.ResponseWriter, workspaceAgent database.WorkspaceAgent) (database.WorkspaceBuild, bool) {
resource, err := db.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching workspace agent resource.",
Detail: err.Error(),
})
return database.WorkspaceBuild{}, false
}
build, err := db.GetWorkspaceBuildByJobID(ctx, resource.JobID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching workspace build job.",
Detail: err.Error(),
})
return database.WorkspaceBuild{}, false
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
err = checkBuildIsLatest(ctx, db, build)
if err != nil {
logger.Debug(ctx, "agent tried to connect from non-latest build",
slog.F("resource", resource),
slog.F("agent", workspaceAgent),
)
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "Agent trying to connect from non-latest build.",
Detail: err.Error(),
})
return database.WorkspaceBuild{}, false
}
return build, true
}
func checkBuildIsLatest(ctx context.Context, db database.Store, build database.WorkspaceBuild) error {
latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}
func (api *API) startAgentWebsocketMonitor(ctx context.Context, func (api *API) startAgentWebsocketMonitor(ctx context.Context,
workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild, workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild,
conn *websocket.Conn, conn *websocket.Conn,
@ -494,3 +442,14 @@ func (m *agentConnectionMonitor) close() {
m.cancel() m.cancel()
m.wg.Wait() m.wg.Wait()
} }
func checkBuildIsLatest(ctx context.Context, db database.Store, build database.WorkspaceBuild) error {
latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}

View File

@ -336,7 +336,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use( r.Use(
apiKeyMiddlewareOptional, apiKeyMiddlewareOptional,
httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{ httpmw.ExtractWorkspaceAgentAndLatestBuild(httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{
DB: options.Database, DB: options.Database,
Optional: true, Optional: true,
}), }),