refactor(coderd): fetch owner information when authorizing workspace agent (#9123)

* Refactors the existing httpmw tests to use dbtestutil so that we can test them against a real database if desired,
* Modifies the GetWorkspaceAgentByAuthToken to return the owner and associated roles, removing the need for additional queries
This commit is contained in:
Cian Johnston 2023-08-21 15:49:26 +01:00 committed by GitHub
parent 6d939b726c
commit 5d4a17717f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 313 additions and 170 deletions

View File

@ -1493,13 +1493,12 @@ func (q *querier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]databas
return q.db.GetUsersByIDs(ctx, ids)
}
// GetWorkspaceAgentByAuthToken is used in http middleware to get the workspace agent.
// This should only be used by a system user in that middleware.
func (q *querier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) {
func (q *querier) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) {
// This is a system function
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return database.WorkspaceAgent{}, err
return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, err
}
return q.db.GetWorkspaceAgentByAuthToken(ctx, authToken)
return q.db.GetWorkspaceAgentAndOwnerByAuthToken(ctx, authToken)
}
func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {

View File

@ -1319,10 +1319,6 @@ func (s *MethodTestSuite) TestSystemFunctions() {
dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead)
}))
s.Run("GetWorkspaceAgentByAuthToken", s.Subtest(func(db database.Store, check *expects) {
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{})
check.Args(agt.AuthToken).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(agt)
}))
s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(int64(0))
}))

View File

@ -2962,18 +2962,72 @@ func (q *FakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]datab
return users, nil
}
func (q *FakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) {
func (q *FakeQuerier) GetWorkspaceAgentAndOwnerByAuthToken(_ context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
agent := q.workspaceAgents[i]
if agent.AuthToken == authToken {
return agent, nil
// map of build number -> row
rows := make(map[int32]database.GetWorkspaceAgentAndOwnerByAuthTokenRow)
// We want to return the latest build number
var latestBuildNumber int32
for _, agt := range q.workspaceAgents {
if agt.AuthToken != authToken {
continue
}
// get the related workspace and user
for _, res := range q.workspaceResources {
if agt.ResourceID != res.ID {
continue
}
for _, build := range q.workspaceBuilds {
if build.JobID != res.JobID {
continue
}
for _, ws := range q.workspaces {
if build.WorkspaceID != ws.ID {
continue
}
var row database.GetWorkspaceAgentAndOwnerByAuthTokenRow
row.WorkspaceID = ws.ID
usr, err := q.getUserByIDNoLock(ws.OwnerID)
if err != nil {
return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, sql.ErrNoRows
}
row.OwnerID = usr.ID
row.OwnerRoles = append(usr.RBACRoles, "member")
// We also need to get org roles for the user
row.OwnerName = usr.Username
row.WorkspaceAgent = agt
for _, mem := range q.organizationMembers {
if mem.UserID == usr.ID {
row.OwnerRoles = append(row.OwnerRoles, fmt.Sprintf("organization-member:%s", mem.OrganizationID.String()))
}
}
// And group memberships
for _, groupMem := range q.groupMembers {
if groupMem.UserID == usr.ID {
row.OwnerGroups = append(row.OwnerGroups, groupMem.GroupID.String())
}
}
// Keep track of the latest build number
rows[build.BuildNumber] = row
if build.BuildNumber > latestBuildNumber {
latestBuildNumber = build.BuildNumber
}
}
}
}
}
return database.WorkspaceAgent{}, sql.ErrNoRows
if len(rows) == 0 {
return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, sql.ErrNoRows
}
// Return the row related to the latest build
return rows[latestBuildNumber], nil
}
func (q *FakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {

View File

@ -788,11 +788,11 @@ func (m metricsStore) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]dat
return users, err
}
func (m metricsStore) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) {
func (m metricsStore) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) {
start := time.Now()
agent, err := m.s.GetWorkspaceAgentByAuthToken(ctx, authToken)
m.queryLatencies.WithLabelValues("GetWorkspaceAgentByAuthToken").Observe(time.Since(start).Seconds())
return agent, err
r0, r1 := m.s.GetWorkspaceAgentAndOwnerByAuthToken(ctx, authToken)
m.queryLatencies.WithLabelValues("GetWorkspaceAgentAndOwnerByAuthToken").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {

View File

@ -1631,19 +1631,19 @@ func (mr *MockStoreMockRecorder) GetUsersByIDs(arg0, arg1 interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsersByIDs", reflect.TypeOf((*MockStore)(nil).GetUsersByIDs), arg0, arg1)
}
// GetWorkspaceAgentByAuthToken mocks base method.
func (m *MockStore) GetWorkspaceAgentByAuthToken(arg0 context.Context, arg1 uuid.UUID) (database.WorkspaceAgent, error) {
// GetWorkspaceAgentAndOwnerByAuthToken mocks base method.
func (m *MockStore) GetWorkspaceAgentAndOwnerByAuthToken(arg0 context.Context, arg1 uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceAgentByAuthToken", arg0, arg1)
ret0, _ := ret[0].(database.WorkspaceAgent)
ret := m.ctrl.Call(m, "GetWorkspaceAgentAndOwnerByAuthToken", arg0, arg1)
ret0, _ := ret[0].(database.GetWorkspaceAgentAndOwnerByAuthTokenRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceAgentByAuthToken indicates an expected call of GetWorkspaceAgentByAuthToken.
func (mr *MockStoreMockRecorder) GetWorkspaceAgentByAuthToken(arg0, arg1 interface{}) *gomock.Call {
// GetWorkspaceAgentAndOwnerByAuthToken indicates an expected call of GetWorkspaceAgentAndOwnerByAuthToken.
func (mr *MockStoreMockRecorder) GetWorkspaceAgentAndOwnerByAuthToken(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentByAuthToken", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentByAuthToken), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentAndOwnerByAuthToken", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentAndOwnerByAuthToken), arg0, arg1)
}
// GetWorkspaceAgentByID mocks base method.

View File

@ -156,7 +156,7 @@ type sqlcQuerier interface {
// to look up references to actions. eg. a user could build a workspace
// for another user, then be deleted... we still want them to appear!
GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error)
GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (GetWorkspaceAgentAndOwnerByAuthTokenRow, error)
GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error)
GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentLifecycleStateByIDRow, error)

View File

@ -6382,54 +6382,113 @@ func (q *sqlQuerier) DeleteOldWorkspaceAgentLogs(ctx context.Context) error {
return err
}
const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one
const getWorkspaceAgentAndOwnerByAuthToken = `-- name: GetWorkspaceAgentAndOwnerByAuthToken :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, startup_script_timeout_seconds, expanded_directory, shutdown_script, shutdown_script_timeout_seconds, logs_length, logs_overflowed, startup_script_behavior, started_at, ready_at, subsystems
FROM
workspace_agents
workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.startup_script, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.startup_script_timeout_seconds, workspace_agents.expanded_directory, workspace_agents.shutdown_script, workspace_agents.shutdown_script_timeout_seconds, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.startup_script_behavior, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems,
workspaces.id AS workspace_id,
users.id AS owner_id,
users.username AS owner_name,
users.status AS owner_status,
array_cat(
array_append(users.rbac_roles, 'member'),
array_append(ARRAY[]::text[], 'organization-member:' || organization_members.organization_id::text)
)::text[] as owner_roles,
array_agg(COALESCE(group_members.group_id::text, ''))::text[] AS owner_groups
FROM users
INNER JOIN
workspaces
ON
workspaces.owner_id = users.id
INNER JOIN
workspace_builds
ON
workspace_builds.workspace_id = workspaces.id
INNER JOIN
workspace_resources
ON
workspace_resources.job_id = workspace_builds.job_id
INNER JOIN
workspace_agents
ON
workspace_agents.resource_id = workspace_resources.id
INNER JOIN -- every user is a member of some org
organization_members
ON
organization_members.user_id = users.id
LEFT JOIN -- as they may not be a member of any groups
group_members
ON
group_members.user_id = users.id
WHERE
auth_token = $1
-- TODO: we can add more conditions here, such as:
-- 1) The user must be active
-- 2) The user must not be deleted
-- 3) The workspace must be running
workspace_agents.auth_token = $1
GROUP BY
workspace_agents.id,
workspaces.id,
users.id,
organization_members.organization_id,
workspace_builds.build_number
ORDER BY
created_at DESC
workspace_builds.build_number DESC
LIMIT 1
`
func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceAgentByAuthToken, authToken)
var i WorkspaceAgent
type GetWorkspaceAgentAndOwnerByAuthTokenRow struct {
WorkspaceAgent WorkspaceAgent `db:"workspace_agent" json:"workspace_agent"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
OwnerName string `db:"owner_name" json:"owner_name"`
OwnerStatus UserStatus `db:"owner_status" json:"owner_status"`
OwnerRoles []string `db:"owner_roles" json:"owner_roles"`
OwnerGroups []string `db:"owner_groups" json:"owner_groups"`
}
func (q *sqlQuerier) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (GetWorkspaceAgentAndOwnerByAuthTokenRow, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceAgentAndOwnerByAuthToken, authToken)
var i GetWorkspaceAgentAndOwnerByAuthTokenRow
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.StartupScript,
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.Version,
&i.LastConnectedReplicaID,
&i.ConnectionTimeoutSeconds,
&i.TroubleshootingURL,
&i.MOTDFile,
&i.LifecycleState,
&i.StartupScriptTimeoutSeconds,
&i.ExpandedDirectory,
&i.ShutdownScript,
&i.ShutdownScriptTimeoutSeconds,
&i.LogsLength,
&i.LogsOverflowed,
&i.StartupScriptBehavior,
&i.StartedAt,
&i.ReadyAt,
pq.Array(&i.Subsystems),
&i.WorkspaceAgent.ID,
&i.WorkspaceAgent.CreatedAt,
&i.WorkspaceAgent.UpdatedAt,
&i.WorkspaceAgent.Name,
&i.WorkspaceAgent.FirstConnectedAt,
&i.WorkspaceAgent.LastConnectedAt,
&i.WorkspaceAgent.DisconnectedAt,
&i.WorkspaceAgent.ResourceID,
&i.WorkspaceAgent.AuthToken,
&i.WorkspaceAgent.AuthInstanceID,
&i.WorkspaceAgent.Architecture,
&i.WorkspaceAgent.EnvironmentVariables,
&i.WorkspaceAgent.OperatingSystem,
&i.WorkspaceAgent.StartupScript,
&i.WorkspaceAgent.InstanceMetadata,
&i.WorkspaceAgent.ResourceMetadata,
&i.WorkspaceAgent.Directory,
&i.WorkspaceAgent.Version,
&i.WorkspaceAgent.LastConnectedReplicaID,
&i.WorkspaceAgent.ConnectionTimeoutSeconds,
&i.WorkspaceAgent.TroubleshootingURL,
&i.WorkspaceAgent.MOTDFile,
&i.WorkspaceAgent.LifecycleState,
&i.WorkspaceAgent.StartupScriptTimeoutSeconds,
&i.WorkspaceAgent.ExpandedDirectory,
&i.WorkspaceAgent.ShutdownScript,
&i.WorkspaceAgent.ShutdownScriptTimeoutSeconds,
&i.WorkspaceAgent.LogsLength,
&i.WorkspaceAgent.LogsOverflowed,
&i.WorkspaceAgent.StartupScriptBehavior,
&i.WorkspaceAgent.StartedAt,
&i.WorkspaceAgent.ReadyAt,
pq.Array(&i.WorkspaceAgent.Subsystems),
&i.WorkspaceID,
&i.OwnerID,
&i.OwnerName,
&i.OwnerStatus,
pq.Array(&i.OwnerRoles),
pq.Array(&i.OwnerGroups),
)
return i, err
}

View File

@ -1,13 +1,3 @@
-- name: GetWorkspaceAgentByAuthToken :one
SELECT
*
FROM
workspace_agents
WHERE
auth_token = $1
ORDER BY
created_at DESC;
-- name: GetWorkspaceAgentByID :one
SELECT
*
@ -200,3 +190,56 @@ WHERE
WHERE
wb.workspace_id = @workspace_id :: uuid
);
-- name: GetWorkspaceAgentAndOwnerByAuthToken :one
SELECT
sqlc.embed(workspace_agents),
workspaces.id AS workspace_id,
users.id AS owner_id,
users.username AS owner_name,
users.status AS owner_status,
array_cat(
array_append(users.rbac_roles, 'member'),
array_append(ARRAY[]::text[], 'organization-member:' || organization_members.organization_id::text)
)::text[] as owner_roles,
array_agg(COALESCE(group_members.group_id::text, ''))::text[] AS owner_groups
FROM users
INNER JOIN
workspaces
ON
workspaces.owner_id = users.id
INNER JOIN
workspace_builds
ON
workspace_builds.workspace_id = workspaces.id
INNER JOIN
workspace_resources
ON
workspace_resources.job_id = workspace_builds.job_id
INNER JOIN
workspace_agents
ON
workspace_agents.resource_id = workspace_resources.id
INNER JOIN -- every user is a member of some org
organization_members
ON
organization_members.user_id = users.id
LEFT JOIN -- as they may not be a member of any groups
group_members
ON
group_members.user_id = users.id
WHERE
-- TODO: we can add more conditions here, such as:
-- 1) The user must be active
-- 2) The user must not be deleted
-- 3) The workspace must be running
workspace_agents.auth_token = @auth_token
GROUP BY
workspace_agents.id,
workspaces.id,
users.id,
organization_members.organization_id,
workspace_builds.build_number
ORDER BY
workspace_builds.build_number DESC
LIMIT 1;

View File

@ -74,8 +74,9 @@ func ExtractWorkspaceAgent(opts ExtractWorkspaceAgentConfig) func(http.Handler)
})
return
}
//nolint:gocritic // System needs to be able to get workspace agents.
agent, err := opts.DB.GetWorkspaceAgentByAuthToken(dbauthz.AsSystemRestricted(ctx), token)
row, err := opts.DB.GetWorkspaceAgentAndOwnerByAuthToken(dbauthz.AsSystemRestricted(ctx), token)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
optionalWrite(http.StatusUnauthorized, codersdk.Response{
@ -86,56 +87,23 @@ func ExtractWorkspaceAgent(opts ExtractWorkspaceAgentConfig) func(http.Handler)
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent.",
Message: "Internal error checking workspace agent authorization.",
Detail: err.Error(),
})
return
}
//nolint:gocritic // System needs to be able to get workspace agents.
subject, err := getAgentSubject(dbauthz.AsSystemRestricted(ctx), opts.DB, agent)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent.",
Detail: err.Error(),
})
return
}
subject := rbac.Subject{
ID: row.OwnerID.String(),
Roles: rbac.RoleNames(row.OwnerRoles),
Groups: row.OwnerGroups,
Scope: rbac.WorkspaceAgentScope(row.WorkspaceID, row.OwnerID),
}.WithCachedASTValue()
ctx = context.WithValue(ctx, workspaceAgentContextKey{}, agent)
ctx = context.WithValue(ctx, workspaceAgentContextKey{}, row.WorkspaceAgent)
// Also set the dbauthz actor for the request.
ctx = dbauthz.As(ctx, subject)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}
func getAgentSubject(ctx context.Context, db database.Store, agent database.WorkspaceAgent) (rbac.Subject, error) {
// TODO: make a different query that gets the workspace owner and roles along with the agent.
workspace, err := db.GetWorkspaceByAgentID(ctx, agent.ID)
if err != nil {
return rbac.Subject{}, err
}
user, err := db.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
return rbac.Subject{}, err
}
roles, err := db.GetAuthorizationUserRoles(ctx, user.ID)
if err != nil {
return rbac.Subject{}, err
}
// A user that creates a workspace can use this agent auth token and
// impersonate the workspace. So to prevent privilege escalation, the
// subject inherits the roles of the user that owns the workspace.
// We then add a workspace-agent scope to limit the permissions
// to only what the workspace agent needs.
return rbac.Subject{
ID: user.ID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.WorkspaceAgentScope(workspace.ID, user.ID),
}.WithCachedASTValue(), nil
}

View File

@ -10,8 +10,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
)
@ -19,26 +19,19 @@ import (
func TestWorkspaceAgent(t *testing.T) {
t.Parallel()
setup := func(db database.Store, token uuid.UUID) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set(codersdk.SessionTokenHeader, token.String())
return r
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{
db, _ := dbtestutil.NewDB(t)
req, rtr := setup(t, db, uuid.New(), httpmw.ExtractWorkspaceAgent(
httpmw.ExtractWorkspaceAgentConfig{
DB: db,
Optional: false,
}),
)
rtr.Get("/", nil)
r := setup(db, uuid.New())
}))
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
req.Header.Set(codersdk.SessionTokenHeader, uuid.New().String())
rtr.ServeHTTP(rw, req)
res := rw.Result()
defer res.Body.Close()
@ -47,42 +40,71 @@ func TestWorkspaceAgent(t *testing.T) {
t.Run("Found", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
var (
user = dbgen.User(t, db, database.User{})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
})
job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{})
resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
JobID: job.ID,
})
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resource.ID,
})
)
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{
db, _ := dbtestutil.NewDB(t)
authToken := uuid.New()
req, rtr := setup(t, db, authToken, httpmw.ExtractWorkspaceAgent(
httpmw.ExtractWorkspaceAgentConfig{
DB: db,
Optional: false,
}),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.WorkspaceAgent(r)
rw.WriteHeader(http.StatusOK)
})
r := setup(db, agent.AuthToken)
}))
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
req.Header.Set(codersdk.SessionTokenHeader, authToken.String())
rtr.ServeHTTP(rw, req)
res := rw.Result()
defer res.Body.Close()
t.Cleanup(func() { _ = res.Body.Close() })
require.Equal(t, http.StatusOK, res.StatusCode)
})
}
func setup(t testing.TB, db database.Store, authToken uuid.UUID, mw func(http.Handler) http.Handler) (*http.Request, http.Handler) {
t.Helper()
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{
Status: database.UserStatusActive,
})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
template := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
ActiveVersionID: templateVersion.ID,
CreatedBy: user.ID,
})
workspace := dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: template.ID,
})
job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
OrganizationID: org.ID,
})
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
JobID: job.ID,
TemplateVersionID: templateVersion.ID,
})
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resource.ID,
AuthToken: authToken,
})
req := httptest.NewRequest("GET", "/", nil)
rtr := chi.NewRouter()
rtr.Use(mw)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.WorkspaceAgent(r)
rw.WriteHeader(http.StatusOK)
})
return req, rtr
}

View File

@ -406,7 +406,9 @@ func TestWorkspaceAgentListen(t *testing.T) {
_, err = agentClient.Listen(ctx)
require.Error(t, err)
require.ErrorContains(t, err, "build is outdated")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusForbidden, sdkErr.StatusCode())
})
}