feat: Add initial AuthzQuerier implementation (#5919)

feat: Add initial AuthzQuerier implementation
- Adds package database/dbauthz that adds a database.Store implementation where each method goes through AuthZ checks
- Implements all database.Store methods on AuthzQuerier
- Updates and fixes unit tests where required
- Updates coderd initialization to use AuthzQuerier if codersdk.ExperimentAuthzQuerier is enabled
This commit is contained in:
Steven Masley 2023-02-14 08:27:06 -06:00 committed by GitHub
parent ebdfdc749d
commit 6fb8aff6d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 5013 additions and 136 deletions

View File

@ -15,10 +15,10 @@ import (
// activityBumpWorkspace automatically bumps the workspace's auto-off timer
// if it is set to expire soon.
func activityBumpWorkspace(log slog.Logger, db database.Store, workspaceID uuid.UUID) {
func activityBumpWorkspace(ctx context.Context, log slog.Logger, db database.Store, workspaceID uuid.UUID) {
// We set a short timeout so if the app is under load, these
// low priority operations fail first.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
ctx, cancel := context.WithTimeout(ctx, time.Second*15)
defer cancel()
err := db.InTx(func(s database.Store) error {
@ -82,9 +82,12 @@ func activityBumpWorkspace(log slog.Logger, db database.Store, workspaceID uuid.
return nil
}, nil)
if err != nil {
log.Error(ctx, "bump failed", slog.Error(err),
slog.F("workspace_id", workspaceID),
)
if !xerrors.Is(err, context.Canceled) {
// Bump will fail if the context is cancelled, but this is ok.
log.Error(ctx, "bump failed", slog.Error(err),
slog.F("workspace_id", workspaceID),
)
}
return
}

View File

@ -51,6 +51,28 @@ type HTTPAuthorizer struct {
// return
// }
func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
// The experiment does not replace ALL rbac checks, but does replace most.
// This statement aborts early on the checks that will be removed in the
// future when this experiment is default.
if api.Experiments.Enabled(codersdk.ExperimentAuthzQuerier) {
// Some resource types do not interact with the persistent layer and
// we need to keep these checks happening in the API layer.
switch object.RBACObject().Type {
case rbac.ResourceWorkspaceExecution.Type:
// This is not a db resource, always in API layer
case rbac.ResourceDeploymentConfig.Type:
// For metric cache items like DAU, we do not hit the DB.
// Some db actions are in asserted in the authz layer.
case rbac.ResourceReplicas.Type:
// Replica rbac is checked for adding and removing replicas.
case rbac.ResourceProvisionerDaemon.Type:
// Provisioner rbac is checked for adding and removing provisioners.
case rbac.ResourceDebugInfo.Type:
// This is not a db resource, always in API layer.
default:
return true
}
}
return api.HTTPAuth.Authorize(r, action, object)
}

View File

@ -12,6 +12,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/autobuild/schedule"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
)
// Executor automatically starts or stops workspaces.
@ -33,7 +34,8 @@ type Stats struct {
// New returns a new autobuild executor.
func New(ctx context.Context, db database.Store, log slog.Logger, tick <-chan time.Time) *Executor {
le := &Executor{
ctx: ctx,
//nolint:gocritic // TODO: make an autostart role instead of using System
ctx: dbauthz.AsSystem(ctx),
db: db,
tick: tick,
log: log,

View File

@ -42,6 +42,7 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/awsidentity"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbtype"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/gitsshkey"
@ -157,13 +158,6 @@ func New(options *Options) *API {
options = &Options{}
}
experiments := initExperiments(options.Logger, options.DeploymentConfig.Experiments.Value, options.DeploymentConfig.Experimental.Value)
// TODO: remove this once we promote authz_querier out of experiments.
if experiments.Enabled(codersdk.ExperimentAuthzQuerier) {
panic("Coming soon!")
// if _, ok := (options.Database).(*authzquery.AuthzQuerier); !ok {
// options.Database = authzquery.NewAuthzQuerier(options.Database, options.Authorizer)
// }
}
if options.AppHostname != "" && options.AppHostnameRegex == nil || options.AppHostname == "" && options.AppHostnameRegex != nil {
panic("coderd: both AppHostname and AppHostnameRegex must be set or unset")
}
@ -204,6 +198,14 @@ func New(options *Options) *API {
if options.Auditor == nil {
options.Auditor = audit.NewNop()
}
// TODO: remove this once we promote authz_querier out of experiments.
if experiments.Enabled(codersdk.ExperimentAuthzQuerier) {
options.Database = dbauthz.New(
options.Database,
options.Authorizer,
options.Logger.Named("authz_querier"),
)
}
if options.SetUserGroups == nil {
options.SetUserGroups = func(context.Context, database.Store, uuid.UUID, []string) error { return nil }
}
@ -304,8 +306,10 @@ func New(options *Options) *API {
DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value,
Optional: true,
}),
httpmw.ExtractUserParam(api.Database, false),
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
httpmw.AsAuthzSystem(
httpmw.ExtractUserParam(api.Database, false),
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
),
),
// Build-Version is helpful for debugging.
func(next http.Handler) http.Handler {
@ -332,11 +336,13 @@ func New(options *Options) *API {
DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value,
Optional: true,
}),
// Redirect to the login page if the user tries to open an app with
// "me" as the username and they are not logged in.
httpmw.ExtractUserParam(api.Database, true),
// Extracts the <workspace.agent> from the url
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
httpmw.AsAuthzSystem(
// Redirect to the login page if the user tries to open an app with
// "me" as the username and they are not logged in.
httpmw.ExtractUserParam(api.Database, true),
// Extracts the <workspace.agent> from the url
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
),
)
r.HandleFunc("/*", api.workspaceAppsProxyPath)
}

View File

@ -12,7 +12,6 @@ import (
"testing"
"time"
"github.com/coder/coder/cryptorand"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
@ -20,8 +19,9 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/codersdk"
@ -30,12 +30,6 @@ import (
)
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
// For any route using SQL filters, we need to know if the database is an
// in memory fake. This is because the in memory fake does not use SQL, and
// still uses rego. So this boolean indicates how to assert the expected
// behavior.
_, isMemoryDB := a.api.Database.(dbfake.FakeDatabase)
// Some quick reused objects
workspaceRBACObj := rbac.ResourceWorkspace.WithID(a.Workspace.ID).InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
workspaceExecObj := rbac.ResourceWorkspaceExecution.WithID(a.Workspace.ID).InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
@ -269,16 +263,17 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
"POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
// Endpoints that use the SQLQuery filter.
// For any route using SQL filters, we do not check authorization.
// This is because the in memory fake does not use SQL.
"GET:/api/v2/workspaces/": {
StatusCode: http.StatusOK,
NoAuthorize: !isMemoryDB,
NoAuthorize: true,
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceWorkspace,
},
"GET:/api/v2/organizations/{organization}/templates": {
StatusCode: http.StatusOK,
NoAuthorize: !isMemoryDB,
NoAuthorize: true,
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceTemplate,
},

View File

@ -2,15 +2,21 @@ package coderdtest_test
import (
"context"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
)
func TestAuthorizeAllEndpoints(t *testing.T) {
if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) {
t.Skip("Skipping TestAuthorizeAllEndpoints for authz_querier experiment")
}
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
// Required for any subdomain-based proxy tests to pass.

View File

@ -35,6 +35,7 @@ import (
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/afero"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
@ -58,6 +59,7 @@ import (
"github.com/coder/coder/coderd/autobuild/executor"
"github.com/coder/coder/coderd/awsidentity"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/gitsshkey"
@ -179,12 +181,13 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can
options.Database, options.Pubsub = dbtestutil.NewDB(t)
}
// TODO: remove this once we're ready to enable authz querier by default.
if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), "authz_querier") {
panic("Coming soon!")
// if options.Authorizer != nil {
// options.Authorizer = &RecordingAuthorizer{}
// }
// options.Database = authzquery.NewAuthzQuerier(options.Database, options.Authorizer)
if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) {
if options.Authorizer == nil {
options.Authorizer = &RecordingAuthorizer{
Wrapped: rbac.NewAuthorizer(prometheus.NewRegistry()),
}
}
options.Database = dbauthz.New(options.Database, options.Authorizer, slogtest.Make(t, nil).Leveled(slog.LevelDebug))
}
if options.DeploymentConfig == nil {
options.DeploymentConfig = DeploymentConfig(t)

View File

@ -0,0 +1,387 @@
package dbauthz
import (
"context"
"database/sql"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/open-policy-agent/opa/topdown"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/rbac"
)
var _ database.Store = (*querier)(nil)
var (
// NoActorError wraps ErrNoRows for the api to return a 404. This is the correct
// response when the user is not authorized.
NoActorError = xerrors.Errorf("no authorization actor in context: %w", sql.ErrNoRows)
)
// NotAuthorizedError is a sentinel error that unwraps to sql.ErrNoRows.
// This allows the internal error to be read by the caller if needed. Otherwise
// it will be handled as a 404.
type NotAuthorizedError struct {
Err error
}
func (e NotAuthorizedError) Error() string {
return fmt.Sprintf("unauthorized: %s", e.Err.Error())
}
// Unwrap will always unwrap to a sql.ErrNoRows so the API returns a 404.
// So 'errors.Is(err, sql.ErrNoRows)' will always be true.
func (NotAuthorizedError) Unwrap() error {
return sql.ErrNoRows
}
func logNotAuthorizedError(ctx context.Context, logger slog.Logger, err error) error {
// Only log the errors if it is an UnauthorizedError error.
internalError := new(rbac.UnauthorizedError)
if err != nil && xerrors.As(err, &internalError) {
e := new(topdown.Error)
if xerrors.As(err, &e) || e.Code == topdown.CancelErr {
// For some reason rego changes a cancelled context to a topdown.CancelErr. We
// expect to check for cancelled context errors if the user cancels the request,
// so we should change the error to a context.Canceled error.
//
// NotAuthorizedError is == to sql.ErrNoRows, which is not correct
// if it's actually a cancelled context.
internalError.SetInternal(context.Canceled)
return internalError
}
logger.Debug(ctx, "unauthorized",
slog.F("internal", internalError.Internal()),
slog.F("input", internalError.Input()),
slog.Error(err),
)
}
return NotAuthorizedError{
Err: err,
}
}
// querier is a wrapper around the database store that performs authorization
// checks before returning data. All querier methods expect an authorization
// subject present in the context. If no subject is present, most methods will
// fail.
//
// Use WithAuthorizeContext to set the authorization subject in the context for
// the common user case.
type querier struct {
db database.Store
auth rbac.Authorizer
log slog.Logger
}
func New(db database.Store, authorizer rbac.Authorizer, logger slog.Logger) database.Store {
// If the underlying db store is already a querier, return it.
// Do not double wrap.
if _, ok := db.(*querier); ok {
return db
}
return &querier{
db: db,
auth: authorizer,
log: logger,
}
}
// authorizeContext is a helper function to authorize an action on an object.
func (q *querier) authorizeContext(ctx context.Context, action rbac.Action, object rbac.Objecter) error {
act, ok := ActorFromContext(ctx)
if !ok {
return NoActorError
}
err := q.auth.Authorize(ctx, act, action, object.RBACObject())
if err != nil {
return logNotAuthorizedError(ctx, q.log, err)
}
return nil
}
type authContextKey struct{}
// ActorFromContext returns the authorization subject from the context.
// All authentication flows should set the authorization subject in the context.
// If no actor is present, the function returns false.
func ActorFromContext(ctx context.Context) (rbac.Subject, bool) {
a, ok := ctx.Value(authContextKey{}).(rbac.Subject)
return a, ok
}
// AsSystem returns a context with a system actor. This is used for internal
// system operations that are not tied to any particular actor.
// When you use this function, be sure to add a //nolint comment
// explaining why it is necessary.
//
// We trust you have received the usual lecture from the local System
// Administrator. It usually boils down to these three things:
// #1) Respect the privacy of others.
// #2) Think before you type.
// #3) With great power comes great responsibility.
func AsSystem(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, rbac.Subject{
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Name: "system",
DisplayName: "System",
Site: []rbac.Permission{
{
ResourceType: rbac.ResourceWildcard.Type,
Action: rbac.WildcardSymbol,
},
},
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
},
}),
Scope: rbac.ScopeAll,
},
)
}
var AsRemoveActor = rbac.Subject{
ID: "remove-actor",
}
// As returns a context with the given actor stored in the context.
// This is used for cases where the actor touching the database is not the
// actor stored in the context.
// When you use this function, be sure to add a //nolint comment
// explaining why it is necessary.
func As(ctx context.Context, actor rbac.Subject) context.Context {
if actor.Equal(AsRemoveActor) {
// AsRemoveActor is a special case that is used to indicate that the actor
// should be removed from the context.
return context.WithValue(ctx, authContextKey{}, nil)
}
return context.WithValue(ctx, authContextKey{}, actor)
}
//
// Generic functions used to implement the database.Store methods.
//
// insert runs an rbac.ActionCreate on the rbac object argument before
// running the insertFunc. The insertFunc is expected to return the object that
// was inserted.
func insert[
ObjectType any,
ArgumentType any,
Insert func(ctx context.Context, arg ArgumentType) (ObjectType, error),
](
logger slog.Logger,
authorizer rbac.Authorizer,
object rbac.Objecter,
insertFunc Insert,
) Insert {
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
// Fetch the rbac subject
act, ok := ActorFromContext(ctx)
if !ok {
return empty, NoActorError
}
// Authorize the action
err = authorizer.Authorize(ctx, act, rbac.ActionCreate, object.RBACObject())
if err != nil {
return empty, logNotAuthorizedError(ctx, logger, err)
}
// Insert the database object
return insertFunc(ctx, arg)
}
}
func deleteQ[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
Delete func(ctx context.Context, arg ArgumentType) error,
](
logger slog.Logger,
authorizer rbac.Authorizer,
fetchFunc Fetch,
deleteFunc Delete,
) Delete {
return fetchAndExec(logger, authorizer,
rbac.ActionDelete, fetchFunc, deleteFunc)
}
func updateWithReturn[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
UpdateQuery func(ctx context.Context, arg ArgumentType) (ObjectType, error),
](
logger slog.Logger,
authorizer rbac.Authorizer,
fetchFunc Fetch,
updateQuery UpdateQuery,
) UpdateQuery {
return fetchAndQuery(logger, authorizer, rbac.ActionUpdate, fetchFunc, updateQuery)
}
func update[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
Exec func(ctx context.Context, arg ArgumentType) error,
](
logger slog.Logger,
authorizer rbac.Authorizer,
fetchFunc Fetch,
updateExec Exec,
) Exec {
return fetchAndExec(logger, authorizer, rbac.ActionUpdate, fetchFunc, updateExec)
}
// fetch is a generic function that wraps a database
// query function (returns an object and an error) with authorization. The
// returned function has the same arguments as the database function.
//
// The database query function will **ALWAYS** hit the database, even if the
// user cannot read the resource. This is because the resource details are
// required to run a proper authorization check.
func fetch[
ArgumentType any,
ObjectType rbac.Objecter,
DatabaseFunc func(ctx context.Context, arg ArgumentType) (ObjectType, error),
](
logger slog.Logger,
authorizer rbac.Authorizer,
f DatabaseFunc,
) DatabaseFunc {
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
// Fetch the rbac subject
act, ok := ActorFromContext(ctx)
if !ok {
return empty, NoActorError
}
// Fetch the database object
object, err := f(ctx, arg)
if err != nil {
return empty, xerrors.Errorf("fetch object: %w", err)
}
// Authorize the action
err = authorizer.Authorize(ctx, act, rbac.ActionRead, object.RBACObject())
if err != nil {
return empty, logNotAuthorizedError(ctx, logger, err)
}
return object, nil
}
}
// fetchAndExec uses fetchAndQuery but only returns the error. The naming comes
// from SQL 'exec' functions which only return an error.
// See fetchAndQuery for more information.
func fetchAndExec[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
Exec func(ctx context.Context, arg ArgumentType) error,
](
logger slog.Logger,
authorizer rbac.Authorizer,
action rbac.Action,
fetchFunc Fetch,
execFunc Exec,
) Exec {
f := fetchAndQuery(logger, authorizer, action, fetchFunc, func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
return empty, execFunc(ctx, arg)
})
return func(ctx context.Context, arg ArgumentType) error {
_, err := f(ctx, arg)
return err
}
}
// fetchAndQuery is a generic function that wraps a database fetch and query.
// A query has potential side effects in the database (update, delete, etc).
// The fetch is used to know which rbac object the action should be asserted on
// **before** the query runs. The returns from the fetch are only used to
// assert rbac. The final return of this function comes from the Query function.
func fetchAndQuery[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
Query func(ctx context.Context, arg ArgumentType) (ObjectType, error),
](
logger slog.Logger,
authorizer rbac.Authorizer,
action rbac.Action,
fetchFunc Fetch,
queryFunc Query,
) Query {
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
// Fetch the rbac subject
act, ok := ActorFromContext(ctx)
if !ok {
return empty, NoActorError
}
// Fetch the database object
object, err := fetchFunc(ctx, arg)
if err != nil {
return empty, xerrors.Errorf("fetch object: %w", err)
}
// Authorize the action
err = authorizer.Authorize(ctx, act, action, object.RBACObject())
if err != nil {
return empty, logNotAuthorizedError(ctx, logger, err)
}
return queryFunc(ctx, arg)
}
}
// fetchWithPostFilter is like fetch, but works with lists of objects.
// SQL filters are much more optimal.
func fetchWithPostFilter[
ArgumentType any,
ObjectType rbac.Objecter,
DatabaseFunc func(ctx context.Context, arg ArgumentType) ([]ObjectType, error),
](
authorizer rbac.Authorizer,
f DatabaseFunc,
) DatabaseFunc {
return func(ctx context.Context, arg ArgumentType) (empty []ObjectType, err error) {
// Fetch the rbac subject
act, ok := ActorFromContext(ctx)
if !ok {
return empty, NoActorError
}
// Fetch the database object
objects, err := f(ctx, arg)
if err != nil {
return nil, xerrors.Errorf("fetch object: %w", err)
}
// Authorize the action
return rbac.Filter(ctx, authorizer, act, rbac.ActionRead, objects)
}
}
// prepareSQLFilter is a helper function that prepares a SQL filter using the
// given authorization context.
func prepareSQLFilter(ctx context.Context, authorizer rbac.Authorizer, action rbac.Action, resourceType string) (rbac.PreparedAuthorized, error) {
act, ok := ActorFromContext(ctx)
if !ok {
return nil, NoActorError
}
return authorizer.Prepare(ctx, act, action, resourceType)
}

View File

@ -0,0 +1,151 @@
package dbauthz_test
import (
"context"
"reflect"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/rbac"
)
func TestAsNoActor(t *testing.T) {
t.Parallel()
t.Run("AsRemoveActor", func(t *testing.T) {
t.Parallel()
_, ok := dbauthz.ActorFromContext(context.Background())
require.False(t, ok, "no actor should be present")
})
t.Run("AsActor", func(t *testing.T) {
t.Parallel()
ctx := dbauthz.As(context.Background(), coderdtest.RandomRBACSubject())
_, ok := dbauthz.ActorFromContext(ctx)
require.True(t, ok, "actor present")
})
t.Run("DeleteActor", func(t *testing.T) {
t.Parallel()
// First set an actor
ctx := dbauthz.As(context.Background(), coderdtest.RandomRBACSubject())
_, ok := dbauthz.ActorFromContext(ctx)
require.True(t, ok, "actor present")
// Delete the actor
ctx = dbauthz.As(ctx, dbauthz.AsRemoveActor)
_, ok = dbauthz.ActorFromContext(ctx)
require.False(t, ok, "actor should be deleted")
})
}
func TestPing(t *testing.T) {
t.Parallel()
q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{}, slog.Make())
_, err := q.Ping(context.Background())
require.NoError(t, err, "must not error")
}
// TestInTX is not perfect, just checks that it properly checks auth.
func TestInTX(t *testing.T) {
t.Parallel()
db := dbfake.New()
q := dbauthz.New(db, &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: xerrors.New("custom error")},
}, slog.Make())
actor := rbac.Subject{
ID: uuid.NewString(),
Roles: rbac.RoleNames{rbac.RoleOwner()},
Groups: []string{},
Scope: rbac.ScopeAll,
}
w := dbgen.Workspace(t, db, database.Workspace{})
ctx := dbauthz.As(context.Background(), actor)
err := q.InTx(func(tx database.Store) error {
// The inner tx should use the parent's authz
_, err := tx.GetWorkspaceByID(ctx, w.ID)
return err
}, nil)
require.Error(t, err, "must error")
require.ErrorAs(t, err, &dbauthz.NotAuthorizedError{}, "must be an authorized error")
}
// TestNew should not double wrap a querier.
func TestNew(t *testing.T) {
t.Parallel()
var (
db = dbfake.New()
exp = dbgen.Workspace(t, db, database.Workspace{})
rec = &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
}
subj = rbac.Subject{}
ctx = dbauthz.As(context.Background(), rbac.Subject{})
)
// Double wrap should not cause an actual double wrap. So only 1 rbac call
// should be made.
az := dbauthz.New(db, rec, slog.Make())
az = dbauthz.New(az, rec, slog.Make())
w, err := az.GetWorkspaceByID(ctx, exp.ID)
require.NoError(t, err, "must not error")
require.Equal(t, exp, w, "must be equal")
rec.AssertActor(t, subj, rec.Pair(rbac.ActionRead, exp))
require.NoError(t, rec.AllAsserted(), "should only be 1 rbac call")
}
// TestDBAuthzRecursive is a simple test to search for infinite recursion
// bugs. It isn't perfect, and only catches a subset of the possible bugs
// as only the first db call will be made. But it is better than nothing.
func TestDBAuthzRecursive(t *testing.T) {
t.Parallel()
q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
}, slog.Make())
actor := rbac.Subject{
ID: uuid.NewString(),
Roles: rbac.RoleNames{rbac.RoleOwner()},
Groups: []string{},
Scope: rbac.ScopeAll,
}
for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ {
var ins []reflect.Value
ctx := dbauthz.As(context.Background(), actor)
ins = append(ins, reflect.ValueOf(ctx))
method := reflect.TypeOf(q).Method(i)
for i := 2; i < method.Type.NumIn(); i++ {
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
}
if method.Name == "InTx" || method.Name == "Ping" {
continue
}
// Log the name of the last method, so if there is a panic, it is
// easy to know which method failed.
// t.Log(method.Name)
// Call the function. Any infinite recursion will stack overflow.
reflect.ValueOf(q).Method(i).Call(ins)
}
}
func must[T any](value T, err error) T {
if err != nil {
panic(err)
}
return value
}

View File

@ -0,0 +1,17 @@
// Package dbauthz provides an authorization layer on top of the database. This
// package exposes an interface that is currently a 1:1 mapping with
// database.Store.
//
// The same cultural rules apply to this package as they do to database.Store.
// Meaning that each method implemented should keep the number of database
// queries as close to 1 as possible. Each method should do 1 thing, with no
// unexpected side effects (eg: updating multiple tables in a single method).
//
// Do not implement business logic in this package. Only authorization related
// logic should be implemented here. In most cases, this should only be a call to
// the rbac authorizer.
//
// When a new database method is added to database.Store, it should be added to
// this package as well. The unit test "Accounting" will ensure all methods are
// tested. See other unit tests for examples on how to write these.
package dbauthz

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,377 @@
package dbauthz_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"testing"
"golang.org/x/xerrors"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"cdr.dev/slog"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/coderd/util/slice"
)
var (
skipMethods = map[string]string{
"InTx": "Not relevant",
"Ping": "Not relevant",
}
)
// TestMethodTestSuite runs MethodTestSuite.
// In order for 'go test' to run this suite, we need to create
// a normal test function and pass our suite to suite.Run
// nolint: paralleltest
func TestMethodTestSuite(t *testing.T) {
suite.Run(t, new(MethodTestSuite))
}
// MethodTestSuite runs all methods tests for querier. We use
// a test suite so we can account for all functions tested on the querier.
// We can then assert all methods were tested and asserted for proper RBAC
// checks. This forces RBAC checks to be written for all methods.
// Additionally, the way unit tests are written allows for easily executing
// a single test for debugging.
type MethodTestSuite struct {
suite.Suite
// methodAccounting counts all methods called by a 'RunMethodTest'
methodAccounting map[string]int
}
// SetupSuite sets up the suite by creating a map of all methods on querier
// and setting their count to 0.
func (s *MethodTestSuite) SetupSuite() {
az := dbauthz.New(nil, nil, slog.Make())
// Take the underlying type of the interface.
azt := reflect.TypeOf(az).Elem()
s.methodAccounting = make(map[string]int)
for i := 0; i < azt.NumMethod(); i++ {
method := azt.Method(i)
if _, ok := skipMethods[method.Name]; ok {
// We can't use s.T().Skip as this will skip the entire suite.
s.T().Logf("Skipping method %q: %s", method.Name, skipMethods[method.Name])
continue
}
s.methodAccounting[method.Name] = 0
}
}
// TearDownSuite asserts that all methods were called at least once.
func (s *MethodTestSuite) TearDownSuite() {
s.Run("Accounting", func() {
t := s.T()
notCalled := []string{}
for m, c := range s.methodAccounting {
if c <= 0 {
notCalled = append(notCalled, m)
}
}
sort.Strings(notCalled)
for _, m := range notCalled {
t.Errorf("Method never called: %q", m)
}
})
}
// Subtest is a helper function that returns a function that can be passed to
// s.Run(). This function will run the test case for the method that is being
// tested. The check parameter is used to assert the results of the method.
// If the caller does not use the `check` parameter, the test will fail.
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
return func() {
t := s.T()
testName := s.T().Name()
names := strings.Split(testName, "/")
methodName := names[len(names)-1]
s.methodAccounting[methodName]++
db := dbfake.New()
fakeAuthorizer := &coderdtest.FakeAuthorizer{
AlwaysReturn: nil,
}
rec := &coderdtest.RecordingAuthorizer{
Wrapped: fakeAuthorizer,
}
az := dbauthz.New(db, rec, slog.Make())
actor := rbac.Subject{
ID: uuid.NewString(),
Roles: rbac.RoleNames{rbac.RoleOwner()},
Groups: []string{},
Scope: rbac.ScopeAll,
}
ctx := dbauthz.As(context.Background(), actor)
var testCase expects
testCaseF(db, &testCase)
// Check the developer added assertions. If there are no assertions,
// an empty list should be passed.
s.Require().False(testCase.assertions == nil, "rbac assertions not set, use the 'check' parameter")
// Find the method with the name of the test.
var callMethod func(ctx context.Context) ([]reflect.Value, error)
azt := reflect.TypeOf(az)
MethodLoop:
for i := 0; i < azt.NumMethod(); i++ {
method := azt.Method(i)
if method.Name == methodName {
methodF := reflect.ValueOf(az).Method(i)
callMethod = func(ctx context.Context) ([]reflect.Value, error) {
resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.inputs...))
return splitResp(t, resp)
}
break MethodLoop
}
}
require.NotNil(t, callMethod, "method %q does not exist", methodName)
if len(testCase.assertions) > 0 {
// Only run these tests if we know the underlying call makes
// rbac assertions.
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
}
if len(testCase.assertions) > 0 ||
slice.Contains([]string{
"GetAuthorizedWorkspaces",
"GetAuthorizedTemplates",
}, methodName) {
// Some methods do not make RBAC assertions because they use
// SQL. We still want to test that they return an error if the
// actor is not set.
s.NoActorErrorTest(callMethod)
}
// Always run
s.Run("Success", func() {
rec.Reset()
fakeAuthorizer.AlwaysReturn = nil
outputs, err := callMethod(ctx)
s.NoError(err, "method %q returned an error", methodName)
// Some tests may not care about the outputs, so we only assert if
// they are provided.
if testCase.outputs != nil {
// Assert the required outputs
s.Equal(len(testCase.outputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
for i := range outputs {
a, b := testCase.outputs[i].Interface(), outputs[i].Interface()
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
// Order does not matter
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
} else {
s.Equal(a, b, "method %q returned unexpected output %d", methodName, i)
}
}
}
var pairs []coderdtest.ActionObjectPair
for _, assrt := range testCase.assertions {
for _, action := range assrt.Actions {
pairs = append(pairs, coderdtest.ActionObjectPair{
Action: action,
Object: assrt.Object,
})
}
}
rec.AssertActor(s.T(), actor, pairs...)
s.NoError(rec.AllAsserted(), "all rbac calls must be asserted")
})
}
}
func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context) ([]reflect.Value, error)) {
s.Run("AsRemoveActor", func() {
// Call without any actor
_, err := callMethod(context.Background())
s.ErrorIs(err, dbauthz.NoActorError, "method should return NoActorError error when no actor is provided")
})
}
// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz.
// Asserts that the error returned is a NotAuthorizedError.
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
s.Run("NotAuthorized", func() {
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
// If we have assertions, that means the method should FAIL
// if RBAC will disallow the request. The returned error should
// be expected to be a NotAuthorizedError.
resp, err := callMethod(ctx)
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
// any case where the error is nil and the response is an empty slice.
if err != nil || !hasEmptySliceResponse(resp) {
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
s.Errorf(err, "method should an error with disallow authz")
s.ErrorIsf(err, sql.ErrNoRows, "error should match sql.ErrNoRows")
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
}
})
}
func hasEmptySliceResponse(values []reflect.Value) bool {
for _, r := range values {
if r.Kind() == reflect.Slice || r.Kind() == reflect.Array {
if r.Len() == 0 {
return true
}
}
}
return false
}
func splitResp(t *testing.T, values []reflect.Value) ([]reflect.Value, error) {
outputs := []reflect.Value{}
for _, r := range values {
if r.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
if r.IsNil() {
// Error is found, but it's nil!
return outputs, nil
}
err, ok := r.Interface().(error)
if !ok {
t.Fatal("error is not an error?!")
}
return outputs, err
}
outputs = append(outputs, r)
} //nolint: unreachable
t.Fatal("no expected error value found in responses (error can be nil)")
return nil, nil // unreachable, required to compile
}
// expects is used to build a test case for a method.
// It includes the expected inputs, rbac assertions, and expected outputs.
type expects struct {
inputs []reflect.Value
assertions []AssertRBAC
// outputs is optional. Can assert non-error return values.
outputs []reflect.Value
}
// Asserts is required. Asserts the RBAC authorize calls that should be made.
// If no RBAC calls are expected, pass an empty list: 'm.Asserts()'
func (m *expects) Asserts(pairs ...any) *expects {
m.assertions = asserts(pairs...)
return m
}
// Args is required. The arguments to be provided to the method.
// If there are no arguments, pass an empty list: 'm.Args()'
// The first context argument should not be included, as the test suite
// will provide it.
func (m *expects) Args(args ...any) *expects {
m.inputs = values(args...)
return m
}
// Returns is optional. If it is never called, it will not be asserted.
func (m *expects) Returns(rets ...any) *expects {
m.outputs = values(rets...)
return m
}
// AssertRBAC contains the object and actions to be asserted.
type AssertRBAC struct {
Object rbac.Object
Actions []rbac.Action
}
// values is a convenience method for creating []reflect.Value.
//
// values(workspace, template, ...)
//
// is equivalent to
//
// []reflect.Value{
// reflect.ValueOf(workspace),
// reflect.ValueOf(template),
// ...
// }
func values(ins ...any) []reflect.Value {
out := make([]reflect.Value, 0)
for _, input := range ins {
input := input
out = append(out, reflect.ValueOf(input))
}
return out
}
// asserts is a convenience method for creating AssertRBACs.
//
// The number of inputs must be an even number.
// asserts() will panic if this is not the case.
//
// Even-numbered inputs are the objects, and odd-numbered inputs are the actions.
// Objects must implement rbac.Objecter.
// Inputs can be a single rbac.Action, or a slice of rbac.Action.
//
// asserts(workspace, rbac.ActionRead, template, slice(rbac.ActionRead, rbac.ActionWrite), ...)
//
// is equivalent to
//
// []AssertRBAC{
// {Object: workspace, Actions: []rbac.Action{rbac.ActionRead}},
// {Object: template, Actions: []rbac.Action{rbac.ActionRead, rbac.ActionWrite)}},
// ...
// }
func asserts(inputs ...any) []AssertRBAC {
if len(inputs)%2 != 0 {
panic(fmt.Sprintf("Must be an even length number of args, found %d", len(inputs)))
}
out := make([]AssertRBAC, 0)
for i := 0; i < len(inputs); i += 2 {
obj, ok := inputs[i].(rbac.Objecter)
if !ok {
panic(fmt.Sprintf("object type '%T' does not implement rbac.Objecter", inputs[i]))
}
rbacObj := obj.RBACObject()
var actions []rbac.Action
actions, ok = inputs[i+1].([]rbac.Action)
if !ok {
action, ok := inputs[i+1].(rbac.Action)
if !ok {
// Could be the string type.
actionAsString, ok := inputs[i+1].(string)
if !ok {
panic(fmt.Sprintf("action '%q' not a supported action", actionAsString))
}
action = rbac.Action(actionAsString)
}
actions = []rbac.Action{action}
}
out = append(out, AssertRBAC{
Object: rbacObj,
Actions: actions,
})
}
return out
}
type emptyPreparedAuthorized struct{}
func (emptyPreparedAuthorized) Authorize(_ context.Context, _ rbac.Object) error { return nil }
func (emptyPreparedAuthorized) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
return "", nil
}

View File

@ -0,0 +1,194 @@
package dbauthz
import (
"context"
"time"
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
)
// TODO: All these system functions should have rbac objects created to allow
// only system roles to call them. No user roles should ever have the permission
// to these objects. Might need a negative permission on the `Owner` role to
// prevent owners.
func (q *querier) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) {
return q.db.UpdateUserLinkedID(ctx, arg)
}
func (q *querier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
return q.db.GetUserLinkByLinkedID(ctx, linkedID)
}
func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
return q.db.GetUserLinkByUserIDLoginType(ctx, arg)
}
func (q *querier) GetLatestWorkspaceBuilds(ctx context.Context) ([]database.WorkspaceBuild, error) {
// This function is a system function until we implement a join for workspace builds.
// This is because we need to query for all related workspaces to the returned builds.
// This is a very inefficient method of fetching the latest workspace builds.
// We should just join the rbac properties.
return q.db.GetLatestWorkspaceBuilds(ctx)
}
// GetWorkspaceAgentByAuthToken is used in http middleware to get the workspace agent.
// This should only be used by a system user in that middleware.
func (q *querier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) {
return q.db.GetWorkspaceAgentByAuthToken(ctx, authToken)
}
func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) {
return q.db.GetActiveUserCount(ctx)
}
func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, error) {
return q.db.GetUnexpiredLicenses(ctx)
}
func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) {
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
// TODO Implement authz check for system user.
return q.db.GetDERPMeshKey(ctx)
}
func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error {
// TODO Implement authz check for system user.
return q.db.InsertDERPMeshKey(ctx, value)
}
func (q *querier) InsertDeploymentID(ctx context.Context, value string) error {
// TODO Implement authz check for system user.
return q.db.InsertDeploymentID(ctx, value)
}
func (q *querier) InsertReplica(ctx context.Context, arg database.InsertReplicaParams) (database.Replica, error) {
// TODO Implement authz check for system user.
return q.db.InsertReplica(ctx, arg)
}
func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaParams) (database.Replica, error) {
// TODO Implement authz check for system user.
return q.db.UpdateReplica(ctx, arg)
}
func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error {
// TODO Implement authz check for system user.
return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt)
}
func (q *querier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.Replica, error) {
// TODO Implement authz check for system user.
return q.db.GetReplicasUpdatedAfter(ctx, updatedAt)
}
func (q *querier) GetUserCount(ctx context.Context) (int64, error) {
return q.db.GetUserCount(ctx)
}
func (q *querier) GetTemplates(ctx context.Context) ([]database.Template, error) {
// TODO Implement authz check for system user.
return q.db.GetTemplates(ctx)
}
// UpdateWorkspaceBuildCostByID is used by the provisioning system to update the cost of a workspace build.
func (q *querier) UpdateWorkspaceBuildCostByID(ctx context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) (database.WorkspaceBuild, error) {
return q.db.UpdateWorkspaceBuildCostByID(ctx, arg)
}
func (q *querier) InsertOrUpdateLastUpdateCheck(ctx context.Context, value string) error {
return q.db.InsertOrUpdateLastUpdateCheck(ctx, value)
}
func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) {
return q.db.GetLastUpdateCheck(ctx)
}
// Telemetry related functions. These functions are system functions for returning
// telemetry data. Never called by a user.
func (q *querier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceBuild, error) {
return q.db.GetWorkspaceBuildsCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
return q.db.GetWorkspaceAgentsCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceAppsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceApp, error) {
return q.db.GetWorkspaceAppsCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceResourcesCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResource, error) {
return q.db.GetWorkspaceResourcesCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResourceMetadatum, error) {
return q.db.GetWorkspaceResourceMetadataCreatedAfter(ctx, createdAt)
}
func (q *querier) DeleteOldAgentStats(ctx context.Context) error {
return q.db.DeleteOldAgentStats(ctx)
}
func (q *querier) GetParameterSchemasCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ParameterSchema, error) {
return q.db.GetParameterSchemasCreatedAfter(ctx, createdAt)
}
func (q *querier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ProvisionerJob, error) {
return q.db.GetProvisionerJobsCreatedAfter(ctx, createdAt)
}
// Provisionerd server functions
func (q *querier) InsertWorkspaceAgent(ctx context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) {
return q.db.InsertWorkspaceAgent(ctx, arg)
}
func (q *querier) InsertWorkspaceApp(ctx context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) {
return q.db.InsertWorkspaceApp(ctx, arg)
}
func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) {
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
}
func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) {
return q.db.AcquireProvisionerJob(ctx, arg)
}
func (q *querier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error {
return q.db.UpdateProvisionerJobWithCompleteByID(ctx, arg)
}
func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
return q.db.UpdateProvisionerJobByID(ctx, arg)
}
func (q *querier) InsertProvisionerJob(ctx context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) {
return q.db.InsertProvisionerJob(ctx, arg)
}
func (q *querier) InsertProvisionerJobLogs(ctx context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) {
return q.db.InsertProvisionerJobLogs(ctx, arg)
}
func (q *querier) InsertProvisionerDaemon(ctx context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) {
return q.db.InsertProvisionerDaemon(ctx, arg)
}
func (q *querier) InsertTemplateVersionParameter(ctx context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) {
return q.db.InsertTemplateVersionParameter(ctx, arg)
}
func (q *querier) InsertWorkspaceResource(ctx context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) {
return q.db.InsertWorkspaceResource(ctx, arg)
}
func (q *querier) InsertParameterSchema(ctx context.Context, arg database.InsertParameterSchemaParams) (database.ParameterSchema, error) {
return q.db.InsertParameterSchema(ctx, arg)
}

View File

@ -0,0 +1,219 @@
package dbauthz_test
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbgen"
)
func (s *MethodTestSuite) TestSystemFunctions() {
s.Run("UpdateUserLinkedID", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
l := dbgen.UserLink(s.T(), db, database.UserLink{UserID: u.ID})
check.Args(database.UpdateUserLinkedIDParams{
UserID: u.ID,
LinkedID: l.LinkedID,
LoginType: database.LoginTypeGithub,
}).Asserts().Returns(l)
}))
s.Run("GetUserLinkByLinkedID", s.Subtest(func(db database.Store, check *expects) {
l := dbgen.UserLink(s.T(), db, database.UserLink{})
check.Args(l.LinkedID).Asserts().Returns(l)
}))
s.Run("GetUserLinkByUserIDLoginType", s.Subtest(func(db database.Store, check *expects) {
l := dbgen.UserLink(s.T(), db, database.UserLink{})
check.Args(database.GetUserLinkByUserIDLoginTypeParams{
UserID: l.UserID,
LoginType: l.LoginType,
}).Asserts().Returns(l)
}))
s.Run("GetLatestWorkspaceBuilds", s.Subtest(func(db database.Store, check *expects) {
dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
check.Args().Asserts()
}))
s.Run("GetWorkspaceAgentByAuthToken", s.Subtest(func(db database.Store, check *expects) {
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{})
check.Args(agt.AuthToken).Asserts().Returns(agt)
}))
s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts().Returns(int64(0))
}))
s.Run("GetUnexpiredLicenses", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts()
}))
s.Run("GetAuthorizationUserRoles", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
check.Args(u.ID).Asserts()
}))
s.Run("GetDERPMeshKey", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts()
}))
s.Run("InsertDERPMeshKey", s.Subtest(func(db database.Store, check *expects) {
check.Args("value").Asserts().Returns()
}))
s.Run("InsertDeploymentID", s.Subtest(func(db database.Store, check *expects) {
check.Args("value").Asserts().Returns()
}))
s.Run("InsertReplica", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertReplicaParams{
ID: uuid.New(),
}).Asserts()
}))
s.Run("UpdateReplica", s.Subtest(func(db database.Store, check *expects) {
replica, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New()})
require.NoError(s.T(), err)
check.Args(database.UpdateReplicaParams{
ID: replica.ID,
DatabaseLatency: 100,
}).Asserts()
}))
s.Run("DeleteReplicasUpdatedBefore", s.Subtest(func(db database.Store, check *expects) {
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()})
require.NoError(s.T(), err)
check.Args(time.Now().Add(time.Hour)).Asserts()
}))
s.Run("GetReplicasUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()})
require.NoError(s.T(), err)
check.Args(time.Now().Add(time.Hour * -1)).Asserts()
}))
s.Run("GetUserCount", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts().Returns(int64(0))
}))
s.Run("GetTemplates", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.Template(s.T(), db, database.Template{})
check.Args().Asserts()
}))
s.Run("UpdateWorkspaceBuildCostByID", s.Subtest(func(db database.Store, check *expects) {
b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
o := b
o.DailyCost = 10
check.Args(database.UpdateWorkspaceBuildCostByIDParams{
ID: b.ID,
DailyCost: 10,
}).Asserts().Returns(o)
}))
s.Run("InsertOrUpdateLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) {
check.Args("value").Asserts()
}))
s.Run("GetLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) {
err := db.InsertOrUpdateLastUpdateCheck(context.Background(), "value")
require.NoError(s.T(), err)
check.Args().Asserts()
}))
s.Run("GetWorkspaceBuildsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetWorkspaceAgentsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetWorkspaceAppsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetWorkspaceResourcesCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetWorkspaceResourceMetadataCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceResourceMetadatums(s.T(), db, database.WorkspaceResourceMetadatum{})
check.Args(time.Now()).Asserts()
}))
s.Run("DeleteOldAgentStats", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts()
}))
s.Run("GetParameterSchemasCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.ParameterSchema(s.T(), db, database.ParameterSchema{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetProvisionerJobsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("InsertWorkspaceAgent", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertWorkspaceAgentParams{
ID: uuid.New(),
}).Asserts()
}))
s.Run("InsertWorkspaceApp", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertWorkspaceAppParams{
ID: uuid.New(),
Health: database.WorkspaceAppHealthDisabled,
SharingLevel: database.AppSharingLevelOwner,
}).Asserts()
}))
s.Run("InsertWorkspaceResourceMetadata", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertWorkspaceResourceMetadataParams{
WorkspaceResourceID: uuid.New(),
}).Asserts()
}))
s.Run("AcquireProvisionerJob", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
StartedAt: sql.NullTime{Valid: false},
})
check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}}).
Asserts()
}))
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
check.Args(database.UpdateProvisionerJobWithCompleteByIDParams{
ID: j.ID,
}).Asserts()
}))
s.Run("UpdateProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
check.Args(database.UpdateProvisionerJobByIDParams{
ID: j.ID,
UpdatedAt: time.Now(),
}).Asserts()
}))
s.Run("InsertProvisionerJob", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
}).Asserts()
}))
s.Run("InsertProvisionerJobLogs", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
check.Args(database.InsertProvisionerJobLogsParams{
JobID: j.ID,
}).Asserts()
}))
s.Run("InsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertProvisionerDaemonParams{
ID: uuid.New(),
}).Asserts()
}))
s.Run("InsertTemplateVersionParameter", s.Subtest(func(db database.Store, check *expects) {
v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{})
check.Args(database.InsertTemplateVersionParameterParams{
TemplateVersionID: v.ID,
}).Asserts()
}))
s.Run("InsertWorkspaceResource", s.Subtest(func(db database.Store, check *expects) {
r := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{})
check.Args(database.InsertWorkspaceResourceParams{
ID: r.ID,
Transition: database.WorkspaceTransitionStart,
}).Asserts()
}))
s.Run("InsertParameterSchema", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertParameterSchemaParams{
ID: uuid.New(),
DefaultSourceScheme: database.ParameterSourceSchemeNone,
DefaultDestinationScheme: database.ParameterDestinationSchemeNone,
ValidationTypeSystem: database.ParameterTypeSystemNone,
}).Asserts()
}))
}

View File

@ -614,6 +614,14 @@ func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params databas
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return -1, err
}
}
users := make([]database.User, 0, len(q.users))
for _, user := range q.users {
@ -892,6 +900,14 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
q.mutex.RLock()
defer q.mutex.RUnlock()
if prepared != nil {
// Call this to match the same function calls as the SQL implementation.
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return nil, err
}
}
workspaces := make([]database.Workspace, 0)
for _, workspace := range q.workspaces {
if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID {
@ -1230,6 +1246,23 @@ func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg databa
return database.Workspace{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) {
if err := validateDatabaseType(workspaceAppID); err != nil {
return database.Workspace{}, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, workspaceApp := range q.workspaceApps {
workspaceApp := workspaceApp
if workspaceApp.ID == workspaceAppID {
return q.GetWorkspaceByAgentID(context.Background(), workspaceApp.AgentID)
}
}
return database.Workspace{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -1646,6 +1679,14 @@ func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.G
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL())
if err != nil {
return nil, err
}
}
var templates []database.Template
for _, template := range q.templates {
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
@ -3819,6 +3860,18 @@ func (q *fakeQuerier) InsertLicense(
return l, nil
}
func (q *fakeQuerier) GetLicenseByID(_ context.Context, id int32) (database.License, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, license := range q.licenses {
if license.ID == id {
return license, nil
}
}
return database.License{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

View File

@ -66,7 +66,7 @@ func Template(t testing.TB, db database.Store, seed database.Template) database.
UserACL: seed.UserACL,
GroupACL: seed.GroupACL,
DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)),
AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true),
AllowUserCancelWorkspaceJobs: seed.AllowUserCancelWorkspaceJobs,
})
require.NoError(t, err, "insert template")
return template
@ -369,11 +369,8 @@ func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) dat
func TemplateVersion(t testing.TB, db database.Store, orig database.TemplateVersion) database.TemplateVersion {
version, err := db.InsertTemplateVersion(context.Background(), database.InsertTemplateVersionParams{
ID: takeFirst(orig.ID, uuid.New()),
TemplateID: uuid.NullUUID{
UUID: takeFirst(orig.TemplateID.UUID, uuid.New()),
Valid: takeFirst(orig.TemplateID.Valid, true),
},
ID: takeFirst(orig.ID, uuid.New()),
TemplateID: orig.TemplateID,
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
CreatedAt: takeFirst(orig.CreatedAt, database.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, database.Now()),

View File

@ -68,7 +68,7 @@ func TestGenerator(t *testing.T) {
require.Equal(t, exp, must(db.GetWorkspaceAppsByAgentID(context.Background(), exp.AgentID))[0])
})
t.Run("WorkspaceResourceMetadatum", func(t *testing.T) {
t.Run("WorkspaceResourceMetadata", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
exp := dbgen.WorkspaceResourceMetadatums(t, db, database.WorkspaceResourceMetadatum{})

View File

@ -2,6 +2,7 @@ package database
import (
"sort"
"strconv"
"github.com/coder/coder/coderd/rbac"
)
@ -63,6 +64,11 @@ func (TemplateVersion) RBACObject(template Template) rbac.Object {
return template.RBACObject()
}
// RBACObjectNoTemplate is for orphaned template versions.
func (v TemplateVersion) RBACObjectNoTemplate() rbac.Object {
return rbac.ResourceTemplate.InOrg(v.OrganizationID)
}
func (g Group) RBACObject() rbac.Object {
return rbac.ResourceGroup.WithID(g.ID).
InOrg(g.OrganizationID)
@ -94,6 +100,13 @@ func (m OrganizationMember) RBACObject() rbac.Object {
InOrg(m.OrganizationID)
}
func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object {
// TODO: This feels incorrect as we are really returning a list of orgmembers.
// This return type should be refactored to return a list of orgmembers, not this
// special type.
return rbac.ResourceUser.WithID(m.UserID)
}
func (o Organization) RBACObject() rbac.Object {
return rbac.ResourceOrganization.
WithID(o.ID).
@ -118,11 +131,29 @@ func (u User) RBACObject() rbac.Object {
}
func (u User) UserDataRBACObject() rbac.Object {
return rbac.ResourceUser.WithID(u.ID).WithOwner(u.ID.String())
return rbac.ResourceUserData.WithID(u.ID).WithOwner(u.ID.String())
}
func (License) RBACObject() rbac.Object {
return rbac.ResourceLicense
func (u GetUsersRow) RBACObject() rbac.Object {
return rbac.ResourceUser.WithID(u.ID)
}
func (u GitSSHKey) RBACObject() rbac.Object {
return rbac.ResourceUserData.WithID(u.UserID).WithOwner(u.UserID.String())
}
func (u GitAuthLink) RBACObject() rbac.Object {
// I assume UserData is ok?
return rbac.ResourceUserData.WithID(u.UserID).WithOwner(u.UserID.String())
}
func (u UserLink) RBACObject() rbac.Object {
// I assume UserData is ok?
return rbac.ResourceUserData.WithOwner(u.UserID.String()).WithID(u.UserID)
}
func (l License) RBACObject() rbac.Object {
return rbac.ResourceLicense.WithIDString(strconv.FormatInt(int64(l.ID), 10))
}
func ConvertUserRows(rows []GetUsersRow) []User {

View File

@ -56,6 +56,7 @@ type sqlcQuerier interface {
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error)
GetLicenseByID(ctx context.Context, id int32) (License, error)
GetLicenses(ctx context.Context) ([]License, error)
GetLogoURL(ctx context.Context) (string, error)
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
@ -121,6 +122,7 @@ type sqlcQuerier interface {
GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (Workspace, error)
GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Workspace, error)
GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWorkspaceByOwnerIDAndNameParams) (Workspace, error)
GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (Workspace, error)
GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (WorkspaceResource, error)
GetWorkspaceResourceMetadataByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceResourceMetadatum, error)
GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceResourceMetadatum, error)

View File

@ -1343,6 +1343,30 @@ func (q *sqlQuerier) DeleteLicense(ctx context.Context, id int32) (int32, error)
return id, err
}
const getLicenseByID = `-- name: GetLicenseByID :one
SELECT
id, uploaded_at, jwt, exp, uuid
FROM
licenses
WHERE
id = $1
LIMIT
1
`
func (q *sqlQuerier) GetLicenseByID(ctx context.Context, id int32) (License, error) {
row := q.db.QueryRowContext(ctx, getLicenseByID, id)
var i License
err := row.Scan(
&i.ID,
&i.UploadedAt,
&i.JWT,
&i.Exp,
&i.UUID,
)
return i, err
}
const getLicenses = `-- name: GetLicenses :many
SELECT id, uploaded_at, jwt, exp, uuid
FROM licenses
@ -6513,6 +6537,62 @@ func (q *sqlQuerier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWo
return i, err
}
const getWorkspaceByWorkspaceAppID = `-- name: GetWorkspaceByWorkspaceAppID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at
FROM
workspaces
WHERE
workspaces.id = (
SELECT
workspace_id
FROM
workspace_builds
WHERE
workspace_builds.job_id = (
SELECT
job_id
FROM
workspace_resources
WHERE
workspace_resources.id = (
SELECT
resource_id
FROM
workspace_agents
WHERE
workspace_agents.id = (
SELECT
agent_id
FROM
workspace_apps
WHERE
workspace_apps.id = $1
)
)
)
)
`
func (q *sqlQuerier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (Workspace, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceByWorkspaceAppID, workspaceAppID)
var i Workspace
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OwnerID,
&i.OrganizationID,
&i.TemplateID,
&i.Deleted,
&i.Name,
&i.AutostartSchedule,
&i.Ttl,
&i.LastUsedAt,
)
return i, err
}
const getWorkspaces = `-- name: GetWorkspaces :many
SELECT
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, COUNT(*) OVER () as count

View File

@ -14,6 +14,16 @@ SELECT *
FROM licenses
ORDER BY (id);
-- name: GetLicenseByID :one
SELECT
*
FROM
licenses
WHERE
id = $1
LIMIT
1;
-- name: GetUnexpiredLicenses :many
SELECT *
FROM licenses

View File

@ -8,6 +8,42 @@ WHERE
LIMIT
1;
-- name: GetWorkspaceByWorkspaceAppID :one
SELECT
*
FROM
workspaces
WHERE
workspaces.id = (
SELECT
workspace_id
FROM
workspace_builds
WHERE
workspace_builds.job_id = (
SELECT
job_id
FROM
workspace_resources
WHERE
workspace_resources.id = (
SELECT
resource_id
FROM
workspace_agents
WHERE
workspace_agents.id = (
SELECT
agent_id
FROM
workspace_apps
WHERE
workspace_apps.id = @workspace_app_id
)
)
)
);
-- name: GetWorkspaceByAgentID :one
SELECT
*

View File

@ -76,7 +76,14 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) {
ID: file.ID,
})
return
} else if !errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error getting file.",
Detail: err.Error(),
})
return
}
id := uuid.New()
file, err = api.Database.InsertFile(ctx, database.InsertFileParams{
ID: id,

View File

@ -19,6 +19,7 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
@ -159,7 +160,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
return
}
key, err := cfg.DB.GetAPIKeyByID(r.Context(), keyID)
//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystem(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
optionalWrite(http.StatusUnauthorized, codersdk.Response{
@ -192,7 +194,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
changed = false
)
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
link, err = cfg.DB.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystem(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
@ -275,7 +278,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
}
}
if changed {
err := cfg.DB.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
//nolint:gocritic // System needs to update API Key LastUsed
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystem(ctx), database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
@ -291,7 +295,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
// If the API Key is associated with a user_link (e.g. Github/OIDC)
// then we want to update the relevant oauth fields.
if link.UserID != uuid.Nil {
link, err = cfg.DB.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
// nolint:gocritic
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystem(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
@ -310,7 +315,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
// We only want to update this occasionally to reduce DB write
// load. We update alongside the UserLink and APIKey since it's
// easier on the DB to colocate writes.
_, err = cfg.DB.UpdateUserLastSeenAt(ctx, database.UpdateUserLastSeenAtParams{
// nolint:gocritic
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystem(ctx), database.UpdateUserLastSeenAtParams{
ID: key.UserID,
LastSeenAt: database.Now(),
UpdatedAt: database.Now(),
@ -327,7 +333,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
// If the key is valid, we also fetch the user roles and status.
// The roles are used for RBAC authorize checks, and the status
// is to block 'suspended' users from accessing the platform.
roles, err := cfg.DB.GetAuthorizationUserRoles(r.Context(), key.UserID)
// nolint:gocritic
roles, err := cfg.DB.GetAuthorizationUserRoles(dbauthz.AsSystem(ctx), key.UserID)
if err != nil {
write(http.StatusUnauthorized, codersdk.Response{
Message: internalErrorMessage,
@ -343,16 +350,20 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
return
}
// Actor is the user's authorization context.
actor := rbac.Subject{
ID: key.UserID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.ScopeName(key.Scope),
}
ctx = context.WithValue(ctx, apiKeyContextKey{}, key)
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
Username: roles.Username,
Actor: rbac.Subject{
ID: key.UserID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.ScopeName(key.Scope),
},
Actor: actor,
})
// Set the auth context for the authzquerier as well.
ctx = dbauthz.As(ctx, actor)
next.ServeHTTP(rw, r.WithContext(ctx))
})

37
coderd/httpmw/authz.go Normal file
View File

@ -0,0 +1,37 @@
package httpmw
import (
"net/http"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/go-chi/chi/v5"
)
// AsAuthzSystem is a chained handler that temporarily sets the dbauthz context
// to System for the inner handlers, and resets the context afterwards.
//
// TODO: Refactor the middleware functions to not require this.
// This is a bit of a kludge for now as some middleware functions require
// usage as a system user in some cases, but not all cases. To avoid large
// refactors, we use this middleware to temporarily set the context to a system.
func AsAuthzSystem(mws ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
chain := chi.Chain(mws...)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
before, beforeExists := dbauthz.ActorFromContext(r.Context())
if !beforeExists {
// AsRemoveActor will actually remove the actor from the context.
before = dbauthz.AsRemoveActor
}
// nolint:gocritic // AsAuthzSystem needs to do this.
r = r.WithContext(dbauthz.AsSystem(ctx))
chain.Handler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
r = r.WithContext(dbauthz.As(r.Context(), before))
next.ServeHTTP(rw, r)
})).ServeHTTP(rw, r)
})
}
}

View File

@ -0,0 +1,97 @@
package httpmw_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpmw"
)
func TestAsAuthzSystem(t *testing.T) {
t.Parallel()
userActor := coderdtest.RandomRBACSubject()
base := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
actor, ok := dbauthz.ActorFromContext(r.Context())
assert.True(t, ok, "actor should exist")
assert.True(t, userActor.Equal(actor), "actor should be the user actor")
})
mwSetUser := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
r = r.WithContext(dbauthz.As(r.Context(), userActor))
next.ServeHTTP(rw, r)
})
}
mwAssertSystem := mwAssert(func(req *http.Request) {
actor, ok := dbauthz.ActorFromContext(req.Context())
assert.True(t, ok, "actor should exist")
assert.False(t, userActor.Equal(actor), "systemActor should not be the user actor")
assert.Contains(t, actor.Roles.Names(), "system", "should have system role")
})
mwAssertUser := mwAssert(func(req *http.Request) {
actor, ok := dbauthz.ActorFromContext(req.Context())
assert.True(t, ok, "actor should exist")
assert.True(t, userActor.Equal(actor), "should be the useractor")
})
mwAssertNoUser := mwAssert(func(req *http.Request) {
_, ok := dbauthz.ActorFromContext(req.Context())
assert.False(t, ok, "actor should not exist")
})
// Request as the user actor
const pattern = "/"
req := httptest.NewRequest("GET", pattern, nil)
res := httptest.NewRecorder()
handler := chi.NewRouter()
handler.Route(pattern, func(r chi.Router) {
r.Use(
// First assert there is no actor context
mwAssertNoUser,
httpmw.AsAuthzSystem(
// Assert the system actor
mwAssertSystem,
mwAssertSystem,
),
// Assert no user present outside of the AsAuthzSystem chain
mwAssertNoUser,
// ----
// Set to the user actor
mwSetUser,
// Assert the user actor
mwAssertUser,
httpmw.AsAuthzSystem(
// Assert the system actor
mwAssertSystem,
mwAssertSystem,
),
// Check the user actor was returned to the context
mwAssertUser,
)
r.Handle("/", base)
r.NotFound(func(writer http.ResponseWriter, request *http.Request) {
assert.Fail(t, "should not hit not found, the route should be correct")
})
})
handler.ServeHTTP(res, req)
}
func mwAssert(assertF func(req *http.Request)) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assertF(r)
next.ServeHTTP(rw, r)
})
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
@ -68,7 +69,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
})
return
}
user, err = db.GetUserByID(ctx, apiKey.UserID)
//nolint:gocritic // System needs to be able to get user from param.
user, err = db.GetUserByID(dbauthz.AsSystem(ctx), apiKey.UserID)
if xerrors.Is(err, sql.ErrNoRows) {
httpapi.ResourceNotFound(rw)
return
@ -81,8 +83,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
return
}
} else if userID, err := uuid.Parse(userQuery); err == nil {
// If the userQuery is a valid uuid
user, err = db.GetUserByID(ctx, userID)
//nolint:gocritic // If the userQuery is a valid uuid
user, err = db.GetUserByID(dbauthz.AsSystem(ctx), userID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: userErrorMessage,
@ -90,8 +92,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
return
}
} else {
// Try as a username last
user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
// nolint:gocritic // Try as a username last
user, err = db.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
Username: userQuery,
})
if err != nil {

View File

@ -10,7 +10,9 @@ import (
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
)
@ -45,7 +47,8 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
})
return
}
agent, err := db.GetWorkspaceAgentByAuthToken(ctx, token)
//nolint:gocritic // System needs to be able to get workspace agents.
agent, err := db.GetWorkspaceAgentByAuthToken(dbauthz.AsSystem(ctx), token)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
@ -62,8 +65,50 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
return
}
//nolint:gocritic // System needs to be able to get workspace agents.
subject, err := getAgentSubject(dbauthz.AsSystem(ctx), db, agent)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent.",
Detail: err.Error(),
})
return
}
ctx = context.WithValue(ctx, workspaceAgentContextKey{}, agent)
// Also set the dbauthz actor for the request.
ctx = dbauthz.As(ctx, subject)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}
func getAgentSubject(ctx context.Context, db database.Store, agent database.WorkspaceAgent) (rbac.Subject, error) {
// TODO: make a different query that gets the workspace owner and roles along with the agent.
workspace, err := db.GetWorkspaceByAgentID(ctx, agent.ID)
if err != nil {
return rbac.Subject{}, err
}
user, err := db.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
return rbac.Subject{}, err
}
roles, err := db.GetAuthorizationUserRoles(ctx, user.ID)
if err != nil {
return rbac.Subject{}, err
}
// A user that creates a workspace can use this agent auth token and
// impersonate the workspace. So to prevent privilege escalation, the
// subject inherits the roles of the user that owns the workspace.
// We then add a workspace-agent scope to limit the permissions
// to only what the workspace agent needs.
return rbac.Subject{
ID: user.ID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.WorkspaceAgentScope(workspace.ID, user.ID),
}, nil
}

View File

@ -19,11 +19,10 @@ import (
func TestWorkspaceAgent(t *testing.T) {
t.Parallel()
setup := func(db database.Store) (*http.Request, uuid.UUID) {
token := uuid.New()
setup := func(db database.Store, token uuid.UUID) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set(codersdk.SessionTokenHeader, token.String())
return r, token
return r
}
t.Run("None", func(t *testing.T) {
@ -34,7 +33,7 @@ func TestWorkspaceAgent(t *testing.T) {
httpmw.ExtractWorkspaceAgent(db),
)
rtr.Get("/", nil)
r, _ := setup(db)
r := setup(db, uuid.New())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
@ -46,6 +45,24 @@ func TestWorkspaceAgent(t *testing.T) {
t.Run("Found", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
var (
user = dbgen.User(t, db, database.User{})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
})
job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{})
resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
JobID: job.ID,
})
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resource.ID,
})
)
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceAgent(db),
@ -54,10 +71,7 @@ func TestWorkspaceAgent(t *testing.T) {
_ = httpmw.WorkspaceAgent(r)
rw.WriteHeader(http.StatusOK)
})
r, token := setup(db)
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
AuthToken: token,
})
r := setup(db, agent.AuthToken)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)

View File

@ -55,20 +55,20 @@ func (api *API) putMemberRoles(rw http.ResponseWriter, r *http.Request) {
// Assigning a role requires the create permission.
if len(added) > 0 && !api.Authorize(r, rbac.ActionCreate, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) {
httpapi.Forbidden(rw)
httpapi.ResourceNotFound(rw)
return
}
// Removing a role requires the delete permission.
if len(removed) > 0 && !api.Authorize(r, rbac.ActionDelete, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) {
httpapi.Forbidden(rw)
httpapi.ResourceNotFound(rw)
return
}
// Just treat adding & removing as "assigning" for now.
for _, roleName := range append(added, removed...) {
if !rbac.CanAssignRole(actorRoles.Actor.Roles, roleName) {
httpapi.Forbidden(rw)
httpapi.ResourceNotFound(rw)
return
}
}

View File

@ -14,6 +14,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/codersdk"
"github.com/coder/retry"
)
@ -142,6 +143,8 @@ func countUniqueUsers(rows []database.GetTemplateDAUsRow) int {
}
func (c *Cache) refresh(ctx context.Context) error {
//nolint:gocritic // This is a system service.
ctx = dbauthz.AsSystem(ctx)
err := c.database.DeleteOldAgentStats(ctx)
if err != nil {
return xerrors.Errorf("delete old stats: %w", err)

View File

@ -25,6 +25,7 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/codersdk"
@ -56,6 +57,8 @@ type Server struct {
// AcquireJob queries the database to lock a job.
func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
//nolint:gocritic //TODO: make a provisionerd role
ctx = dbauthz.AsSystem(ctx)
// This prevents loads of provisioner daemons from consistently
// querying the database when no jobs are available.
//
@ -270,6 +273,8 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
}
func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
//nolint:gocritic //TODO: make a provisionerd role
ctx = dbauthz.AsSystem(ctx)
jobID, err := uuid.Parse(request.JobId)
if err != nil {
return nil, xerrors.Errorf("parse job id: %w", err)
@ -299,6 +304,8 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot
}
func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
//nolint:gocritic //TODO: make a provisionerd role
ctx = dbauthz.AsSystem(ctx)
parsedID, err := uuid.Parse(request.JobId)
if err != nil {
return nil, xerrors.Errorf("parse job id: %w", err)
@ -345,7 +352,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
slog.F("stage", log.Stage),
slog.F("output", log.Output))
}
logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams)
//nolint:gocritic //TODO: make a provisionerd role
logs, err := server.Database.InsertProvisionerJobLogs(dbauthz.AsSystem(context.Background()), insertParams)
if err != nil {
server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err))
return nil, xerrors.Errorf("insert job logs: %w", err)
@ -470,6 +478,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
}
func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) {
//nolint:gocritic // TODO: make a provisionerd role
ctx = dbauthz.AsSystem(ctx)
jobID, err := uuid.Parse(failJob.JobId)
if err != nil {
return nil, xerrors.Errorf("parse job id: %w", err)
@ -596,6 +606,8 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
// CompleteJob is triggered by a provision daemon to mark a provisioner job as completed.
func (server *Server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) {
//nolint:gocritic // TODO: make a provisionerd role
ctx = dbauthz.AsSystem(ctx)
jobID, err := uuid.Parse(completed.JobId)
if err != nil {
return nil, xerrors.Errorf("parse job id: %w", err)

View File

@ -16,9 +16,10 @@ import (
"nhooyr.io/websocket"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
)
@ -32,6 +33,7 @@ import (
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
var (
ctx = r.Context()
actor, _ = dbauthz.ActorFromContext(ctx)
logger = api.Logger.With(slog.F("job_id", job.ID))
follow = r.URL.Query().Has("follow")
afterRaw = r.URL.Query().Get("after")
@ -49,7 +51,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
// of processed IDs.
var bufferedLogs <-chan database.ProvisionerJobLog
if follow {
bl, closeFollow, err := api.followLogs(job.ID)
bl, closeFollow, err := api.followLogs(actor, job.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error watching provisioner logs.",
@ -367,7 +369,7 @@ type provisionerJobLogsMessage struct {
EndOfLogs bool `json:"end_of_logs,omitempty"`
}
func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) {
func (api *API) followLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) {
logger := api.Logger.With(slog.F("job_id", jobID))
var (
@ -392,7 +394,7 @@ func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog,
}
if jlMsg.CreatedAfter != 0 {
logs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{
logs, err := api.Database.GetProvisionerLogsByIDBetween(dbauthz.As(ctx, actor), database.GetProvisionerLogsByIDBetweenParams{
JobID: jobID,
CreatedAfter: jlMsg.CreatedAfter,
})

View File

@ -1039,7 +1039,6 @@ func testAuthorize(t *testing.T, name string, subject Subject, sets ...[]authTes
}
}
func must[T any](value T, err error) T {
if err != nil {
panic(err)

View File

@ -133,6 +133,8 @@ var (
ResourceWorkspace.Type: {ActionRead},
// CRUD to provisioner daemons for now.
ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
// Needs to read all organizations since
ResourceOrganization.Type: {ActionRead},
}),
Org: map[string][]Permission{},
User: []Permission{},
@ -217,6 +219,12 @@ var (
// The first key is the actor role, the second is the roles they can assign.
// map[actor_role][assign_role]<can_assign>
assignRoles = map[string]map[string]bool{
"system": {
owner: true,
member: true,
orgAdmin: true,
orgMember: true,
},
owner: {
owner: true,
auditor: true,

View File

@ -10,7 +10,7 @@ import (
// BenchmarkRBACValueAllocation benchmarks the cost of allocating a rego input
// value. By default, `ast.InterfaceToValue` is used to convert the input,
// which uses json marshalling under the hood.
// which uses json marshaling under the hood.
//
// Currently ast.Object.insert() is the slowest part of the process and allocates
// the most amount of bytes. This general approach copies all of our struct

View File

@ -19,6 +19,7 @@ type authSubject struct {
Actor rbac.Subject
}
// TODO: add the SYSTEM to the MATRIX
func TestRolePermissions(t *testing.T) {
t.Parallel()
@ -183,8 +184,8 @@ func TestRolePermissions(t *testing.T) {
Actions: []rbac.Action{rbac.ActionRead},
Resource: rbac.ResourceOrganization.WithID(orgID).InOrg(orgID),
AuthorizeMap: map[bool][]authSubject{
true: {owner, orgAdmin, orgMemberMe},
false: {otherOrgAdmin, otherOrgMember, memberMe, templateAdmin, userAdmin},
true: {owner, orgAdmin, orgMemberMe, templateAdmin},
false: {otherOrgAdmin, otherOrgMember, memberMe, userAdmin},
},
},
{

View File

@ -1,6 +1,10 @@
package rbac
import "github.com/open-policy-agent/opa/rego"
import (
"errors"
"github.com/open-policy-agent/opa/rego"
)
const (
// errUnauthorized is the error message that should be returned to
@ -24,6 +28,12 @@ type UnauthorizedError struct {
output rego.ResultSet
}
// IsUnauthorizedError is a convenience function to check if err is UnauthorizedError.
// It is equivalent to errors.As(err, &UnauthorizedError{}).
func IsUnauthorizedError(err error) bool {
return errors.As(err, &UnauthorizedError{})
}
// ForbiddenWithInternal creates a new error that will return a simple
// "forbidden" to the client, logging internally the more detailed message
// provided.
@ -37,6 +47,10 @@ func ForbiddenWithInternal(internal error, subject Subject, action Action, objec
}
}
func (e UnauthorizedError) Unwrap() error {
return e.internal
}
// Error implements the error interface.
func (UnauthorizedError) Error() string {
return errUnauthorized
@ -47,6 +61,10 @@ func (e *UnauthorizedError) Internal() error {
return e.internal
}
func (e *UnauthorizedError) SetInternal(err error) {
e.internal = err
}
func (e *UnauthorizedError) Input() map[string]interface{} {
return map[string]interface{}{
"subject": e.subject,
@ -59,3 +77,11 @@ func (e *UnauthorizedError) Input() map[string]interface{} {
func (e *UnauthorizedError) Output() rego.ResultSet {
return e.output
}
// As implements the errors.As interface.
func (*UnauthorizedError) As(target interface{}) bool {
if _, ok := target.(*UnauthorizedError); ok {
return true
}
return false
}

32
coderd/rbac/error_test.go Normal file
View File

@ -0,0 +1,32 @@
package rbac_test
import (
"testing"
"github.com/coder/coder/coderd/rbac"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
func TestIsUnauthorizedError(t *testing.T) {
t.Parallel()
t.Run("NotWrapped", func(t *testing.T) {
t.Parallel()
errFunc := func() error {
return rbac.UnauthorizedError{}
}
err := errFunc()
require.True(t, rbac.IsUnauthorizedError(err))
})
t.Run("Wrapped", func(t *testing.T) {
t.Parallel()
errFunc := func() error {
return xerrors.Errorf("test error: %w", rbac.UnauthorizedError{})
}
err := errFunc()
require.True(t, rbac.IsUnauthorizedError(err))
})
}

View File

@ -3,6 +3,8 @@ package rbac
import (
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
)
@ -41,6 +43,29 @@ func (s Scope) Name() string {
return s.Role.Name
}
// WorkspaceAgentScope returns a scope that is the same as ScopeAll but can only
// affect resources in the allow list. Only a scope is returned as the roles
// should come from the workspace owner.
func WorkspaceAgentScope(workspaceID, ownerID uuid.UUID) Scope {
allScope, err := ScopeAll.Expand()
if err != nil {
panic("failed to expand scope all, this should never happen")
}
return Scope{
// TODO: We want to limit the role too to be extra safe.
// Even though the allowlist blocks anything else, it is still good
// incase we change the behavior of the allowlist. The allowlist is new
// and evolving.
Role: allScope.Role,
// This prevents the agent from being able to access any other resource.
AllowIDList: []string{
workspaceID.String(),
ownerID.String(),
// TODO: Might want to include the template the workspace uses too?
},
}
}
const (
ScopeAll ScopeName = "all"
ScopeApplicationConnect ScopeName = "application_connect"

View File

@ -47,7 +47,7 @@ func (api *API) assignableOrgRoles(rw http.ResponseWriter, r *http.Request) {
actorRoles := httpmw.UserAuthorization(r)
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) {
httpapi.Forbidden(rw)
httpapi.ResourceNotFound(rw)
return
}

View File

@ -30,7 +30,7 @@ func TestListRoles(t *testing.T) {
})
require.NoError(t, err, "create org")
const forbidden = "Forbidden"
const notFound = "Resource not found"
testCases := []struct {
Name string
Client *codersdk.Client
@ -66,7 +66,7 @@ func TestListRoles(t *testing.T) {
APICall: func(ctx context.Context) ([]codersdk.AssignableRoles, error) {
return member.ListOrganizationRoles(ctx, otherOrg.ID)
},
AuthorizedError: forbidden,
AuthorizedError: notFound,
},
// Org admin
{
@ -95,7 +95,7 @@ func TestListRoles(t *testing.T) {
APICall: func(ctx context.Context) ([]codersdk.AssignableRoles, error) {
return orgAdmin.ListOrganizationRoles(ctx, otherOrg.ID)
},
AuthorizedError: forbidden,
AuthorizedError: notFound,
},
// Admin
{
@ -133,7 +133,7 @@ func TestListRoles(t *testing.T) {
if c.AuthorizedError != "" {
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
require.Contains(t, apiErr.Message, c.AuthorizedError)
} else {
require.NoError(t, err)

View File

@ -82,6 +82,10 @@ func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
return
}
// TODO: This just returns the workspaces a user can view. We should use
// a system function to get all workspaces that use this template.
// This data should never be exposed to the user aside from a non-zero count.
// Or we move this into a postgres constraint.
workspaces, err := api.Database.GetWorkspaces(ctx, database.GetWorkspacesParams{
TemplateIds: []uuid.UUID{template.ID},
})

View File

@ -18,9 +18,9 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
@ -57,7 +57,8 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
return
}
user, err := api.Database.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
//nolint:gocritic // In order to login, we need to get the user first!
user, err := api.Database.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
Email: loginWithPassword.Email,
})
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
@ -111,15 +112,32 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
return
}
//nolint:gocritic // System needs to fetch user roles in order to login user.
roles, err := api.Database.GetAuthorizationUserRoles(dbauthz.AsSystem(ctx), user.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error.",
})
return
}
// If the user logged into a suspended account, reject the login request.
if user.Status != database.UserStatusActive {
if roles.Status != database.UserStatusActive {
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
Message: "Your account is suspended. Contact an admin to reactivate your account.",
})
return
}
cookie, key, err := api.createAPIKey(ctx, createAPIKeyParams{
userSubj := rbac.Subject{
ID: user.ID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.ScopeAll,
}
//nolint:gocritic // Creating the API key as the user instead of as system.
cookie, key, err := api.createAPIKey(dbauthz.As(ctx, userSubj), createAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypePassword,
RemoteAddr: r.RemoteAddr,
@ -765,7 +783,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
// with OIDC for the first time.
if user.ID == uuid.Nil {
var organizationID uuid.UUID
organizations, _ := tx.GetOrganizations(ctx)
//nolint:gocritic
organizations, _ := tx.GetOrganizations(dbauthz.AsSystem(ctx))
if len(organizations) > 0 {
// Add the user to the first organization. Once multi-organization
// support is added, we should enable a configuration map of user
@ -773,7 +792,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
organizationID = organizations[0].ID
}
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
//nolint:gocritic
_, err := tx.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
Username: params.Username,
})
if err == nil {
@ -786,7 +806,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
params.Username = httpapi.UsernameFrom(alternate)
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
//nolint:gocritic
_, err := tx.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
Username: params.Username,
})
if xerrors.Is(err, sql.ErrNoRows) {
@ -805,7 +826,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
}
}
user, _, err = api.CreateUser(ctx, tx, CreateUserRequest{
//nolint:gocritic
user, _, err = api.CreateUser(dbauthz.AsSystem(ctx), tx, CreateUserRequest{
CreateUserRequest: codersdk.CreateUserRequest{
Email: params.Email,
Username: params.Username,
@ -819,7 +841,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
}
if link.UserID == uuid.Nil {
link, err = tx.InsertUserLink(ctx, database.InsertUserLinkParams{
//nolint:gocritic
link, err = tx.InsertUserLink(dbauthz.AsSystem(ctx), database.InsertUserLinkParams{
UserID: user.ID,
LoginType: params.LoginType,
LinkedID: params.LinkedID,
@ -833,7 +856,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
}
if link.UserID != uuid.Nil {
link, err = tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
//nolint:gocritic
link, err = tx.UpdateUserLink(dbauthz.AsSystem(ctx), database.UpdateUserLinkParams{
UserID: user.ID,
LoginType: params.LoginType,
OAuthAccessToken: params.State.Token.AccessToken,
@ -847,7 +871,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
// Ensure groups are correct.
if len(params.Groups) > 0 {
err := api.Options.SetUserGroups(ctx, tx, user.ID, params.Groups)
//nolint:gocritic
err := api.Options.SetUserGroups(dbauthz.AsSystem(ctx), tx, user.ID, params.Groups)
if err != nil {
return xerrors.Errorf("set user groups: %w", err)
}
@ -880,7 +905,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
// In such cases in the current implementation this user can now no
// longer sign in until an administrator finds the offending built-in
// user and changes their username.
user, err = tx.UpdateUserProfile(ctx, database.UpdateUserProfileParams{
//nolint:gocritic
user, err = tx.UpdateUserProfile(dbauthz.AsSystem(ctx), database.UpdateUserProfileParams{
ID: user.ID,
Email: user.Email,
Username: user.Username,
@ -898,7 +924,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
return nil, database.APIKey{}, xerrors.Errorf("in tx: %w", err)
}
cookie, key, err := api.createAPIKey(ctx, createAPIKeyParams{
//nolint:gocritic
cookie, key, err := api.createAPIKey(dbauthz.AsSystem(ctx), createAPIKeyParams{
UserID: user.ID,
LoginType: params.LoginType,
RemoteAddr: r.RemoteAddr,

View File

@ -16,6 +16,7 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
@ -37,7 +38,8 @@ import (
// @Router /users/first [get]
func (api *API) firstUser(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userCount, err := api.Database.GetUserCount(ctx)
//nolint:gocritic // needed for first user check
userCount, err := api.Database.GetUserCount(dbauthz.AsSystem(ctx))
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching user count.",
@ -70,6 +72,7 @@ func (api *API) firstUser(rw http.ResponseWriter, r *http.Request) {
// @Success 201 {object} codersdk.CreateFirstUserResponse
// @Router /users/first [post]
func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
// TODO: Should this admin system context be in a middleware?
ctx := r.Context()
var createUser codersdk.CreateFirstUserRequest
if !httpapi.Read(ctx, rw, r, &createUser) {
@ -77,7 +80,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
}
// This should only function for the first user.
userCount, err := api.Database.GetUserCount(ctx)
//nolint:gocritic // needed to create first user
userCount, err := api.Database.GetUserCount(dbauthz.AsSystem(ctx))
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching user count.",
@ -117,7 +121,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
return
}
user, organizationID, err := api.CreateUser(ctx, api.Database, CreateUserRequest{
//nolint:gocritic // needed to create first user
user, organizationID, err := api.CreateUser(dbauthz.AsSystem(ctx), api.Database, CreateUserRequest{
CreateUserRequest: codersdk.CreateUserRequest{
Email: createUser.Email,
Username: createUser.Username,
@ -146,7 +151,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
// the user. Maybe I add this ability to grant roles in the createUser api
// and add some rbac bypass when calling api functions this way??
// Add the admin role to this first user.
_, err = api.Database.UpdateUserRoles(ctx, database.UpdateUserRolesParams{
//nolint:gocritic // needed to create first user
_, err = api.Database.UpdateUserRoles(dbauthz.AsSystem(ctx), database.UpdateUserRolesParams{
GrantedRoles: []string{rbac.RoleOwner()},
ID: user.ID,
})
@ -987,7 +993,7 @@ func (api *API) organizationByUserAndName(rw http.ResponseWriter, r *http.Reques
ctx := r.Context()
organizationName := chi.URLParam(r, "organizationname")
organization, err := api.Database.GetOrganizationByName(ctx, organizationName)
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -854,7 +854,7 @@ func TestGrantSiteRoles(t *testing.T) {
AssignToUser: randOrgUser.ID.String(),
Roles: []string{rbac.RoleOrgMember(randOrg.ID)},
Error: true,
StatusCode: http.StatusForbidden,
StatusCode: http.StatusNotFound,
},
{
Name: "AdminUpdateOrgSelf",

View File

@ -1,10 +1,5 @@
package slice
// New is a convenience method for creating []T.
func New[T any](items ...T) []T {
return items
}
// SameElements returns true if the 2 lists have the same elements in any
// order.
func SameElements[T comparable](a []T, b []T) bool {
@ -67,3 +62,8 @@ func OverlapCompare[T any](a []T, b []T, equal func(a, b T) bool) bool {
}
return false
}
// New is a convenience method for creating []T.
func New[T any](items ...T) []T {
return items
}

View File

@ -26,6 +26,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
@ -625,14 +626,29 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
// inactive disconnect timeout we ensure that we don't block but
// also guarantee that the agent will be considered disconnected
// by normal status check.
ctx, cancel := context.WithTimeout(api.ctx, api.AgentInactiveDisconnectTimeout)
//
// Use a system context as the agent has disconnected and that token
// may no longer be valid.
//nolint:gocritic
ctx, cancel := context.WithTimeout(dbauthz.AsSystem(api.ctx), api.AgentInactiveDisconnectTimeout)
defer cancel()
disconnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
_ = updateConnectionTimes(ctx)
err := updateConnectionTimes(ctx)
if err != nil {
// This is a bug with unit tests that cancel the app context and
// cause this error log to be generated. We should fix the unit tests
// as this is a valid log.
if !xerrors.Is(err, context.Canceled) {
api.Logger.Error(ctx, "failed to update agent disconnect time",
slog.Error(err),
slog.F("workspace", build.WorkspaceID),
)
}
}
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
}()
@ -907,7 +923,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
slog.F("payload", req),
)
activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace.ID)
activityBumpWorkspace(ctx, api.Logger.Named("activity_bump"), api.Database, workspace.ID)
payload, err := json.Marshal(req)
if err != nil {

View File

@ -24,6 +24,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
@ -330,7 +331,8 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request
// different auth formats, and tricks this endpoint into deleting an
// unchecked API key, we validate that the secret matches the secret
// we store in the database.
apiKey, err := api.Database.GetAPIKeyByID(ctx, id)
//nolint:gocritic // needed for workspace app logout
apiKey, err := api.Database.GetAPIKeyByID(dbauthz.AsSystem(ctx), id)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to lookup API key.",
@ -349,7 +351,8 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request
})
return
}
err = api.Database.DeleteAPIKeyByID(ctx, id)
//nolint:gocritic // needed for workspace app logout
err = api.Database.DeleteAPIKeyByID(dbauthz.AsSystem(ctx), id)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to delete API key.",
@ -409,7 +412,10 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request
// error while looking it up, an HTML error page is returned and false is
// returned so the caller can return early.
func (api *API) lookupWorkspaceApp(rw http.ResponseWriter, r *http.Request, agentID uuid.UUID, appSlug string) (database.WorkspaceApp, bool) {
app, err := api.Database.GetWorkspaceAppByAgentIDAndSlug(r.Context(), database.GetWorkspaceAppByAgentIDAndSlugParams{
// dbauthz.AsSystem is allowed here as the app authz is checked later.
// The app authz is determined by the sharing level.
//nolint:gocritic
app, err := api.Database.GetWorkspaceAppByAgentIDAndSlug(dbauthz.AsSystem(r.Context()), database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: agentID,
Slug: appSlug,
})
@ -1019,7 +1025,8 @@ func decryptAPIKey(ctx context.Context, db database.Store, encryptedAPIKey strin
// Lookup the API key so we can decrypt it.
keyID := object.Header.KeyID
key, err := db.GetAPIKeyByID(ctx, keyID)
//nolint:gocritic // needed to check API key
key, err := db.GetAPIKeyByID(dbauthz.AsSystem(ctx), keyID)
if err != nil {
return database.APIKey{}, "", xerrors.Errorf("get API key by key ID: %w", err)
}

View File

@ -10,6 +10,7 @@ import (
"github.com/coder/coder/coderd/awsidentity"
"github.com/coder/coder/coderd/azureidentity"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/codersdk"
@ -126,7 +127,8 @@ func (api *API) postWorkspaceAuthGoogleInstanceIdentity(rw http.ResponseWriter,
func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, instanceID string) {
ctx := r.Context()
agent, err := api.Database.GetWorkspaceAgentByInstanceID(ctx, instanceID)
//nolint:gocritic // needed for auth instance id
agent, err := api.Database.GetWorkspaceAgentByInstanceID(dbauthz.AsSystem(ctx), instanceID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Instance with id %q not found.", instanceID),
@ -140,7 +142,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
})
return
}
resource, err := api.Database.GetWorkspaceResourceByID(ctx, agent.ResourceID)
//nolint:gocritic // needed for auth instance id
resource, err := api.Database.GetWorkspaceResourceByID(dbauthz.AsSystem(ctx), agent.ResourceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job resource.",
@ -148,7 +151,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
})
return
}
job, err := api.Database.GetProvisionerJobByID(ctx, resource.JobID)
//nolint:gocritic // needed for auth instance id
job, err := api.Database.GetProvisionerJobByID(dbauthz.AsSystem(ctx), resource.JobID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
@ -171,7 +175,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
})
return
}
resourceHistory, err := api.Database.GetWorkspaceBuildByID(ctx, jobData.WorkspaceBuildID)
//nolint:gocritic // needed for auth instance id
resourceHistory, err := api.Database.GetWorkspaceBuildByID(dbauthz.AsSystem(ctx), jobData.WorkspaceBuildID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace build.",
@ -182,7 +187,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
// This token should only be exchanged if the instance ID is valid
// for the latest history. If an instance ID is recycled by a cloud,
// we'd hate to leak access to a user's workspace.
latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, resourceHistory.WorkspaceID)
//nolint:gocritic // needed for auth instance id
latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(dbauthz.AsSystem(ctx), resourceHistory.WorkspaceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching the latest workspace build.",

View File

@ -371,6 +371,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
return
}
// TODO: This should be a system call as the actor might not be able to
// read other workspaces. Ideally we check the error on create and look for
// a postgres conflict error.
workspace, err := api.Database.GetWorkspaceByOwnerIDAndName(ctx, database.GetWorkspaceByOwnerIDAndNameParams{
OwnerID: user.ID,
Name: createWorkspace.Name,

View File

@ -144,7 +144,9 @@ func New(ctx context.Context, options *Options) (*API, error) {
if len(options.SCIMAPIKey) != 0 {
api.AGPL.RootHandler.Route("/scim/v2", func(r chi.Router) {
r.Use(api.scimEnabledMW)
r.Use(
api.scimEnabledMW,
)
r.Post("/Users", api.scimPostUser)
r.Route("/Users", func(r chi.Router) {
r.Get("/", api.scimGetUsers)

View File

@ -6,6 +6,8 @@ import (
"testing"
"time"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
@ -100,7 +102,9 @@ func TestEntitlements(t *testing.T) {
require.NoError(t, err)
require.False(t, entitlements.HasLicense)
coderdtest.CreateFirstUser(t, client)
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
//nolint:gocritic // unit test
ctx := dbauthz.AsSystem(context.Background())
_, err = api.Database.InsertLicense(ctx, database.InsertLicenseParams{
UploadedAt: database.Now(),
Exp: database.Now().AddDate(1, 0, 0),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
@ -128,7 +132,9 @@ func TestEntitlements(t *testing.T) {
require.False(t, entitlements.HasLicense)
coderdtest.CreateFirstUser(t, client)
// Valid
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
ctx := context.Background()
//nolint:gocritic // unit test
_, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{
UploadedAt: database.Now(),
Exp: database.Now().AddDate(1, 0, 0),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
@ -139,7 +145,8 @@ func TestEntitlements(t *testing.T) {
})
require.NoError(t, err)
// Expired
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
//nolint:gocritic // unit test
_, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{
UploadedAt: database.Now(),
Exp: database.Now().AddDate(-1, 0, 0),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
@ -148,7 +155,8 @@ func TestEntitlements(t *testing.T) {
})
require.NoError(t, err)
// Invalid
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
//nolint:gocritic // unit test
_, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{
UploadedAt: database.Now(),
Exp: database.Now().AddDate(1, 0, 0),
JWT: "invalid",

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net/http"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
@ -22,6 +24,9 @@ func TestNew(t *testing.T) {
}
func TestAuthorizeAllEndpoints(t *testing.T) {
if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) {
t.Skip("Skipping TestAuthorizeAllEndpoints for authz_querier experiment")
}
t.Parallel()
client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{

View File

@ -14,6 +14,7 @@ import (
agpl "github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
@ -155,7 +156,8 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
return
}
user, _, err := api.AGPL.CreateUser(ctx, api.Database, agpl.CreateUserRequest{
//nolint:gocritic // needed for SCIM
user, _, err := api.AGPL.CreateUser(dbauthz.AsSystem(ctx), api.Database, agpl.CreateUserRequest{
CreateUserRequest: codersdk.CreateUserRequest{
Username: sUser.UserName,
Email: email,
@ -207,7 +209,8 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
return
}
dbUser, err := api.Database.GetUserByID(ctx, uid)
//nolint:gocritic // needed for SCIM
dbUser, err := api.Database.GetUserByID(dbauthz.AsSystem(ctx), uid)
if err != nil {
_ = handlerutil.WriteError(rw, err)
return
@ -220,7 +223,8 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
status = database.UserStatusSuspended
}
_, err = api.Database.UpdateUserStatus(r.Context(), database.UpdateUserStatusParams{
//nolint:gocritic // needed for SCIM
_, err = api.Database.UpdateUserStatus(dbauthz.AsSystem(r.Context()), database.UpdateUserStatusParams{
ID: dbUser.ID,
Status: status,
UpdatedAt: database.Now(),

View File

@ -921,6 +921,10 @@ func TestTemplateAccess(t *testing.T) {
testTemplateRead := func(t *testing.T, org orgSetup, usr *codersdk.Client, read []codersdk.Template) {
found, err := usr.TemplatesByOrganization(ctx, org.Org.ID)
if len(read) == 0 && err != nil {
require.ErrorContains(t, err, "Resource not found")
return
}
require.NoError(t, err, "failed to get templates")
exp := make(map[uuid.UUID]codersdk.Template)

View File

@ -24,6 +24,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/provisionerd/proto"
sdkproto "github.com/coder/coder/provisionersdk/proto"
@ -886,7 +887,8 @@ func (r *Runner) commitQuota(ctx context.Context, resources []*sdkproto.Resource
const stage = "Commit quota"
resp, err := r.quotaCommitter.CommitQuota(ctx, &proto.CommitQuotaRequest{
//nolint:gocritic // TODO: make a provisionerd role
resp, err := r.quotaCommitter.CommitQuota(dbauthz.AsSystem(ctx), &proto.CommitQuotaRequest{
JobId: r.job.JobId,
DailyCost: int32(cost),
})

View File

@ -20,6 +20,29 @@ import (
"github.com/quasilyte/go-ruleguard/dsl/types"
)
// dbauthzAuthorizationContext is a lint rule that protects the usage of
// system contexts. This is a dangerous pattern that can lead to
// leaking database information as a system context can be essentially
// "sudo".
//
// Anytime a function like "AsSystem" is used, it should be accompanied by a comment
// explaining why it's ok and a nolint.
func dbauthzAuthorizationContext(m dsl.Matcher) {
m.Import("context")
m.Import("github.com/coder/coder/coderd/database/dbauthz")
m.Match(
`dbauthz.$f($c)`,
).
Where(
m["c"].Type.Implements("context.Context") &&
// Only report on functions that start with "As".
m["f"].Text.Matches("^As"),
).
// Instructions for fixing the lint error should be included on the dangerous function.
Report("Using '$f' is dangerous and should be accompanied by a comment explaining why it's ok and a nolint.")
}
// Use xerrors everywhere! It provides additional stacktrace info!
//
//nolint:unused,deadcode,varnamelen