mirror of https://github.com/coder/coder.git
feat: Add initial AuthzQuerier implementation (#5919)
feat: Add initial AuthzQuerier implementation - Adds package database/dbauthz that adds a database.Store implementation where each method goes through AuthZ checks - Implements all database.Store methods on AuthzQuerier - Updates and fixes unit tests where required - Updates coderd initialization to use AuthzQuerier if codersdk.ExperimentAuthzQuerier is enabled
This commit is contained in:
parent
ebdfdc749d
commit
6fb8aff6d0
|
@ -15,10 +15,10 @@ import (
|
|||
|
||||
// activityBumpWorkspace automatically bumps the workspace's auto-off timer
|
||||
// if it is set to expire soon.
|
||||
func activityBumpWorkspace(log slog.Logger, db database.Store, workspaceID uuid.UUID) {
|
||||
func activityBumpWorkspace(ctx context.Context, log slog.Logger, db database.Store, workspaceID uuid.UUID) {
|
||||
// We set a short timeout so if the app is under load, these
|
||||
// low priority operations fail first.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*15)
|
||||
defer cancel()
|
||||
|
||||
err := db.InTx(func(s database.Store) error {
|
||||
|
@ -82,9 +82,12 @@ func activityBumpWorkspace(log slog.Logger, db database.Store, workspaceID uuid.
|
|||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
log.Error(ctx, "bump failed", slog.Error(err),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
)
|
||||
if !xerrors.Is(err, context.Canceled) {
|
||||
// Bump will fail if the context is cancelled, but this is ok.
|
||||
log.Error(ctx, "bump failed", slog.Error(err),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -51,6 +51,28 @@ type HTTPAuthorizer struct {
|
|||
// return
|
||||
// }
|
||||
func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
|
||||
// The experiment does not replace ALL rbac checks, but does replace most.
|
||||
// This statement aborts early on the checks that will be removed in the
|
||||
// future when this experiment is default.
|
||||
if api.Experiments.Enabled(codersdk.ExperimentAuthzQuerier) {
|
||||
// Some resource types do not interact with the persistent layer and
|
||||
// we need to keep these checks happening in the API layer.
|
||||
switch object.RBACObject().Type {
|
||||
case rbac.ResourceWorkspaceExecution.Type:
|
||||
// This is not a db resource, always in API layer
|
||||
case rbac.ResourceDeploymentConfig.Type:
|
||||
// For metric cache items like DAU, we do not hit the DB.
|
||||
// Some db actions are in asserted in the authz layer.
|
||||
case rbac.ResourceReplicas.Type:
|
||||
// Replica rbac is checked for adding and removing replicas.
|
||||
case rbac.ResourceProvisionerDaemon.Type:
|
||||
// Provisioner rbac is checked for adding and removing provisioners.
|
||||
case rbac.ResourceDebugInfo.Type:
|
||||
// This is not a db resource, always in API layer.
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
return api.HTTPAuth.Authorize(r, action, object)
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/autobuild/schedule"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
)
|
||||
|
||||
// Executor automatically starts or stops workspaces.
|
||||
|
@ -33,7 +34,8 @@ type Stats struct {
|
|||
// New returns a new autobuild executor.
|
||||
func New(ctx context.Context, db database.Store, log slog.Logger, tick <-chan time.Time) *Executor {
|
||||
le := &Executor{
|
||||
ctx: ctx,
|
||||
//nolint:gocritic // TODO: make an autostart role instead of using System
|
||||
ctx: dbauthz.AsSystem(ctx),
|
||||
db: db,
|
||||
tick: tick,
|
||||
log: log,
|
||||
|
|
|
@ -42,6 +42,7 @@ import (
|
|||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/awsidentity"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/dbtype"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
|
@ -157,13 +158,6 @@ func New(options *Options) *API {
|
|||
options = &Options{}
|
||||
}
|
||||
experiments := initExperiments(options.Logger, options.DeploymentConfig.Experiments.Value, options.DeploymentConfig.Experimental.Value)
|
||||
// TODO: remove this once we promote authz_querier out of experiments.
|
||||
if experiments.Enabled(codersdk.ExperimentAuthzQuerier) {
|
||||
panic("Coming soon!")
|
||||
// if _, ok := (options.Database).(*authzquery.AuthzQuerier); !ok {
|
||||
// options.Database = authzquery.NewAuthzQuerier(options.Database, options.Authorizer)
|
||||
// }
|
||||
}
|
||||
if options.AppHostname != "" && options.AppHostnameRegex == nil || options.AppHostname == "" && options.AppHostnameRegex != nil {
|
||||
panic("coderd: both AppHostname and AppHostnameRegex must be set or unset")
|
||||
}
|
||||
|
@ -204,6 +198,14 @@ func New(options *Options) *API {
|
|||
if options.Auditor == nil {
|
||||
options.Auditor = audit.NewNop()
|
||||
}
|
||||
// TODO: remove this once we promote authz_querier out of experiments.
|
||||
if experiments.Enabled(codersdk.ExperimentAuthzQuerier) {
|
||||
options.Database = dbauthz.New(
|
||||
options.Database,
|
||||
options.Authorizer,
|
||||
options.Logger.Named("authz_querier"),
|
||||
)
|
||||
}
|
||||
if options.SetUserGroups == nil {
|
||||
options.SetUserGroups = func(context.Context, database.Store, uuid.UUID, []string) error { return nil }
|
||||
}
|
||||
|
@ -304,8 +306,10 @@ func New(options *Options) *API {
|
|||
DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value,
|
||||
Optional: true,
|
||||
}),
|
||||
httpmw.ExtractUserParam(api.Database, false),
|
||||
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
|
||||
httpmw.AsAuthzSystem(
|
||||
httpmw.ExtractUserParam(api.Database, false),
|
||||
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
|
||||
),
|
||||
),
|
||||
// Build-Version is helpful for debugging.
|
||||
func(next http.Handler) http.Handler {
|
||||
|
@ -332,11 +336,13 @@ func New(options *Options) *API {
|
|||
DisableSessionExpiryRefresh: options.DeploymentConfig.DisableSessionExpiryRefresh.Value,
|
||||
Optional: true,
|
||||
}),
|
||||
// Redirect to the login page if the user tries to open an app with
|
||||
// "me" as the username and they are not logged in.
|
||||
httpmw.ExtractUserParam(api.Database, true),
|
||||
// Extracts the <workspace.agent> from the url
|
||||
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
|
||||
httpmw.AsAuthzSystem(
|
||||
// Redirect to the login page if the user tries to open an app with
|
||||
// "me" as the username and they are not logged in.
|
||||
httpmw.ExtractUserParam(api.Database, true),
|
||||
// Extracts the <workspace.agent> from the url
|
||||
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
|
||||
),
|
||||
)
|
||||
r.HandleFunc("/*", api.workspaceAppsProxyPath)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
|
@ -20,8 +19,9 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
@ -30,12 +30,6 @@ import (
|
|||
)
|
||||
|
||||
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
// For any route using SQL filters, we need to know if the database is an
|
||||
// in memory fake. This is because the in memory fake does not use SQL, and
|
||||
// still uses rego. So this boolean indicates how to assert the expected
|
||||
// behavior.
|
||||
_, isMemoryDB := a.api.Database.(dbfake.FakeDatabase)
|
||||
|
||||
// Some quick reused objects
|
||||
workspaceRBACObj := rbac.ResourceWorkspace.WithID(a.Workspace.ID).InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||
workspaceExecObj := rbac.ResourceWorkspaceExecution.WithID(a.Workspace.ID).InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||
|
@ -269,16 +263,17 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
"POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
"POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
|
||||
// Endpoints that use the SQLQuery filter.
|
||||
// For any route using SQL filters, we do not check authorization.
|
||||
// This is because the in memory fake does not use SQL.
|
||||
"GET:/api/v2/workspaces/": {
|
||||
StatusCode: http.StatusOK,
|
||||
NoAuthorize: !isMemoryDB,
|
||||
NoAuthorize: true,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceWorkspace,
|
||||
},
|
||||
"GET:/api/v2/organizations/{organization}/templates": {
|
||||
StatusCode: http.StatusOK,
|
||||
NoAuthorize: !isMemoryDB,
|
||||
NoAuthorize: true,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceTemplate,
|
||||
},
|
||||
|
|
|
@ -2,15 +2,21 @@ package coderdtest_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) {
|
||||
t.Skip("Skipping TestAuthorizeAllEndpoints for authz_querier experiment")
|
||||
}
|
||||
t.Parallel()
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
// Required for any subdomain-based proxy tests to pass.
|
||||
|
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -58,6 +59,7 @@ import (
|
|||
"github.com/coder/coder/coderd/autobuild/executor"
|
||||
"github.com/coder/coder/coderd/awsidentity"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
|
@ -179,12 +181,13 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can
|
|||
options.Database, options.Pubsub = dbtestutil.NewDB(t)
|
||||
}
|
||||
// TODO: remove this once we're ready to enable authz querier by default.
|
||||
if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), "authz_querier") {
|
||||
panic("Coming soon!")
|
||||
// if options.Authorizer != nil {
|
||||
// options.Authorizer = &RecordingAuthorizer{}
|
||||
// }
|
||||
// options.Database = authzquery.NewAuthzQuerier(options.Database, options.Authorizer)
|
||||
if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) {
|
||||
if options.Authorizer == nil {
|
||||
options.Authorizer = &RecordingAuthorizer{
|
||||
Wrapped: rbac.NewAuthorizer(prometheus.NewRegistry()),
|
||||
}
|
||||
}
|
||||
options.Database = dbauthz.New(options.Database, options.Authorizer, slogtest.Make(t, nil).Leveled(slog.LevelDebug))
|
||||
}
|
||||
if options.DeploymentConfig == nil {
|
||||
options.DeploymentConfig = DeploymentConfig(t)
|
||||
|
|
|
@ -0,0 +1,387 @@
|
|||
package dbauthz
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
)
|
||||
|
||||
var _ database.Store = (*querier)(nil)
|
||||
|
||||
var (
|
||||
// NoActorError wraps ErrNoRows for the api to return a 404. This is the correct
|
||||
// response when the user is not authorized.
|
||||
NoActorError = xerrors.Errorf("no authorization actor in context: %w", sql.ErrNoRows)
|
||||
)
|
||||
|
||||
// NotAuthorizedError is a sentinel error that unwraps to sql.ErrNoRows.
|
||||
// This allows the internal error to be read by the caller if needed. Otherwise
|
||||
// it will be handled as a 404.
|
||||
type NotAuthorizedError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e NotAuthorizedError) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s", e.Err.Error())
|
||||
}
|
||||
|
||||
// Unwrap will always unwrap to a sql.ErrNoRows so the API returns a 404.
|
||||
// So 'errors.Is(err, sql.ErrNoRows)' will always be true.
|
||||
func (NotAuthorizedError) Unwrap() error {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func logNotAuthorizedError(ctx context.Context, logger slog.Logger, err error) error {
|
||||
// Only log the errors if it is an UnauthorizedError error.
|
||||
internalError := new(rbac.UnauthorizedError)
|
||||
if err != nil && xerrors.As(err, &internalError) {
|
||||
e := new(topdown.Error)
|
||||
if xerrors.As(err, &e) || e.Code == topdown.CancelErr {
|
||||
// For some reason rego changes a cancelled context to a topdown.CancelErr. We
|
||||
// expect to check for cancelled context errors if the user cancels the request,
|
||||
// so we should change the error to a context.Canceled error.
|
||||
//
|
||||
// NotAuthorizedError is == to sql.ErrNoRows, which is not correct
|
||||
// if it's actually a cancelled context.
|
||||
internalError.SetInternal(context.Canceled)
|
||||
return internalError
|
||||
}
|
||||
logger.Debug(ctx, "unauthorized",
|
||||
slog.F("internal", internalError.Internal()),
|
||||
slog.F("input", internalError.Input()),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
return NotAuthorizedError{
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// querier is a wrapper around the database store that performs authorization
|
||||
// checks before returning data. All querier methods expect an authorization
|
||||
// subject present in the context. If no subject is present, most methods will
|
||||
// fail.
|
||||
//
|
||||
// Use WithAuthorizeContext to set the authorization subject in the context for
|
||||
// the common user case.
|
||||
type querier struct {
|
||||
db database.Store
|
||||
auth rbac.Authorizer
|
||||
log slog.Logger
|
||||
}
|
||||
|
||||
func New(db database.Store, authorizer rbac.Authorizer, logger slog.Logger) database.Store {
|
||||
// If the underlying db store is already a querier, return it.
|
||||
// Do not double wrap.
|
||||
if _, ok := db.(*querier); ok {
|
||||
return db
|
||||
}
|
||||
return &querier{
|
||||
db: db,
|
||||
auth: authorizer,
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeContext is a helper function to authorize an action on an object.
|
||||
func (q *querier) authorizeContext(ctx context.Context, action rbac.Action, object rbac.Objecter) error {
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return NoActorError
|
||||
}
|
||||
|
||||
err := q.auth.Authorize(ctx, act, action, object.RBACObject())
|
||||
if err != nil {
|
||||
return logNotAuthorizedError(ctx, q.log, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type authContextKey struct{}
|
||||
|
||||
// ActorFromContext returns the authorization subject from the context.
|
||||
// All authentication flows should set the authorization subject in the context.
|
||||
// If no actor is present, the function returns false.
|
||||
func ActorFromContext(ctx context.Context) (rbac.Subject, bool) {
|
||||
a, ok := ctx.Value(authContextKey{}).(rbac.Subject)
|
||||
return a, ok
|
||||
}
|
||||
|
||||
// AsSystem returns a context with a system actor. This is used for internal
|
||||
// system operations that are not tied to any particular actor.
|
||||
// When you use this function, be sure to add a //nolint comment
|
||||
// explaining why it is necessary.
|
||||
//
|
||||
// We trust you have received the usual lecture from the local System
|
||||
// Administrator. It usually boils down to these three things:
|
||||
// #1) Respect the privacy of others.
|
||||
// #2) Think before you type.
|
||||
// #3) With great power comes great responsibility.
|
||||
func AsSystem(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, authContextKey{}, rbac.Subject{
|
||||
ID: uuid.Nil.String(),
|
||||
Roles: rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Name: "system",
|
||||
DisplayName: "System",
|
||||
Site: []rbac.Permission{
|
||||
{
|
||||
ResourceType: rbac.ResourceWildcard.Type,
|
||||
Action: rbac.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
Org: map[string][]rbac.Permission{},
|
||||
User: []rbac.Permission{},
|
||||
},
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
var AsRemoveActor = rbac.Subject{
|
||||
ID: "remove-actor",
|
||||
}
|
||||
|
||||
// As returns a context with the given actor stored in the context.
|
||||
// This is used for cases where the actor touching the database is not the
|
||||
// actor stored in the context.
|
||||
// When you use this function, be sure to add a //nolint comment
|
||||
// explaining why it is necessary.
|
||||
func As(ctx context.Context, actor rbac.Subject) context.Context {
|
||||
if actor.Equal(AsRemoveActor) {
|
||||
// AsRemoveActor is a special case that is used to indicate that the actor
|
||||
// should be removed from the context.
|
||||
return context.WithValue(ctx, authContextKey{}, nil)
|
||||
}
|
||||
return context.WithValue(ctx, authContextKey{}, actor)
|
||||
}
|
||||
|
||||
//
|
||||
// Generic functions used to implement the database.Store methods.
|
||||
//
|
||||
|
||||
// insert runs an rbac.ActionCreate on the rbac object argument before
|
||||
// running the insertFunc. The insertFunc is expected to return the object that
|
||||
// was inserted.
|
||||
func insert[
|
||||
ObjectType any,
|
||||
ArgumentType any,
|
||||
Insert func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
](
|
||||
logger slog.Logger,
|
||||
authorizer rbac.Authorizer,
|
||||
object rbac.Objecter,
|
||||
insertFunc Insert,
|
||||
) Insert {
|
||||
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
|
||||
// Fetch the rbac subject
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return empty, NoActorError
|
||||
}
|
||||
|
||||
// Authorize the action
|
||||
err = authorizer.Authorize(ctx, act, rbac.ActionCreate, object.RBACObject())
|
||||
if err != nil {
|
||||
return empty, logNotAuthorizedError(ctx, logger, err)
|
||||
}
|
||||
|
||||
// Insert the database object
|
||||
return insertFunc(ctx, arg)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteQ[
|
||||
ObjectType rbac.Objecter,
|
||||
ArgumentType any,
|
||||
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
Delete func(ctx context.Context, arg ArgumentType) error,
|
||||
](
|
||||
logger slog.Logger,
|
||||
authorizer rbac.Authorizer,
|
||||
fetchFunc Fetch,
|
||||
deleteFunc Delete,
|
||||
) Delete {
|
||||
return fetchAndExec(logger, authorizer,
|
||||
rbac.ActionDelete, fetchFunc, deleteFunc)
|
||||
}
|
||||
|
||||
func updateWithReturn[
|
||||
ObjectType rbac.Objecter,
|
||||
ArgumentType any,
|
||||
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
UpdateQuery func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
](
|
||||
logger slog.Logger,
|
||||
authorizer rbac.Authorizer,
|
||||
fetchFunc Fetch,
|
||||
updateQuery UpdateQuery,
|
||||
) UpdateQuery {
|
||||
return fetchAndQuery(logger, authorizer, rbac.ActionUpdate, fetchFunc, updateQuery)
|
||||
}
|
||||
|
||||
func update[
|
||||
ObjectType rbac.Objecter,
|
||||
ArgumentType any,
|
||||
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
Exec func(ctx context.Context, arg ArgumentType) error,
|
||||
](
|
||||
logger slog.Logger,
|
||||
authorizer rbac.Authorizer,
|
||||
fetchFunc Fetch,
|
||||
updateExec Exec,
|
||||
) Exec {
|
||||
return fetchAndExec(logger, authorizer, rbac.ActionUpdate, fetchFunc, updateExec)
|
||||
}
|
||||
|
||||
// fetch is a generic function that wraps a database
|
||||
// query function (returns an object and an error) with authorization. The
|
||||
// returned function has the same arguments as the database function.
|
||||
//
|
||||
// The database query function will **ALWAYS** hit the database, even if the
|
||||
// user cannot read the resource. This is because the resource details are
|
||||
// required to run a proper authorization check.
|
||||
func fetch[
|
||||
ArgumentType any,
|
||||
ObjectType rbac.Objecter,
|
||||
DatabaseFunc func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
](
|
||||
logger slog.Logger,
|
||||
authorizer rbac.Authorizer,
|
||||
f DatabaseFunc,
|
||||
) DatabaseFunc {
|
||||
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
|
||||
// Fetch the rbac subject
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return empty, NoActorError
|
||||
}
|
||||
|
||||
// Fetch the database object
|
||||
object, err := f(ctx, arg)
|
||||
if err != nil {
|
||||
return empty, xerrors.Errorf("fetch object: %w", err)
|
||||
}
|
||||
|
||||
// Authorize the action
|
||||
err = authorizer.Authorize(ctx, act, rbac.ActionRead, object.RBACObject())
|
||||
if err != nil {
|
||||
return empty, logNotAuthorizedError(ctx, logger, err)
|
||||
}
|
||||
|
||||
return object, nil
|
||||
}
|
||||
}
|
||||
|
||||
// fetchAndExec uses fetchAndQuery but only returns the error. The naming comes
|
||||
// from SQL 'exec' functions which only return an error.
|
||||
// See fetchAndQuery for more information.
|
||||
func fetchAndExec[
|
||||
ObjectType rbac.Objecter,
|
||||
ArgumentType any,
|
||||
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
Exec func(ctx context.Context, arg ArgumentType) error,
|
||||
](
|
||||
logger slog.Logger,
|
||||
authorizer rbac.Authorizer,
|
||||
action rbac.Action,
|
||||
fetchFunc Fetch,
|
||||
execFunc Exec,
|
||||
) Exec {
|
||||
f := fetchAndQuery(logger, authorizer, action, fetchFunc, func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
|
||||
return empty, execFunc(ctx, arg)
|
||||
})
|
||||
return func(ctx context.Context, arg ArgumentType) error {
|
||||
_, err := f(ctx, arg)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// fetchAndQuery is a generic function that wraps a database fetch and query.
|
||||
// A query has potential side effects in the database (update, delete, etc).
|
||||
// The fetch is used to know which rbac object the action should be asserted on
|
||||
// **before** the query runs. The returns from the fetch are only used to
|
||||
// assert rbac. The final return of this function comes from the Query function.
|
||||
func fetchAndQuery[
|
||||
ObjectType rbac.Objecter,
|
||||
ArgumentType any,
|
||||
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
Query func(ctx context.Context, arg ArgumentType) (ObjectType, error),
|
||||
](
|
||||
logger slog.Logger,
|
||||
authorizer rbac.Authorizer,
|
||||
action rbac.Action,
|
||||
fetchFunc Fetch,
|
||||
queryFunc Query,
|
||||
) Query {
|
||||
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
|
||||
// Fetch the rbac subject
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return empty, NoActorError
|
||||
}
|
||||
|
||||
// Fetch the database object
|
||||
object, err := fetchFunc(ctx, arg)
|
||||
if err != nil {
|
||||
return empty, xerrors.Errorf("fetch object: %w", err)
|
||||
}
|
||||
|
||||
// Authorize the action
|
||||
err = authorizer.Authorize(ctx, act, action, object.RBACObject())
|
||||
if err != nil {
|
||||
return empty, logNotAuthorizedError(ctx, logger, err)
|
||||
}
|
||||
|
||||
return queryFunc(ctx, arg)
|
||||
}
|
||||
}
|
||||
|
||||
// fetchWithPostFilter is like fetch, but works with lists of objects.
|
||||
// SQL filters are much more optimal.
|
||||
func fetchWithPostFilter[
|
||||
ArgumentType any,
|
||||
ObjectType rbac.Objecter,
|
||||
DatabaseFunc func(ctx context.Context, arg ArgumentType) ([]ObjectType, error),
|
||||
](
|
||||
authorizer rbac.Authorizer,
|
||||
f DatabaseFunc,
|
||||
) DatabaseFunc {
|
||||
return func(ctx context.Context, arg ArgumentType) (empty []ObjectType, err error) {
|
||||
// Fetch the rbac subject
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return empty, NoActorError
|
||||
}
|
||||
|
||||
// Fetch the database object
|
||||
objects, err := f(ctx, arg)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("fetch object: %w", err)
|
||||
}
|
||||
|
||||
// Authorize the action
|
||||
return rbac.Filter(ctx, authorizer, act, rbac.ActionRead, objects)
|
||||
}
|
||||
}
|
||||
|
||||
// prepareSQLFilter is a helper function that prepares a SQL filter using the
|
||||
// given authorization context.
|
||||
func prepareSQLFilter(ctx context.Context, authorizer rbac.Authorizer, action rbac.Action, resourceType string) (rbac.PreparedAuthorized, error) {
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, NoActorError
|
||||
}
|
||||
|
||||
return authorizer.Prepare(ctx, act, action, resourceType)
|
||||
}
|
|
@ -0,0 +1,151 @@
|
|||
package dbauthz_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/database/dbgen"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
)
|
||||
|
||||
func TestAsNoActor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AsRemoveActor", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ok := dbauthz.ActorFromContext(context.Background())
|
||||
require.False(t, ok, "no actor should be present")
|
||||
})
|
||||
|
||||
t.Run("AsActor", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := dbauthz.As(context.Background(), coderdtest.RandomRBACSubject())
|
||||
_, ok := dbauthz.ActorFromContext(ctx)
|
||||
require.True(t, ok, "actor present")
|
||||
})
|
||||
|
||||
t.Run("DeleteActor", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// First set an actor
|
||||
ctx := dbauthz.As(context.Background(), coderdtest.RandomRBACSubject())
|
||||
_, ok := dbauthz.ActorFromContext(ctx)
|
||||
require.True(t, ok, "actor present")
|
||||
|
||||
// Delete the actor
|
||||
ctx = dbauthz.As(ctx, dbauthz.AsRemoveActor)
|
||||
_, ok = dbauthz.ActorFromContext(ctx)
|
||||
require.False(t, ok, "actor should be deleted")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{}, slog.Make())
|
||||
_, err := q.Ping(context.Background())
|
||||
require.NoError(t, err, "must not error")
|
||||
}
|
||||
|
||||
// TestInTX is not perfect, just checks that it properly checks auth.
|
||||
func TestInTX(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbfake.New()
|
||||
q := dbauthz.New(db, &coderdtest.RecordingAuthorizer{
|
||||
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: xerrors.New("custom error")},
|
||||
}, slog.Make())
|
||||
actor := rbac.Subject{
|
||||
ID: uuid.NewString(),
|
||||
Roles: rbac.RoleNames{rbac.RoleOwner()},
|
||||
Groups: []string{},
|
||||
Scope: rbac.ScopeAll,
|
||||
}
|
||||
|
||||
w := dbgen.Workspace(t, db, database.Workspace{})
|
||||
ctx := dbauthz.As(context.Background(), actor)
|
||||
err := q.InTx(func(tx database.Store) error {
|
||||
// The inner tx should use the parent's authz
|
||||
_, err := tx.GetWorkspaceByID(ctx, w.ID)
|
||||
return err
|
||||
}, nil)
|
||||
require.Error(t, err, "must error")
|
||||
require.ErrorAs(t, err, &dbauthz.NotAuthorizedError{}, "must be an authorized error")
|
||||
}
|
||||
|
||||
// TestNew should not double wrap a querier.
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db = dbfake.New()
|
||||
exp = dbgen.Workspace(t, db, database.Workspace{})
|
||||
rec = &coderdtest.RecordingAuthorizer{
|
||||
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
|
||||
}
|
||||
subj = rbac.Subject{}
|
||||
ctx = dbauthz.As(context.Background(), rbac.Subject{})
|
||||
)
|
||||
|
||||
// Double wrap should not cause an actual double wrap. So only 1 rbac call
|
||||
// should be made.
|
||||
az := dbauthz.New(db, rec, slog.Make())
|
||||
az = dbauthz.New(az, rec, slog.Make())
|
||||
|
||||
w, err := az.GetWorkspaceByID(ctx, exp.ID)
|
||||
require.NoError(t, err, "must not error")
|
||||
require.Equal(t, exp, w, "must be equal")
|
||||
|
||||
rec.AssertActor(t, subj, rec.Pair(rbac.ActionRead, exp))
|
||||
require.NoError(t, rec.AllAsserted(), "should only be 1 rbac call")
|
||||
}
|
||||
|
||||
// TestDBAuthzRecursive is a simple test to search for infinite recursion
|
||||
// bugs. It isn't perfect, and only catches a subset of the possible bugs
|
||||
// as only the first db call will be made. But it is better than nothing.
|
||||
func TestDBAuthzRecursive(t *testing.T) {
|
||||
t.Parallel()
|
||||
q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{
|
||||
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
|
||||
}, slog.Make())
|
||||
actor := rbac.Subject{
|
||||
ID: uuid.NewString(),
|
||||
Roles: rbac.RoleNames{rbac.RoleOwner()},
|
||||
Groups: []string{},
|
||||
Scope: rbac.ScopeAll,
|
||||
}
|
||||
for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ {
|
||||
var ins []reflect.Value
|
||||
ctx := dbauthz.As(context.Background(), actor)
|
||||
|
||||
ins = append(ins, reflect.ValueOf(ctx))
|
||||
method := reflect.TypeOf(q).Method(i)
|
||||
for i := 2; i < method.Type.NumIn(); i++ {
|
||||
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
|
||||
}
|
||||
if method.Name == "InTx" || method.Name == "Ping" {
|
||||
continue
|
||||
}
|
||||
// Log the name of the last method, so if there is a panic, it is
|
||||
// easy to know which method failed.
|
||||
// t.Log(method.Name)
|
||||
// Call the function. Any infinite recursion will stack overflow.
|
||||
reflect.ValueOf(q).Method(i).Call(ins)
|
||||
}
|
||||
}
|
||||
|
||||
func must[T any](value T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return value
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
// Package dbauthz provides an authorization layer on top of the database. This
|
||||
// package exposes an interface that is currently a 1:1 mapping with
|
||||
// database.Store.
|
||||
//
|
||||
// The same cultural rules apply to this package as they do to database.Store.
|
||||
// Meaning that each method implemented should keep the number of database
|
||||
// queries as close to 1 as possible. Each method should do 1 thing, with no
|
||||
// unexpected side effects (eg: updating multiple tables in a single method).
|
||||
//
|
||||
// Do not implement business logic in this package. Only authorization related
|
||||
// logic should be implemented here. In most cases, this should only be a call to
|
||||
// the rbac authorizer.
|
||||
//
|
||||
// When a new database method is added to database.Store, it should be added to
|
||||
// this package as well. The unit test "Accounting" will ensure all methods are
|
||||
// tested. See other unit tests for examples on how to write these.
|
||||
package dbauthz
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,377 @@
|
|||
package dbauthz_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/coderd/util/slice"
|
||||
)
|
||||
|
||||
var (
|
||||
skipMethods = map[string]string{
|
||||
"InTx": "Not relevant",
|
||||
"Ping": "Not relevant",
|
||||
}
|
||||
)
|
||||
|
||||
// TestMethodTestSuite runs MethodTestSuite.
|
||||
// In order for 'go test' to run this suite, we need to create
|
||||
// a normal test function and pass our suite to suite.Run
|
||||
// nolint: paralleltest
|
||||
func TestMethodTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(MethodTestSuite))
|
||||
}
|
||||
|
||||
// MethodTestSuite runs all methods tests for querier. We use
|
||||
// a test suite so we can account for all functions tested on the querier.
|
||||
// We can then assert all methods were tested and asserted for proper RBAC
|
||||
// checks. This forces RBAC checks to be written for all methods.
|
||||
// Additionally, the way unit tests are written allows for easily executing
|
||||
// a single test for debugging.
|
||||
type MethodTestSuite struct {
|
||||
suite.Suite
|
||||
// methodAccounting counts all methods called by a 'RunMethodTest'
|
||||
methodAccounting map[string]int
|
||||
}
|
||||
|
||||
// SetupSuite sets up the suite by creating a map of all methods on querier
|
||||
// and setting their count to 0.
|
||||
func (s *MethodTestSuite) SetupSuite() {
|
||||
az := dbauthz.New(nil, nil, slog.Make())
|
||||
// Take the underlying type of the interface.
|
||||
azt := reflect.TypeOf(az).Elem()
|
||||
s.methodAccounting = make(map[string]int)
|
||||
for i := 0; i < azt.NumMethod(); i++ {
|
||||
method := azt.Method(i)
|
||||
if _, ok := skipMethods[method.Name]; ok {
|
||||
// We can't use s.T().Skip as this will skip the entire suite.
|
||||
s.T().Logf("Skipping method %q: %s", method.Name, skipMethods[method.Name])
|
||||
continue
|
||||
}
|
||||
s.methodAccounting[method.Name] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// TearDownSuite asserts that all methods were called at least once.
|
||||
func (s *MethodTestSuite) TearDownSuite() {
|
||||
s.Run("Accounting", func() {
|
||||
t := s.T()
|
||||
notCalled := []string{}
|
||||
for m, c := range s.methodAccounting {
|
||||
if c <= 0 {
|
||||
notCalled = append(notCalled, m)
|
||||
}
|
||||
}
|
||||
sort.Strings(notCalled)
|
||||
for _, m := range notCalled {
|
||||
t.Errorf("Method never called: %q", m)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Subtest is a helper function that returns a function that can be passed to
|
||||
// s.Run(). This function will run the test case for the method that is being
|
||||
// tested. The check parameter is used to assert the results of the method.
|
||||
// If the caller does not use the `check` parameter, the test will fail.
|
||||
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
|
||||
return func() {
|
||||
t := s.T()
|
||||
testName := s.T().Name()
|
||||
names := strings.Split(testName, "/")
|
||||
methodName := names[len(names)-1]
|
||||
s.methodAccounting[methodName]++
|
||||
|
||||
db := dbfake.New()
|
||||
fakeAuthorizer := &coderdtest.FakeAuthorizer{
|
||||
AlwaysReturn: nil,
|
||||
}
|
||||
rec := &coderdtest.RecordingAuthorizer{
|
||||
Wrapped: fakeAuthorizer,
|
||||
}
|
||||
az := dbauthz.New(db, rec, slog.Make())
|
||||
actor := rbac.Subject{
|
||||
ID: uuid.NewString(),
|
||||
Roles: rbac.RoleNames{rbac.RoleOwner()},
|
||||
Groups: []string{},
|
||||
Scope: rbac.ScopeAll,
|
||||
}
|
||||
ctx := dbauthz.As(context.Background(), actor)
|
||||
|
||||
var testCase expects
|
||||
testCaseF(db, &testCase)
|
||||
// Check the developer added assertions. If there are no assertions,
|
||||
// an empty list should be passed.
|
||||
s.Require().False(testCase.assertions == nil, "rbac assertions not set, use the 'check' parameter")
|
||||
|
||||
// Find the method with the name of the test.
|
||||
var callMethod func(ctx context.Context) ([]reflect.Value, error)
|
||||
azt := reflect.TypeOf(az)
|
||||
MethodLoop:
|
||||
for i := 0; i < azt.NumMethod(); i++ {
|
||||
method := azt.Method(i)
|
||||
if method.Name == methodName {
|
||||
methodF := reflect.ValueOf(az).Method(i)
|
||||
|
||||
callMethod = func(ctx context.Context) ([]reflect.Value, error) {
|
||||
resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.inputs...))
|
||||
return splitResp(t, resp)
|
||||
}
|
||||
break MethodLoop
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, callMethod, "method %q does not exist", methodName)
|
||||
|
||||
if len(testCase.assertions) > 0 {
|
||||
// Only run these tests if we know the underlying call makes
|
||||
// rbac assertions.
|
||||
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
|
||||
}
|
||||
|
||||
if len(testCase.assertions) > 0 ||
|
||||
slice.Contains([]string{
|
||||
"GetAuthorizedWorkspaces",
|
||||
"GetAuthorizedTemplates",
|
||||
}, methodName) {
|
||||
// Some methods do not make RBAC assertions because they use
|
||||
// SQL. We still want to test that they return an error if the
|
||||
// actor is not set.
|
||||
s.NoActorErrorTest(callMethod)
|
||||
}
|
||||
|
||||
// Always run
|
||||
s.Run("Success", func() {
|
||||
rec.Reset()
|
||||
fakeAuthorizer.AlwaysReturn = nil
|
||||
|
||||
outputs, err := callMethod(ctx)
|
||||
s.NoError(err, "method %q returned an error", methodName)
|
||||
|
||||
// Some tests may not care about the outputs, so we only assert if
|
||||
// they are provided.
|
||||
if testCase.outputs != nil {
|
||||
// Assert the required outputs
|
||||
s.Equal(len(testCase.outputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
|
||||
for i := range outputs {
|
||||
a, b := testCase.outputs[i].Interface(), outputs[i].Interface()
|
||||
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
|
||||
// Order does not matter
|
||||
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
|
||||
} else {
|
||||
s.Equal(a, b, "method %q returned unexpected output %d", methodName, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pairs []coderdtest.ActionObjectPair
|
||||
for _, assrt := range testCase.assertions {
|
||||
for _, action := range assrt.Actions {
|
||||
pairs = append(pairs, coderdtest.ActionObjectPair{
|
||||
Action: action,
|
||||
Object: assrt.Object,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
rec.AssertActor(s.T(), actor, pairs...)
|
||||
s.NoError(rec.AllAsserted(), "all rbac calls must be asserted")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context) ([]reflect.Value, error)) {
|
||||
s.Run("AsRemoveActor", func() {
|
||||
// Call without any actor
|
||||
_, err := callMethod(context.Background())
|
||||
s.ErrorIs(err, dbauthz.NoActorError, "method should return NoActorError error when no actor is provided")
|
||||
})
|
||||
}
|
||||
|
||||
// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz.
|
||||
// Asserts that the error returned is a NotAuthorizedError.
|
||||
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
|
||||
s.Run("NotAuthorized", func() {
|
||||
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
|
||||
|
||||
// If we have assertions, that means the method should FAIL
|
||||
// if RBAC will disallow the request. The returned error should
|
||||
// be expected to be a NotAuthorizedError.
|
||||
resp, err := callMethod(ctx)
|
||||
|
||||
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
|
||||
// any case where the error is nil and the response is an empty slice.
|
||||
if err != nil || !hasEmptySliceResponse(resp) {
|
||||
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
|
||||
s.Errorf(err, "method should an error with disallow authz")
|
||||
s.ErrorIsf(err, sql.ErrNoRows, "error should match sql.ErrNoRows")
|
||||
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func hasEmptySliceResponse(values []reflect.Value) bool {
|
||||
for _, r := range values {
|
||||
if r.Kind() == reflect.Slice || r.Kind() == reflect.Array {
|
||||
if r.Len() == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func splitResp(t *testing.T, values []reflect.Value) ([]reflect.Value, error) {
|
||||
outputs := []reflect.Value{}
|
||||
for _, r := range values {
|
||||
if r.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
if r.IsNil() {
|
||||
// Error is found, but it's nil!
|
||||
return outputs, nil
|
||||
}
|
||||
err, ok := r.Interface().(error)
|
||||
if !ok {
|
||||
t.Fatal("error is not an error?!")
|
||||
}
|
||||
return outputs, err
|
||||
}
|
||||
outputs = append(outputs, r)
|
||||
} //nolint: unreachable
|
||||
t.Fatal("no expected error value found in responses (error can be nil)")
|
||||
return nil, nil // unreachable, required to compile
|
||||
}
|
||||
|
||||
// expects is used to build a test case for a method.
|
||||
// It includes the expected inputs, rbac assertions, and expected outputs.
|
||||
type expects struct {
|
||||
inputs []reflect.Value
|
||||
assertions []AssertRBAC
|
||||
// outputs is optional. Can assert non-error return values.
|
||||
outputs []reflect.Value
|
||||
}
|
||||
|
||||
// Asserts is required. Asserts the RBAC authorize calls that should be made.
|
||||
// If no RBAC calls are expected, pass an empty list: 'm.Asserts()'
|
||||
func (m *expects) Asserts(pairs ...any) *expects {
|
||||
m.assertions = asserts(pairs...)
|
||||
return m
|
||||
}
|
||||
|
||||
// Args is required. The arguments to be provided to the method.
|
||||
// If there are no arguments, pass an empty list: 'm.Args()'
|
||||
// The first context argument should not be included, as the test suite
|
||||
// will provide it.
|
||||
func (m *expects) Args(args ...any) *expects {
|
||||
m.inputs = values(args...)
|
||||
return m
|
||||
}
|
||||
|
||||
// Returns is optional. If it is never called, it will not be asserted.
|
||||
func (m *expects) Returns(rets ...any) *expects {
|
||||
m.outputs = values(rets...)
|
||||
return m
|
||||
}
|
||||
|
||||
// AssertRBAC contains the object and actions to be asserted.
|
||||
type AssertRBAC struct {
|
||||
Object rbac.Object
|
||||
Actions []rbac.Action
|
||||
}
|
||||
|
||||
// values is a convenience method for creating []reflect.Value.
|
||||
//
|
||||
// values(workspace, template, ...)
|
||||
//
|
||||
// is equivalent to
|
||||
//
|
||||
// []reflect.Value{
|
||||
// reflect.ValueOf(workspace),
|
||||
// reflect.ValueOf(template),
|
||||
// ...
|
||||
// }
|
||||
func values(ins ...any) []reflect.Value {
|
||||
out := make([]reflect.Value, 0)
|
||||
for _, input := range ins {
|
||||
input := input
|
||||
out = append(out, reflect.ValueOf(input))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// asserts is a convenience method for creating AssertRBACs.
|
||||
//
|
||||
// The number of inputs must be an even number.
|
||||
// asserts() will panic if this is not the case.
|
||||
//
|
||||
// Even-numbered inputs are the objects, and odd-numbered inputs are the actions.
|
||||
// Objects must implement rbac.Objecter.
|
||||
// Inputs can be a single rbac.Action, or a slice of rbac.Action.
|
||||
//
|
||||
// asserts(workspace, rbac.ActionRead, template, slice(rbac.ActionRead, rbac.ActionWrite), ...)
|
||||
//
|
||||
// is equivalent to
|
||||
//
|
||||
// []AssertRBAC{
|
||||
// {Object: workspace, Actions: []rbac.Action{rbac.ActionRead}},
|
||||
// {Object: template, Actions: []rbac.Action{rbac.ActionRead, rbac.ActionWrite)}},
|
||||
// ...
|
||||
// }
|
||||
func asserts(inputs ...any) []AssertRBAC {
|
||||
if len(inputs)%2 != 0 {
|
||||
panic(fmt.Sprintf("Must be an even length number of args, found %d", len(inputs)))
|
||||
}
|
||||
|
||||
out := make([]AssertRBAC, 0)
|
||||
for i := 0; i < len(inputs); i += 2 {
|
||||
obj, ok := inputs[i].(rbac.Objecter)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("object type '%T' does not implement rbac.Objecter", inputs[i]))
|
||||
}
|
||||
rbacObj := obj.RBACObject()
|
||||
|
||||
var actions []rbac.Action
|
||||
actions, ok = inputs[i+1].([]rbac.Action)
|
||||
if !ok {
|
||||
action, ok := inputs[i+1].(rbac.Action)
|
||||
if !ok {
|
||||
// Could be the string type.
|
||||
actionAsString, ok := inputs[i+1].(string)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("action '%q' not a supported action", actionAsString))
|
||||
}
|
||||
action = rbac.Action(actionAsString)
|
||||
}
|
||||
actions = []rbac.Action{action}
|
||||
}
|
||||
|
||||
out = append(out, AssertRBAC{
|
||||
Object: rbacObj,
|
||||
Actions: actions,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type emptyPreparedAuthorized struct{}
|
||||
|
||||
func (emptyPreparedAuthorized) Authorize(_ context.Context, _ rbac.Object) error { return nil }
|
||||
func (emptyPreparedAuthorized) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
|
@ -0,0 +1,194 @@
|
|||
package dbauthz
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
)
|
||||
|
||||
// TODO: All these system functions should have rbac objects created to allow
|
||||
// only system roles to call them. No user roles should ever have the permission
|
||||
// to these objects. Might need a negative permission on the `Owner` role to
|
||||
// prevent owners.
|
||||
|
||||
func (q *querier) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) {
|
||||
return q.db.UpdateUserLinkedID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
|
||||
return q.db.GetUserLinkByLinkedID(ctx, linkedID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
|
||||
return q.db.GetUserLinkByUserIDLoginType(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetLatestWorkspaceBuilds(ctx context.Context) ([]database.WorkspaceBuild, error) {
|
||||
// This function is a system function until we implement a join for workspace builds.
|
||||
// This is because we need to query for all related workspaces to the returned builds.
|
||||
// This is a very inefficient method of fetching the latest workspace builds.
|
||||
// We should just join the rbac properties.
|
||||
return q.db.GetLatestWorkspaceBuilds(ctx)
|
||||
}
|
||||
|
||||
// GetWorkspaceAgentByAuthToken is used in http middleware to get the workspace agent.
|
||||
// This should only be used by a system user in that middleware.
|
||||
func (q *querier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) {
|
||||
return q.db.GetWorkspaceAgentByAuthToken(ctx, authToken)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) {
|
||||
return q.db.GetActiveUserCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, error) {
|
||||
return q.db.GetUnexpiredLicenses(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) {
|
||||
return q.db.GetAuthorizationUserRoles(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
// TODO Implement authz check for system user.
|
||||
return q.db.GetDERPMeshKey(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error {
|
||||
// TODO Implement authz check for system user.
|
||||
return q.db.InsertDERPMeshKey(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) InsertDeploymentID(ctx context.Context, value string) error {
|
||||
// TODO Implement authz check for system user.
|
||||
return q.db.InsertDeploymentID(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) InsertReplica(ctx context.Context, arg database.InsertReplicaParams) (database.Replica, error) {
|
||||
// TODO Implement authz check for system user.
|
||||
return q.db.InsertReplica(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaParams) (database.Replica, error) {
|
||||
// TODO Implement authz check for system user.
|
||||
return q.db.UpdateReplica(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error {
|
||||
// TODO Implement authz check for system user.
|
||||
return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt)
|
||||
}
|
||||
|
||||
func (q *querier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.Replica, error) {
|
||||
// TODO Implement authz check for system user.
|
||||
return q.db.GetReplicasUpdatedAfter(ctx, updatedAt)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserCount(ctx context.Context) (int64, error) {
|
||||
return q.db.GetUserCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetTemplates(ctx context.Context) ([]database.Template, error) {
|
||||
// TODO Implement authz check for system user.
|
||||
return q.db.GetTemplates(ctx)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceBuildCostByID is used by the provisioning system to update the cost of a workspace build.
|
||||
func (q *querier) UpdateWorkspaceBuildCostByID(ctx context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) (database.WorkspaceBuild, error) {
|
||||
return q.db.UpdateWorkspaceBuildCostByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertOrUpdateLastUpdateCheck(ctx context.Context, value string) error {
|
||||
return q.db.InsertOrUpdateLastUpdateCheck(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) {
|
||||
return q.db.GetLastUpdateCheck(ctx)
|
||||
}
|
||||
|
||||
// Telemetry related functions. These functions are system functions for returning
|
||||
// telemetry data. Never called by a user.
|
||||
|
||||
func (q *querier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceBuild, error) {
|
||||
return q.db.GetWorkspaceBuildsCreatedAfter(ctx, createdAt)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
|
||||
return q.db.GetWorkspaceAgentsCreatedAfter(ctx, createdAt)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceAppsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceApp, error) {
|
||||
return q.db.GetWorkspaceAppsCreatedAfter(ctx, createdAt)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceResourcesCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResource, error) {
|
||||
return q.db.GetWorkspaceResourcesCreatedAfter(ctx, createdAt)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResourceMetadatum, error) {
|
||||
return q.db.GetWorkspaceResourceMetadataCreatedAfter(ctx, createdAt)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldAgentStats(ctx context.Context) error {
|
||||
return q.db.DeleteOldAgentStats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetParameterSchemasCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ParameterSchema, error) {
|
||||
return q.db.GetParameterSchemasCreatedAfter(ctx, createdAt)
|
||||
}
|
||||
func (q *querier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ProvisionerJob, error) {
|
||||
return q.db.GetProvisionerJobsCreatedAfter(ctx, createdAt)
|
||||
}
|
||||
|
||||
// Provisionerd server functions
|
||||
|
||||
func (q *querier) InsertWorkspaceAgent(ctx context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) {
|
||||
return q.db.InsertWorkspaceAgent(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertWorkspaceApp(ctx context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) {
|
||||
return q.db.InsertWorkspaceApp(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) {
|
||||
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) {
|
||||
return q.db.AcquireProvisionerJob(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error {
|
||||
return q.db.UpdateProvisionerJobWithCompleteByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
|
||||
return q.db.UpdateProvisionerJobByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertProvisionerJob(ctx context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) {
|
||||
return q.db.InsertProvisionerJob(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertProvisionerJobLogs(ctx context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) {
|
||||
return q.db.InsertProvisionerJobLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertProvisionerDaemon(ctx context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) {
|
||||
return q.db.InsertProvisionerDaemon(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertTemplateVersionParameter(ctx context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) {
|
||||
return q.db.InsertTemplateVersionParameter(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertWorkspaceResource(ctx context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) {
|
||||
return q.db.InsertWorkspaceResource(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertParameterSchema(ctx context.Context, arg database.InsertParameterSchemaParams) (database.ParameterSchema, error) {
|
||||
return q.db.InsertParameterSchema(ctx, arg)
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
package dbauthz_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbgen"
|
||||
)
|
||||
|
||||
func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
s.Run("UpdateUserLinkedID", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
l := dbgen.UserLink(s.T(), db, database.UserLink{UserID: u.ID})
|
||||
check.Args(database.UpdateUserLinkedIDParams{
|
||||
UserID: u.ID,
|
||||
LinkedID: l.LinkedID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
}).Asserts().Returns(l)
|
||||
}))
|
||||
s.Run("GetUserLinkByLinkedID", s.Subtest(func(db database.Store, check *expects) {
|
||||
l := dbgen.UserLink(s.T(), db, database.UserLink{})
|
||||
check.Args(l.LinkedID).Asserts().Returns(l)
|
||||
}))
|
||||
s.Run("GetUserLinkByUserIDLoginType", s.Subtest(func(db database.Store, check *expects) {
|
||||
l := dbgen.UserLink(s.T(), db, database.UserLink{})
|
||||
check.Args(database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: l.UserID,
|
||||
LoginType: l.LoginType,
|
||||
}).Asserts().Returns(l)
|
||||
}))
|
||||
s.Run("GetLatestWorkspaceBuilds", s.Subtest(func(db database.Store, check *expects) {
|
||||
dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
|
||||
dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetWorkspaceAgentByAuthToken", s.Subtest(func(db database.Store, check *expects) {
|
||||
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{})
|
||||
check.Args(agt.AuthToken).Asserts().Returns(agt)
|
||||
}))
|
||||
s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args().Asserts().Returns(int64(0))
|
||||
}))
|
||||
s.Run("GetUnexpiredLicenses", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetAuthorizationUserRoles", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
check.Args(u.ID).Asserts()
|
||||
}))
|
||||
s.Run("GetDERPMeshKey", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("InsertDERPMeshKey", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args("value").Asserts().Returns()
|
||||
}))
|
||||
s.Run("InsertDeploymentID", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args("value").Asserts().Returns()
|
||||
}))
|
||||
s.Run("InsertReplica", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertReplicaParams{
|
||||
ID: uuid.New(),
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("UpdateReplica", s.Subtest(func(db database.Store, check *expects) {
|
||||
replica, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New()})
|
||||
require.NoError(s.T(), err)
|
||||
check.Args(database.UpdateReplicaParams{
|
||||
ID: replica.ID,
|
||||
DatabaseLatency: 100,
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("DeleteReplicasUpdatedBefore", s.Subtest(func(db database.Store, check *expects) {
|
||||
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()})
|
||||
require.NoError(s.T(), err)
|
||||
check.Args(time.Now().Add(time.Hour)).Asserts()
|
||||
}))
|
||||
s.Run("GetReplicasUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()})
|
||||
require.NoError(s.T(), err)
|
||||
check.Args(time.Now().Add(time.Hour * -1)).Asserts()
|
||||
}))
|
||||
s.Run("GetUserCount", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args().Asserts().Returns(int64(0))
|
||||
}))
|
||||
s.Run("GetTemplates", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.Template(s.T(), db, database.Template{})
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceBuildCostByID", s.Subtest(func(db database.Store, check *expects) {
|
||||
b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
|
||||
o := b
|
||||
o.DailyCost = 10
|
||||
check.Args(database.UpdateWorkspaceBuildCostByIDParams{
|
||||
ID: b.ID,
|
||||
DailyCost: 10,
|
||||
}).Asserts().Returns(o)
|
||||
}))
|
||||
s.Run("InsertOrUpdateLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args("value").Asserts()
|
||||
}))
|
||||
s.Run("GetLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) {
|
||||
err := db.InsertOrUpdateLastUpdateCheck(context.Background(), "value")
|
||||
require.NoError(s.T(), err)
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetWorkspaceBuildsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{CreatedAt: time.Now().Add(-time.Hour)})
|
||||
check.Args(time.Now()).Asserts()
|
||||
}))
|
||||
s.Run("GetWorkspaceAgentsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{CreatedAt: time.Now().Add(-time.Hour)})
|
||||
check.Args(time.Now()).Asserts()
|
||||
}))
|
||||
s.Run("GetWorkspaceAppsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{CreatedAt: time.Now().Add(-time.Hour)})
|
||||
check.Args(time.Now()).Asserts()
|
||||
}))
|
||||
s.Run("GetWorkspaceResourcesCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{CreatedAt: time.Now().Add(-time.Hour)})
|
||||
check.Args(time.Now()).Asserts()
|
||||
}))
|
||||
s.Run("GetWorkspaceResourceMetadataCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.WorkspaceResourceMetadatums(s.T(), db, database.WorkspaceResourceMetadatum{})
|
||||
check.Args(time.Now()).Asserts()
|
||||
}))
|
||||
s.Run("DeleteOldAgentStats", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetParameterSchemasCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.ParameterSchema(s.T(), db, database.ParameterSchema{CreatedAt: time.Now().Add(-time.Hour)})
|
||||
check.Args(time.Now()).Asserts()
|
||||
}))
|
||||
s.Run("GetProvisionerJobsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)})
|
||||
check.Args(time.Now()).Asserts()
|
||||
}))
|
||||
s.Run("InsertWorkspaceAgent", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertWorkspaceAgentParams{
|
||||
ID: uuid.New(),
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("InsertWorkspaceApp", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertWorkspaceAppParams{
|
||||
ID: uuid.New(),
|
||||
Health: database.WorkspaceAppHealthDisabled,
|
||||
SharingLevel: database.AppSharingLevelOwner,
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("InsertWorkspaceResourceMetadata", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertWorkspaceResourceMetadataParams{
|
||||
WorkspaceResourceID: uuid.New(),
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("AcquireProvisionerJob", s.Subtest(func(db database.Store, check *expects) {
|
||||
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
|
||||
StartedAt: sql.NullTime{Valid: false},
|
||||
})
|
||||
check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}}).
|
||||
Asserts()
|
||||
}))
|
||||
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {
|
||||
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
|
||||
check.Args(database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: j.ID,
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("UpdateProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) {
|
||||
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
|
||||
check.Args(database.UpdateProvisionerJobByIDParams{
|
||||
ID: j.ID,
|
||||
UpdatedAt: time.Now(),
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("InsertProvisionerJob", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("InsertProvisionerJobLogs", s.Subtest(func(db database.Store, check *expects) {
|
||||
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
|
||||
check.Args(database.InsertProvisionerJobLogsParams{
|
||||
JobID: j.ID,
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("InsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertProvisionerDaemonParams{
|
||||
ID: uuid.New(),
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("InsertTemplateVersionParameter", s.Subtest(func(db database.Store, check *expects) {
|
||||
v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{})
|
||||
check.Args(database.InsertTemplateVersionParameterParams{
|
||||
TemplateVersionID: v.ID,
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("InsertWorkspaceResource", s.Subtest(func(db database.Store, check *expects) {
|
||||
r := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{})
|
||||
check.Args(database.InsertWorkspaceResourceParams{
|
||||
ID: r.ID,
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).Asserts()
|
||||
}))
|
||||
s.Run("InsertParameterSchema", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertParameterSchemaParams{
|
||||
ID: uuid.New(),
|
||||
DefaultSourceScheme: database.ParameterSourceSchemeNone,
|
||||
DefaultDestinationScheme: database.ParameterDestinationSchemeNone,
|
||||
ValidationTypeSystem: database.ParameterTypeSystemNone,
|
||||
}).Asserts()
|
||||
}))
|
||||
}
|
|
@ -614,6 +614,14 @@ func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params databas
|
|||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
if prepared != nil {
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
users := make([]database.User, 0, len(q.users))
|
||||
|
||||
for _, user := range q.users {
|
||||
|
@ -892,6 +900,14 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
|||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
if prepared != nil {
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
workspaces := make([]database.Workspace, 0)
|
||||
for _, workspace := range q.workspaces {
|
||||
if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID {
|
||||
|
@ -1230,6 +1246,23 @@ func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg databa
|
|||
return database.Workspace{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) {
|
||||
if err := validateDatabaseType(workspaceAppID); err != nil {
|
||||
return database.Workspace{}, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, workspaceApp := range q.workspaceApps {
|
||||
workspaceApp := workspaceApp
|
||||
if workspaceApp.ID == workspaceAppID {
|
||||
return q.GetWorkspaceByAgentID(context.Background(), workspaceApp.AgentID)
|
||||
}
|
||||
}
|
||||
return database.Workspace{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
@ -1646,6 +1679,14 @@ func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.G
|
|||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
if prepared != nil {
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var templates []database.Template
|
||||
for _, template := range q.templates {
|
||||
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
|
||||
|
@ -3819,6 +3860,18 @@ func (q *fakeQuerier) InsertLicense(
|
|||
return l, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetLicenseByID(_ context.Context, id int32) (database.License, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, license := range q.licenses {
|
||||
if license.ID == id {
|
||||
return license, nil
|
||||
}
|
||||
}
|
||||
return database.License{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
|
|
@ -66,7 +66,7 @@ func Template(t testing.TB, db database.Store, seed database.Template) database.
|
|||
UserACL: seed.UserACL,
|
||||
GroupACL: seed.GroupACL,
|
||||
DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)),
|
||||
AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true),
|
||||
AllowUserCancelWorkspaceJobs: seed.AllowUserCancelWorkspaceJobs,
|
||||
})
|
||||
require.NoError(t, err, "insert template")
|
||||
return template
|
||||
|
@ -369,11 +369,8 @@ func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) dat
|
|||
|
||||
func TemplateVersion(t testing.TB, db database.Store, orig database.TemplateVersion) database.TemplateVersion {
|
||||
version, err := db.InsertTemplateVersion(context.Background(), database.InsertTemplateVersionParams{
|
||||
ID: takeFirst(orig.ID, uuid.New()),
|
||||
TemplateID: uuid.NullUUID{
|
||||
UUID: takeFirst(orig.TemplateID.UUID, uuid.New()),
|
||||
Valid: takeFirst(orig.TemplateID.Valid, true),
|
||||
},
|
||||
ID: takeFirst(orig.ID, uuid.New()),
|
||||
TemplateID: orig.TemplateID,
|
||||
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
|
||||
CreatedAt: takeFirst(orig.CreatedAt, database.Now()),
|
||||
UpdatedAt: takeFirst(orig.UpdatedAt, database.Now()),
|
||||
|
|
|
@ -68,7 +68,7 @@ func TestGenerator(t *testing.T) {
|
|||
require.Equal(t, exp, must(db.GetWorkspaceAppsByAgentID(context.Background(), exp.AgentID))[0])
|
||||
})
|
||||
|
||||
t.Run("WorkspaceResourceMetadatum", func(t *testing.T) {
|
||||
t.Run("WorkspaceResourceMetadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := dbfake.New()
|
||||
exp := dbgen.WorkspaceResourceMetadatums(t, db, database.WorkspaceResourceMetadatum{})
|
||||
|
|
|
@ -2,6 +2,7 @@ package database
|
|||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
)
|
||||
|
@ -63,6 +64,11 @@ func (TemplateVersion) RBACObject(template Template) rbac.Object {
|
|||
return template.RBACObject()
|
||||
}
|
||||
|
||||
// RBACObjectNoTemplate is for orphaned template versions.
|
||||
func (v TemplateVersion) RBACObjectNoTemplate() rbac.Object {
|
||||
return rbac.ResourceTemplate.InOrg(v.OrganizationID)
|
||||
}
|
||||
|
||||
func (g Group) RBACObject() rbac.Object {
|
||||
return rbac.ResourceGroup.WithID(g.ID).
|
||||
InOrg(g.OrganizationID)
|
||||
|
@ -94,6 +100,13 @@ func (m OrganizationMember) RBACObject() rbac.Object {
|
|||
InOrg(m.OrganizationID)
|
||||
}
|
||||
|
||||
func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object {
|
||||
// TODO: This feels incorrect as we are really returning a list of orgmembers.
|
||||
// This return type should be refactored to return a list of orgmembers, not this
|
||||
// special type.
|
||||
return rbac.ResourceUser.WithID(m.UserID)
|
||||
}
|
||||
|
||||
func (o Organization) RBACObject() rbac.Object {
|
||||
return rbac.ResourceOrganization.
|
||||
WithID(o.ID).
|
||||
|
@ -118,11 +131,29 @@ func (u User) RBACObject() rbac.Object {
|
|||
}
|
||||
|
||||
func (u User) UserDataRBACObject() rbac.Object {
|
||||
return rbac.ResourceUser.WithID(u.ID).WithOwner(u.ID.String())
|
||||
return rbac.ResourceUserData.WithID(u.ID).WithOwner(u.ID.String())
|
||||
}
|
||||
|
||||
func (License) RBACObject() rbac.Object {
|
||||
return rbac.ResourceLicense
|
||||
func (u GetUsersRow) RBACObject() rbac.Object {
|
||||
return rbac.ResourceUser.WithID(u.ID)
|
||||
}
|
||||
|
||||
func (u GitSSHKey) RBACObject() rbac.Object {
|
||||
return rbac.ResourceUserData.WithID(u.UserID).WithOwner(u.UserID.String())
|
||||
}
|
||||
|
||||
func (u GitAuthLink) RBACObject() rbac.Object {
|
||||
// I assume UserData is ok?
|
||||
return rbac.ResourceUserData.WithID(u.UserID).WithOwner(u.UserID.String())
|
||||
}
|
||||
|
||||
func (u UserLink) RBACObject() rbac.Object {
|
||||
// I assume UserData is ok?
|
||||
return rbac.ResourceUserData.WithOwner(u.UserID.String()).WithID(u.UserID)
|
||||
}
|
||||
|
||||
func (l License) RBACObject() rbac.Object {
|
||||
return rbac.ResourceLicense.WithIDString(strconv.FormatInt(int64(l.ID), 10))
|
||||
}
|
||||
|
||||
func ConvertUserRows(rows []GetUsersRow) []User {
|
||||
|
|
|
@ -56,6 +56,7 @@ type sqlcQuerier interface {
|
|||
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
|
||||
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
|
||||
GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error)
|
||||
GetLicenseByID(ctx context.Context, id int32) (License, error)
|
||||
GetLicenses(ctx context.Context) ([]License, error)
|
||||
GetLogoURL(ctx context.Context) (string, error)
|
||||
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
|
||||
|
@ -121,6 +122,7 @@ type sqlcQuerier interface {
|
|||
GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (Workspace, error)
|
||||
GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Workspace, error)
|
||||
GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWorkspaceByOwnerIDAndNameParams) (Workspace, error)
|
||||
GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (Workspace, error)
|
||||
GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (WorkspaceResource, error)
|
||||
GetWorkspaceResourceMetadataByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceResourceMetadatum, error)
|
||||
GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceResourceMetadatum, error)
|
||||
|
|
|
@ -1343,6 +1343,30 @@ func (q *sqlQuerier) DeleteLicense(ctx context.Context, id int32) (int32, error)
|
|||
return id, err
|
||||
}
|
||||
|
||||
const getLicenseByID = `-- name: GetLicenseByID :one
|
||||
SELECT
|
||||
id, uploaded_at, jwt, exp, uuid
|
||||
FROM
|
||||
licenses
|
||||
WHERE
|
||||
id = $1
|
||||
LIMIT
|
||||
1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetLicenseByID(ctx context.Context, id int32) (License, error) {
|
||||
row := q.db.QueryRowContext(ctx, getLicenseByID, id)
|
||||
var i License
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.UploadedAt,
|
||||
&i.JWT,
|
||||
&i.Exp,
|
||||
&i.UUID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getLicenses = `-- name: GetLicenses :many
|
||||
SELECT id, uploaded_at, jwt, exp, uuid
|
||||
FROM licenses
|
||||
|
@ -6513,6 +6537,62 @@ func (q *sqlQuerier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWo
|
|||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceByWorkspaceAppID = `-- name: GetWorkspaceByWorkspaceAppID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at
|
||||
FROM
|
||||
workspaces
|
||||
WHERE
|
||||
workspaces.id = (
|
||||
SELECT
|
||||
workspace_id
|
||||
FROM
|
||||
workspace_builds
|
||||
WHERE
|
||||
workspace_builds.job_id = (
|
||||
SELECT
|
||||
job_id
|
||||
FROM
|
||||
workspace_resources
|
||||
WHERE
|
||||
workspace_resources.id = (
|
||||
SELECT
|
||||
resource_id
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
workspace_agents.id = (
|
||||
SELECT
|
||||
agent_id
|
||||
FROM
|
||||
workspace_apps
|
||||
WHERE
|
||||
workspace_apps.id = $1
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (Workspace, error) {
|
||||
row := q.db.QueryRowContext(ctx, getWorkspaceByWorkspaceAppID, workspaceAppID)
|
||||
var i Workspace
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.TemplateID,
|
||||
&i.Deleted,
|
||||
&i.Name,
|
||||
&i.AutostartSchedule,
|
||||
&i.Ttl,
|
||||
&i.LastUsedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaces = `-- name: GetWorkspaces :many
|
||||
SELECT
|
||||
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, COUNT(*) OVER () as count
|
||||
|
|
|
@ -14,6 +14,16 @@ SELECT *
|
|||
FROM licenses
|
||||
ORDER BY (id);
|
||||
|
||||
-- name: GetLicenseByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
licenses
|
||||
WHERE
|
||||
id = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetUnexpiredLicenses :many
|
||||
SELECT *
|
||||
FROM licenses
|
||||
|
|
|
@ -8,6 +8,42 @@ WHERE
|
|||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetWorkspaceByWorkspaceAppID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspaces
|
||||
WHERE
|
||||
workspaces.id = (
|
||||
SELECT
|
||||
workspace_id
|
||||
FROM
|
||||
workspace_builds
|
||||
WHERE
|
||||
workspace_builds.job_id = (
|
||||
SELECT
|
||||
job_id
|
||||
FROM
|
||||
workspace_resources
|
||||
WHERE
|
||||
workspace_resources.id = (
|
||||
SELECT
|
||||
resource_id
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
workspace_agents.id = (
|
||||
SELECT
|
||||
agent_id
|
||||
FROM
|
||||
workspace_apps
|
||||
WHERE
|
||||
workspace_apps.id = @workspace_app_id
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
-- name: GetWorkspaceByAgentID :one
|
||||
SELECT
|
||||
*
|
||||
|
|
|
@ -76,7 +76,14 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) {
|
|||
ID: file.ID,
|
||||
})
|
||||
return
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error getting file.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
id := uuid.New()
|
||||
file, err = api.Database.InsertFile(ctx, database.InsertFileParams{
|
||||
ID: id,
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
@ -159,7 +160,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
key, err := cfg.DB.GetAPIKeyByID(r.Context(), keyID)
|
||||
//nolint:gocritic // System needs to fetch API key to check if it's valid.
|
||||
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystem(ctx), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
|
@ -192,7 +194,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
|||
changed = false
|
||||
)
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
link, err = cfg.DB.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystem(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
|
@ -275,7 +278,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
|||
}
|
||||
}
|
||||
if changed {
|
||||
err := cfg.DB.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
//nolint:gocritic // System needs to update API Key LastUsed
|
||||
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystem(ctx), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
|
@ -291,7 +295,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
|||
// If the API Key is associated with a user_link (e.g. Github/OIDC)
|
||||
// then we want to update the relevant oauth fields.
|
||||
if link.UserID != uuid.Nil {
|
||||
link, err = cfg.DB.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
|
||||
// nolint:gocritic
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystem(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
|
@ -310,7 +315,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
|||
// We only want to update this occasionally to reduce DB write
|
||||
// load. We update alongside the UserLink and APIKey since it's
|
||||
// easier on the DB to colocate writes.
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(ctx, database.UpdateUserLastSeenAtParams{
|
||||
// nolint:gocritic
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystem(ctx), database.UpdateUserLastSeenAtParams{
|
||||
ID: key.UserID,
|
||||
LastSeenAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
|
@ -327,7 +333,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
|||
// If the key is valid, we also fetch the user roles and status.
|
||||
// The roles are used for RBAC authorize checks, and the status
|
||||
// is to block 'suspended' users from accessing the platform.
|
||||
roles, err := cfg.DB.GetAuthorizationUserRoles(r.Context(), key.UserID)
|
||||
// nolint:gocritic
|
||||
roles, err := cfg.DB.GetAuthorizationUserRoles(dbauthz.AsSystem(ctx), key.UserID)
|
||||
if err != nil {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
|
@ -343,16 +350,20 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
// Actor is the user's authorization context.
|
||||
actor := rbac.Subject{
|
||||
ID: key.UserID.String(),
|
||||
Roles: rbac.RoleNames(roles.Roles),
|
||||
Groups: roles.Groups,
|
||||
Scope: rbac.ScopeName(key.Scope),
|
||||
}
|
||||
ctx = context.WithValue(ctx, apiKeyContextKey{}, key)
|
||||
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
|
||||
Username: roles.Username,
|
||||
Actor: rbac.Subject{
|
||||
ID: key.UserID.String(),
|
||||
Roles: rbac.RoleNames(roles.Roles),
|
||||
Groups: roles.Groups,
|
||||
Scope: rbac.ScopeName(key.Scope),
|
||||
},
|
||||
Actor: actor,
|
||||
})
|
||||
// Set the auth context for the authzquerier as well.
|
||||
ctx = dbauthz.As(ctx, actor)
|
||||
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package httpmw
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// AsAuthzSystem is a chained handler that temporarily sets the dbauthz context
|
||||
// to System for the inner handlers, and resets the context afterwards.
|
||||
//
|
||||
// TODO: Refactor the middleware functions to not require this.
|
||||
// This is a bit of a kludge for now as some middleware functions require
|
||||
// usage as a system user in some cases, but not all cases. To avoid large
|
||||
// refactors, we use this middleware to temporarily set the context to a system.
|
||||
func AsAuthzSystem(mws ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
|
||||
chain := chi.Chain(mws...)
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
before, beforeExists := dbauthz.ActorFromContext(r.Context())
|
||||
if !beforeExists {
|
||||
// AsRemoveActor will actually remove the actor from the context.
|
||||
before = dbauthz.AsRemoveActor
|
||||
}
|
||||
|
||||
// nolint:gocritic // AsAuthzSystem needs to do this.
|
||||
r = r.WithContext(dbauthz.AsSystem(ctx))
|
||||
chain.Handler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
r = r.WithContext(dbauthz.As(r.Context(), before))
|
||||
next.ServeHTTP(rw, r)
|
||||
})).ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package httpmw_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
func TestAsAuthzSystem(t *testing.T) {
|
||||
t.Parallel()
|
||||
userActor := coderdtest.RandomRBACSubject()
|
||||
|
||||
base := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
actor, ok := dbauthz.ActorFromContext(r.Context())
|
||||
assert.True(t, ok, "actor should exist")
|
||||
assert.True(t, userActor.Equal(actor), "actor should be the user actor")
|
||||
})
|
||||
|
||||
mwSetUser := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
r = r.WithContext(dbauthz.As(r.Context(), userActor))
|
||||
next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
|
||||
mwAssertSystem := mwAssert(func(req *http.Request) {
|
||||
actor, ok := dbauthz.ActorFromContext(req.Context())
|
||||
assert.True(t, ok, "actor should exist")
|
||||
assert.False(t, userActor.Equal(actor), "systemActor should not be the user actor")
|
||||
assert.Contains(t, actor.Roles.Names(), "system", "should have system role")
|
||||
})
|
||||
|
||||
mwAssertUser := mwAssert(func(req *http.Request) {
|
||||
actor, ok := dbauthz.ActorFromContext(req.Context())
|
||||
assert.True(t, ok, "actor should exist")
|
||||
assert.True(t, userActor.Equal(actor), "should be the useractor")
|
||||
})
|
||||
|
||||
mwAssertNoUser := mwAssert(func(req *http.Request) {
|
||||
_, ok := dbauthz.ActorFromContext(req.Context())
|
||||
assert.False(t, ok, "actor should not exist")
|
||||
})
|
||||
|
||||
// Request as the user actor
|
||||
const pattern = "/"
|
||||
req := httptest.NewRequest("GET", pattern, nil)
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler := chi.NewRouter()
|
||||
handler.Route(pattern, func(r chi.Router) {
|
||||
r.Use(
|
||||
// First assert there is no actor context
|
||||
mwAssertNoUser,
|
||||
httpmw.AsAuthzSystem(
|
||||
// Assert the system actor
|
||||
mwAssertSystem,
|
||||
mwAssertSystem,
|
||||
),
|
||||
// Assert no user present outside of the AsAuthzSystem chain
|
||||
mwAssertNoUser,
|
||||
// ----
|
||||
// Set to the user actor
|
||||
mwSetUser,
|
||||
// Assert the user actor
|
||||
mwAssertUser,
|
||||
httpmw.AsAuthzSystem(
|
||||
// Assert the system actor
|
||||
mwAssertSystem,
|
||||
mwAssertSystem,
|
||||
),
|
||||
// Check the user actor was returned to the context
|
||||
mwAssertUser,
|
||||
)
|
||||
r.Handle("/", base)
|
||||
r.NotFound(func(writer http.ResponseWriter, request *http.Request) {
|
||||
assert.Fail(t, "should not hit not found, the route should be correct")
|
||||
})
|
||||
})
|
||||
|
||||
handler.ServeHTTP(res, req)
|
||||
}
|
||||
|
||||
func mwAssert(assertF func(req *http.Request)) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
assertF(r)
|
||||
next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
@ -68,7 +69,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
|
|||
})
|
||||
return
|
||||
}
|
||||
user, err = db.GetUserByID(ctx, apiKey.UserID)
|
||||
//nolint:gocritic // System needs to be able to get user from param.
|
||||
user, err = db.GetUserByID(dbauthz.AsSystem(ctx), apiKey.UserID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
|
@ -81,8 +83,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
|
|||
return
|
||||
}
|
||||
} else if userID, err := uuid.Parse(userQuery); err == nil {
|
||||
// If the userQuery is a valid uuid
|
||||
user, err = db.GetUserByID(ctx, userID)
|
||||
//nolint:gocritic // If the userQuery is a valid uuid
|
||||
user, err = db.GetUserByID(dbauthz.AsSystem(ctx), userID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: userErrorMessage,
|
||||
|
@ -90,8 +92,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
|
|||
return
|
||||
}
|
||||
} else {
|
||||
// Try as a username last
|
||||
user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
|
||||
// nolint:gocritic // Try as a username last
|
||||
user, err = db.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
|
||||
Username: userQuery,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -10,7 +10,9 @@ import (
|
|||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
|
@ -45,7 +47,8 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
|
|||
})
|
||||
return
|
||||
}
|
||||
agent, err := db.GetWorkspaceAgentByAuthToken(ctx, token)
|
||||
//nolint:gocritic // System needs to be able to get workspace agents.
|
||||
agent, err := db.GetWorkspaceAgentByAuthToken(dbauthz.AsSystem(ctx), token)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
|
||||
|
@ -62,8 +65,50 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic // System needs to be able to get workspace agents.
|
||||
subject, err := getAgentSubject(dbauthz.AsSystem(ctx), db, agent)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace agent.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, workspaceAgentContextKey{}, agent)
|
||||
// Also set the dbauthz actor for the request.
|
||||
ctx = dbauthz.As(ctx, subject)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getAgentSubject(ctx context.Context, db database.Store, agent database.WorkspaceAgent) (rbac.Subject, error) {
|
||||
// TODO: make a different query that gets the workspace owner and roles along with the agent.
|
||||
workspace, err := db.GetWorkspaceByAgentID(ctx, agent.ID)
|
||||
if err != nil {
|
||||
return rbac.Subject{}, err
|
||||
}
|
||||
|
||||
user, err := db.GetUserByID(ctx, workspace.OwnerID)
|
||||
if err != nil {
|
||||
return rbac.Subject{}, err
|
||||
}
|
||||
|
||||
roles, err := db.GetAuthorizationUserRoles(ctx, user.ID)
|
||||
if err != nil {
|
||||
return rbac.Subject{}, err
|
||||
}
|
||||
|
||||
// A user that creates a workspace can use this agent auth token and
|
||||
// impersonate the workspace. So to prevent privilege escalation, the
|
||||
// subject inherits the roles of the user that owns the workspace.
|
||||
// We then add a workspace-agent scope to limit the permissions
|
||||
// to only what the workspace agent needs.
|
||||
return rbac.Subject{
|
||||
ID: user.ID.String(),
|
||||
Roles: rbac.RoleNames(roles.Roles),
|
||||
Groups: roles.Groups,
|
||||
Scope: rbac.WorkspaceAgentScope(workspace.ID, user.ID),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -19,11 +19,10 @@ import (
|
|||
func TestWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setup := func(db database.Store) (*http.Request, uuid.UUID) {
|
||||
token := uuid.New()
|
||||
setup := func(db database.Store, token uuid.UUID) *http.Request {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.Header.Set(codersdk.SessionTokenHeader, token.String())
|
||||
return r, token
|
||||
return r
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
|
@ -34,7 +33,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
httpmw.ExtractWorkspaceAgent(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setup(db)
|
||||
r := setup(db, uuid.New())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
|
@ -46,6 +45,24 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := dbfake.New()
|
||||
var (
|
||||
user = dbgen.User(t, db, database.User{})
|
||||
workspace = dbgen.Workspace(t, db, database.Workspace{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{})
|
||||
resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
JobID: job.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: resource.ID,
|
||||
})
|
||||
)
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractWorkspaceAgent(db),
|
||||
|
@ -54,10 +71,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
_ = httpmw.WorkspaceAgent(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
r, token := setup(db)
|
||||
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
AuthToken: token,
|
||||
})
|
||||
r := setup(db, agent.AuthToken)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
|
|
|
@ -55,20 +55,20 @@ func (api *API) putMemberRoles(rw http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Assigning a role requires the create permission.
|
||||
if len(added) > 0 && !api.Authorize(r, rbac.ActionCreate, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) {
|
||||
httpapi.Forbidden(rw)
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
// Removing a role requires the delete permission.
|
||||
if len(removed) > 0 && !api.Authorize(r, rbac.ActionDelete, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) {
|
||||
httpapi.Forbidden(rw)
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
// Just treat adding & removing as "assigning" for now.
|
||||
for _, roleName := range append(added, removed...) {
|
||||
if !rbac.CanAssignRole(actorRoles.Actor.Roles, roleName) {
|
||||
httpapi.Forbidden(rw)
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
@ -142,6 +143,8 @@ func countUniqueUsers(rows []database.GetTemplateDAUsRow) int {
|
|||
}
|
||||
|
||||
func (c *Cache) refresh(ctx context.Context) error {
|
||||
//nolint:gocritic // This is a system service.
|
||||
ctx = dbauthz.AsSystem(ctx)
|
||||
err := c.database.DeleteOldAgentStats(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete old stats: %w", err)
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/parameter"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
@ -56,6 +57,8 @@ type Server struct {
|
|||
|
||||
// AcquireJob queries the database to lock a job.
|
||||
func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
//nolint:gocritic //TODO: make a provisionerd role
|
||||
ctx = dbauthz.AsSystem(ctx)
|
||||
// This prevents loads of provisioner daemons from consistently
|
||||
// querying the database when no jobs are available.
|
||||
//
|
||||
|
@ -270,6 +273,8 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
}
|
||||
|
||||
func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
|
||||
//nolint:gocritic //TODO: make a provisionerd role
|
||||
ctx = dbauthz.AsSystem(ctx)
|
||||
jobID, err := uuid.Parse(request.JobId)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse job id: %w", err)
|
||||
|
@ -299,6 +304,8 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot
|
|||
}
|
||||
|
||||
func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
|
||||
//nolint:gocritic //TODO: make a provisionerd role
|
||||
ctx = dbauthz.AsSystem(ctx)
|
||||
parsedID, err := uuid.Parse(request.JobId)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse job id: %w", err)
|
||||
|
@ -345,7 +352,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
|
|||
slog.F("stage", log.Stage),
|
||||
slog.F("output", log.Output))
|
||||
}
|
||||
logs, err := server.Database.InsertProvisionerJobLogs(context.Background(), insertParams)
|
||||
//nolint:gocritic //TODO: make a provisionerd role
|
||||
logs, err := server.Database.InsertProvisionerJobLogs(dbauthz.AsSystem(context.Background()), insertParams)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("insert job logs: %w", err)
|
||||
|
@ -470,6 +478,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
|
|||
}
|
||||
|
||||
func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) {
|
||||
//nolint:gocritic // TODO: make a provisionerd role
|
||||
ctx = dbauthz.AsSystem(ctx)
|
||||
jobID, err := uuid.Parse(failJob.JobId)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse job id: %w", err)
|
||||
|
@ -596,6 +606,8 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
|
||||
// CompleteJob is triggered by a provision daemon to mark a provisioner job as completed.
|
||||
func (server *Server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) {
|
||||
//nolint:gocritic // TODO: make a provisionerd role
|
||||
ctx = dbauthz.AsSystem(ctx)
|
||||
jobID, err := uuid.Parse(completed.JobId)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse job id: %w", err)
|
||||
|
|
|
@ -16,9 +16,10 @@ import (
|
|||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
|
@ -32,6 +33,7 @@ import (
|
|||
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
actor, _ = dbauthz.ActorFromContext(ctx)
|
||||
logger = api.Logger.With(slog.F("job_id", job.ID))
|
||||
follow = r.URL.Query().Has("follow")
|
||||
afterRaw = r.URL.Query().Get("after")
|
||||
|
@ -49,7 +51,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
|
|||
// of processed IDs.
|
||||
var bufferedLogs <-chan database.ProvisionerJobLog
|
||||
if follow {
|
||||
bl, closeFollow, err := api.followLogs(job.ID)
|
||||
bl, closeFollow, err := api.followLogs(actor, job.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error watching provisioner logs.",
|
||||
|
@ -367,7 +369,7 @@ type provisionerJobLogsMessage struct {
|
|||
EndOfLogs bool `json:"end_of_logs,omitempty"`
|
||||
}
|
||||
|
||||
func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) {
|
||||
func (api *API) followLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) {
|
||||
logger := api.Logger.With(slog.F("job_id", jobID))
|
||||
|
||||
var (
|
||||
|
@ -392,7 +394,7 @@ func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog,
|
|||
}
|
||||
|
||||
if jlMsg.CreatedAfter != 0 {
|
||||
logs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{
|
||||
logs, err := api.Database.GetProvisionerLogsByIDBetween(dbauthz.As(ctx, actor), database.GetProvisionerLogsByIDBetweenParams{
|
||||
JobID: jobID,
|
||||
CreatedAfter: jlMsg.CreatedAfter,
|
||||
})
|
||||
|
|
|
@ -1039,7 +1039,6 @@ func testAuthorize(t *testing.T, name string, subject Subject, sets ...[]authTes
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func must[T any](value T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
|
@ -133,6 +133,8 @@ var (
|
|||
ResourceWorkspace.Type: {ActionRead},
|
||||
// CRUD to provisioner daemons for now.
|
||||
ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
|
||||
// Needs to read all organizations since
|
||||
ResourceOrganization.Type: {ActionRead},
|
||||
}),
|
||||
Org: map[string][]Permission{},
|
||||
User: []Permission{},
|
||||
|
@ -217,6 +219,12 @@ var (
|
|||
// The first key is the actor role, the second is the roles they can assign.
|
||||
// map[actor_role][assign_role]<can_assign>
|
||||
assignRoles = map[string]map[string]bool{
|
||||
"system": {
|
||||
owner: true,
|
||||
member: true,
|
||||
orgAdmin: true,
|
||||
orgMember: true,
|
||||
},
|
||||
owner: {
|
||||
owner: true,
|
||||
auditor: true,
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
// BenchmarkRBACValueAllocation benchmarks the cost of allocating a rego input
|
||||
// value. By default, `ast.InterfaceToValue` is used to convert the input,
|
||||
// which uses json marshalling under the hood.
|
||||
// which uses json marshaling under the hood.
|
||||
//
|
||||
// Currently ast.Object.insert() is the slowest part of the process and allocates
|
||||
// the most amount of bytes. This general approach copies all of our struct
|
||||
|
|
|
@ -19,6 +19,7 @@ type authSubject struct {
|
|||
Actor rbac.Subject
|
||||
}
|
||||
|
||||
// TODO: add the SYSTEM to the MATRIX
|
||||
func TestRolePermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -183,8 +184,8 @@ func TestRolePermissions(t *testing.T) {
|
|||
Actions: []rbac.Action{rbac.ActionRead},
|
||||
Resource: rbac.ResourceOrganization.WithID(orgID).InOrg(orgID),
|
||||
AuthorizeMap: map[bool][]authSubject{
|
||||
true: {owner, orgAdmin, orgMemberMe},
|
||||
false: {otherOrgAdmin, otherOrgMember, memberMe, templateAdmin, userAdmin},
|
||||
true: {owner, orgAdmin, orgMemberMe, templateAdmin},
|
||||
false: {otherOrgAdmin, otherOrgMember, memberMe, userAdmin},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package rbac
|
||||
|
||||
import "github.com/open-policy-agent/opa/rego"
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
)
|
||||
|
||||
const (
|
||||
// errUnauthorized is the error message that should be returned to
|
||||
|
@ -24,6 +28,12 @@ type UnauthorizedError struct {
|
|||
output rego.ResultSet
|
||||
}
|
||||
|
||||
// IsUnauthorizedError is a convenience function to check if err is UnauthorizedError.
|
||||
// It is equivalent to errors.As(err, &UnauthorizedError{}).
|
||||
func IsUnauthorizedError(err error) bool {
|
||||
return errors.As(err, &UnauthorizedError{})
|
||||
}
|
||||
|
||||
// ForbiddenWithInternal creates a new error that will return a simple
|
||||
// "forbidden" to the client, logging internally the more detailed message
|
||||
// provided.
|
||||
|
@ -37,6 +47,10 @@ func ForbiddenWithInternal(internal error, subject Subject, action Action, objec
|
|||
}
|
||||
}
|
||||
|
||||
func (e UnauthorizedError) Unwrap() error {
|
||||
return e.internal
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (UnauthorizedError) Error() string {
|
||||
return errUnauthorized
|
||||
|
@ -47,6 +61,10 @@ func (e *UnauthorizedError) Internal() error {
|
|||
return e.internal
|
||||
}
|
||||
|
||||
func (e *UnauthorizedError) SetInternal(err error) {
|
||||
e.internal = err
|
||||
}
|
||||
|
||||
func (e *UnauthorizedError) Input() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"subject": e.subject,
|
||||
|
@ -59,3 +77,11 @@ func (e *UnauthorizedError) Input() map[string]interface{} {
|
|||
func (e *UnauthorizedError) Output() rego.ResultSet {
|
||||
return e.output
|
||||
}
|
||||
|
||||
// As implements the errors.As interface.
|
||||
func (*UnauthorizedError) As(target interface{}) bool {
|
||||
if _, ok := target.(*UnauthorizedError); ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
package rbac_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func TestIsUnauthorizedError(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NotWrapped", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
errFunc := func() error {
|
||||
return rbac.UnauthorizedError{}
|
||||
}
|
||||
|
||||
err := errFunc()
|
||||
require.True(t, rbac.IsUnauthorizedError(err))
|
||||
})
|
||||
|
||||
t.Run("Wrapped", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
errFunc := func() error {
|
||||
return xerrors.Errorf("test error: %w", rbac.UnauthorizedError{})
|
||||
}
|
||||
err := errFunc()
|
||||
require.True(t, rbac.IsUnauthorizedError(err))
|
||||
})
|
||||
}
|
|
@ -3,6 +3,8 @@ package rbac
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
|
@ -41,6 +43,29 @@ func (s Scope) Name() string {
|
|||
return s.Role.Name
|
||||
}
|
||||
|
||||
// WorkspaceAgentScope returns a scope that is the same as ScopeAll but can only
|
||||
// affect resources in the allow list. Only a scope is returned as the roles
|
||||
// should come from the workspace owner.
|
||||
func WorkspaceAgentScope(workspaceID, ownerID uuid.UUID) Scope {
|
||||
allScope, err := ScopeAll.Expand()
|
||||
if err != nil {
|
||||
panic("failed to expand scope all, this should never happen")
|
||||
}
|
||||
return Scope{
|
||||
// TODO: We want to limit the role too to be extra safe.
|
||||
// Even though the allowlist blocks anything else, it is still good
|
||||
// incase we change the behavior of the allowlist. The allowlist is new
|
||||
// and evolving.
|
||||
Role: allScope.Role,
|
||||
// This prevents the agent from being able to access any other resource.
|
||||
AllowIDList: []string{
|
||||
workspaceID.String(),
|
||||
ownerID.String(),
|
||||
// TODO: Might want to include the template the workspace uses too?
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ScopeAll ScopeName = "all"
|
||||
ScopeApplicationConnect ScopeName = "application_connect"
|
||||
|
|
|
@ -47,7 +47,7 @@ func (api *API) assignableOrgRoles(rw http.ResponseWriter, r *http.Request) {
|
|||
actorRoles := httpmw.UserAuthorization(r)
|
||||
|
||||
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) {
|
||||
httpapi.Forbidden(rw)
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ func TestListRoles(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err, "create org")
|
||||
|
||||
const forbidden = "Forbidden"
|
||||
const notFound = "Resource not found"
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Client *codersdk.Client
|
||||
|
@ -66,7 +66,7 @@ func TestListRoles(t *testing.T) {
|
|||
APICall: func(ctx context.Context) ([]codersdk.AssignableRoles, error) {
|
||||
return member.ListOrganizationRoles(ctx, otherOrg.ID)
|
||||
},
|
||||
AuthorizedError: forbidden,
|
||||
AuthorizedError: notFound,
|
||||
},
|
||||
// Org admin
|
||||
{
|
||||
|
@ -95,7 +95,7 @@ func TestListRoles(t *testing.T) {
|
|||
APICall: func(ctx context.Context) ([]codersdk.AssignableRoles, error) {
|
||||
return orgAdmin.ListOrganizationRoles(ctx, otherOrg.ID)
|
||||
},
|
||||
AuthorizedError: forbidden,
|
||||
AuthorizedError: notFound,
|
||||
},
|
||||
// Admin
|
||||
{
|
||||
|
@ -133,7 +133,7 @@ func TestListRoles(t *testing.T) {
|
|||
if c.AuthorizedError != "" {
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
|
||||
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, c.AuthorizedError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -82,6 +82,10 @@ func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// TODO: This just returns the workspaces a user can view. We should use
|
||||
// a system function to get all workspaces that use this template.
|
||||
// This data should never be exposed to the user aside from a non-zero count.
|
||||
// Or we move this into a postgres constraint.
|
||||
workspaces, err := api.Database.GetWorkspaces(ctx, database.GetWorkspacesParams{
|
||||
TemplateIds: []uuid.UUID{template.ID},
|
||||
})
|
||||
|
|
|
@ -18,9 +18,9 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
|
@ -57,7 +57,8 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
user, err := api.Database.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
|
||||
//nolint:gocritic // In order to login, we need to get the user first!
|
||||
user, err := api.Database.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
|
||||
Email: loginWithPassword.Email,
|
||||
})
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -111,15 +112,32 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic // System needs to fetch user roles in order to login user.
|
||||
roles, err := api.Database.GetAuthorizationUserRoles(dbauthz.AsSystem(ctx), user.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// If the user logged into a suspended account, reject the login request.
|
||||
if user.Status != database.UserStatusActive {
|
||||
if roles.Status != database.UserStatusActive {
|
||||
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: "Your account is suspended. Contact an admin to reactivate your account.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cookie, key, err := api.createAPIKey(ctx, createAPIKeyParams{
|
||||
userSubj := rbac.Subject{
|
||||
ID: user.ID.String(),
|
||||
Roles: rbac.RoleNames(roles.Roles),
|
||||
Groups: roles.Groups,
|
||||
Scope: rbac.ScopeAll,
|
||||
}
|
||||
|
||||
//nolint:gocritic // Creating the API key as the user instead of as system.
|
||||
cookie, key, err := api.createAPIKey(dbauthz.As(ctx, userSubj), createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
|
@ -765,7 +783,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
// with OIDC for the first time.
|
||||
if user.ID == uuid.Nil {
|
||||
var organizationID uuid.UUID
|
||||
organizations, _ := tx.GetOrganizations(ctx)
|
||||
//nolint:gocritic
|
||||
organizations, _ := tx.GetOrganizations(dbauthz.AsSystem(ctx))
|
||||
if len(organizations) > 0 {
|
||||
// Add the user to the first organization. Once multi-organization
|
||||
// support is added, we should enable a configuration map of user
|
||||
|
@ -773,7 +792,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
organizationID = organizations[0].ID
|
||||
}
|
||||
|
||||
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
|
||||
//nolint:gocritic
|
||||
_, err := tx.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
|
||||
Username: params.Username,
|
||||
})
|
||||
if err == nil {
|
||||
|
@ -786,7 +806,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
|
||||
params.Username = httpapi.UsernameFrom(alternate)
|
||||
|
||||
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
|
||||
//nolint:gocritic
|
||||
_, err := tx.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
|
||||
Username: params.Username,
|
||||
})
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -805,7 +826,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
}
|
||||
}
|
||||
|
||||
user, _, err = api.CreateUser(ctx, tx, CreateUserRequest{
|
||||
//nolint:gocritic
|
||||
user, _, err = api.CreateUser(dbauthz.AsSystem(ctx), tx, CreateUserRequest{
|
||||
CreateUserRequest: codersdk.CreateUserRequest{
|
||||
Email: params.Email,
|
||||
Username: params.Username,
|
||||
|
@ -819,7 +841,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
}
|
||||
|
||||
if link.UserID == uuid.Nil {
|
||||
link, err = tx.InsertUserLink(ctx, database.InsertUserLinkParams{
|
||||
//nolint:gocritic
|
||||
link, err = tx.InsertUserLink(dbauthz.AsSystem(ctx), database.InsertUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: params.LoginType,
|
||||
LinkedID: params.LinkedID,
|
||||
|
@ -833,7 +856,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
}
|
||||
|
||||
if link.UserID != uuid.Nil {
|
||||
link, err = tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
|
||||
//nolint:gocritic
|
||||
link, err = tx.UpdateUserLink(dbauthz.AsSystem(ctx), database.UpdateUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: params.LoginType,
|
||||
OAuthAccessToken: params.State.Token.AccessToken,
|
||||
|
@ -847,7 +871,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
|
||||
// Ensure groups are correct.
|
||||
if len(params.Groups) > 0 {
|
||||
err := api.Options.SetUserGroups(ctx, tx, user.ID, params.Groups)
|
||||
//nolint:gocritic
|
||||
err := api.Options.SetUserGroups(dbauthz.AsSystem(ctx), tx, user.ID, params.Groups)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set user groups: %w", err)
|
||||
}
|
||||
|
@ -880,7 +905,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
// In such cases in the current implementation this user can now no
|
||||
// longer sign in until an administrator finds the offending built-in
|
||||
// user and changes their username.
|
||||
user, err = tx.UpdateUserProfile(ctx, database.UpdateUserProfileParams{
|
||||
//nolint:gocritic
|
||||
user, err = tx.UpdateUserProfile(dbauthz.AsSystem(ctx), database.UpdateUserProfileParams{
|
||||
ID: user.ID,
|
||||
Email: user.Email,
|
||||
Username: user.Username,
|
||||
|
@ -898,7 +924,8 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
return nil, database.APIKey{}, xerrors.Errorf("in tx: %w", err)
|
||||
}
|
||||
|
||||
cookie, key, err := api.createAPIKey(ctx, createAPIKeyParams{
|
||||
//nolint:gocritic
|
||||
cookie, key, err := api.createAPIKey(dbauthz.AsSystem(ctx), createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: params.LoginType,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
|
@ -37,7 +38,8 @@ import (
|
|||
// @Router /users/first [get]
|
||||
func (api *API) firstUser(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userCount, err := api.Database.GetUserCount(ctx)
|
||||
//nolint:gocritic // needed for first user check
|
||||
userCount, err := api.Database.GetUserCount(dbauthz.AsSystem(ctx))
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching user count.",
|
||||
|
@ -70,6 +72,7 @@ func (api *API) firstUser(rw http.ResponseWriter, r *http.Request) {
|
|||
// @Success 201 {object} codersdk.CreateFirstUserResponse
|
||||
// @Router /users/first [post]
|
||||
func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Should this admin system context be in a middleware?
|
||||
ctx := r.Context()
|
||||
var createUser codersdk.CreateFirstUserRequest
|
||||
if !httpapi.Read(ctx, rw, r, &createUser) {
|
||||
|
@ -77,7 +80,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// This should only function for the first user.
|
||||
userCount, err := api.Database.GetUserCount(ctx)
|
||||
//nolint:gocritic // needed to create first user
|
||||
userCount, err := api.Database.GetUserCount(dbauthz.AsSystem(ctx))
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching user count.",
|
||||
|
@ -117,7 +121,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
user, organizationID, err := api.CreateUser(ctx, api.Database, CreateUserRequest{
|
||||
//nolint:gocritic // needed to create first user
|
||||
user, organizationID, err := api.CreateUser(dbauthz.AsSystem(ctx), api.Database, CreateUserRequest{
|
||||
CreateUserRequest: codersdk.CreateUserRequest{
|
||||
Email: createUser.Email,
|
||||
Username: createUser.Username,
|
||||
|
@ -146,7 +151,8 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
|||
// the user. Maybe I add this ability to grant roles in the createUser api
|
||||
// and add some rbac bypass when calling api functions this way??
|
||||
// Add the admin role to this first user.
|
||||
_, err = api.Database.UpdateUserRoles(ctx, database.UpdateUserRolesParams{
|
||||
//nolint:gocritic // needed to create first user
|
||||
_, err = api.Database.UpdateUserRoles(dbauthz.AsSystem(ctx), database.UpdateUserRolesParams{
|
||||
GrantedRoles: []string{rbac.RoleOwner()},
|
||||
ID: user.ID,
|
||||
})
|
||||
|
@ -987,7 +993,7 @@ func (api *API) organizationByUserAndName(rw http.ResponseWriter, r *http.Reques
|
|||
ctx := r.Context()
|
||||
organizationName := chi.URLParam(r, "organizationname")
|
||||
organization, err := api.Database.GetOrganizationByName(ctx, organizationName)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -854,7 +854,7 @@ func TestGrantSiteRoles(t *testing.T) {
|
|||
AssignToUser: randOrgUser.ID.String(),
|
||||
Roles: []string{rbac.RoleOrgMember(randOrg.ID)},
|
||||
Error: true,
|
||||
StatusCode: http.StatusForbidden,
|
||||
StatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
Name: "AdminUpdateOrgSelf",
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
package slice
|
||||
|
||||
// New is a convenience method for creating []T.
|
||||
func New[T any](items ...T) []T {
|
||||
return items
|
||||
}
|
||||
|
||||
// SameElements returns true if the 2 lists have the same elements in any
|
||||
// order.
|
||||
func SameElements[T comparable](a []T, b []T) bool {
|
||||
|
@ -67,3 +62,8 @@ func OverlapCompare[T any](a []T, b []T, equal func(a, b T) bool) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// New is a convenience method for creating []T.
|
||||
func New[T any](items ...T) []T {
|
||||
return items
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
|
@ -625,14 +626,29 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
// inactive disconnect timeout we ensure that we don't block but
|
||||
// also guarantee that the agent will be considered disconnected
|
||||
// by normal status check.
|
||||
ctx, cancel := context.WithTimeout(api.ctx, api.AgentInactiveDisconnectTimeout)
|
||||
//
|
||||
// Use a system context as the agent has disconnected and that token
|
||||
// may no longer be valid.
|
||||
//nolint:gocritic
|
||||
ctx, cancel := context.WithTimeout(dbauthz.AsSystem(api.ctx), api.AgentInactiveDisconnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
disconnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
_ = updateConnectionTimes(ctx)
|
||||
err := updateConnectionTimes(ctx)
|
||||
if err != nil {
|
||||
// This is a bug with unit tests that cancel the app context and
|
||||
// cause this error log to be generated. We should fix the unit tests
|
||||
// as this is a valid log.
|
||||
if !xerrors.Is(err, context.Canceled) {
|
||||
api.Logger.Error(ctx, "failed to update agent disconnect time",
|
||||
slog.Error(err),
|
||||
slog.F("workspace", build.WorkspaceID),
|
||||
)
|
||||
}
|
||||
}
|
||||
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
|
||||
}()
|
||||
|
||||
|
@ -907,7 +923,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
|
|||
slog.F("payload", req),
|
||||
)
|
||||
|
||||
activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace.ID)
|
||||
activityBumpWorkspace(ctx, api.Logger.Named("activity_bump"), api.Database, workspace.ID)
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
|
@ -330,7 +331,8 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request
|
|||
// different auth formats, and tricks this endpoint into deleting an
|
||||
// unchecked API key, we validate that the secret matches the secret
|
||||
// we store in the database.
|
||||
apiKey, err := api.Database.GetAPIKeyByID(ctx, id)
|
||||
//nolint:gocritic // needed for workspace app logout
|
||||
apiKey, err := api.Database.GetAPIKeyByID(dbauthz.AsSystem(ctx), id)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to lookup API key.",
|
||||
|
@ -349,7 +351,8 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request
|
|||
})
|
||||
return
|
||||
}
|
||||
err = api.Database.DeleteAPIKeyByID(ctx, id)
|
||||
//nolint:gocritic // needed for workspace app logout
|
||||
err = api.Database.DeleteAPIKeyByID(dbauthz.AsSystem(ctx), id)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to delete API key.",
|
||||
|
@ -409,7 +412,10 @@ func (api *API) handleWorkspaceAppLogout(rw http.ResponseWriter, r *http.Request
|
|||
// error while looking it up, an HTML error page is returned and false is
|
||||
// returned so the caller can return early.
|
||||
func (api *API) lookupWorkspaceApp(rw http.ResponseWriter, r *http.Request, agentID uuid.UUID, appSlug string) (database.WorkspaceApp, bool) {
|
||||
app, err := api.Database.GetWorkspaceAppByAgentIDAndSlug(r.Context(), database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
// dbauthz.AsSystem is allowed here as the app authz is checked later.
|
||||
// The app authz is determined by the sharing level.
|
||||
//nolint:gocritic
|
||||
app, err := api.Database.GetWorkspaceAppByAgentIDAndSlug(dbauthz.AsSystem(r.Context()), database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: agentID,
|
||||
Slug: appSlug,
|
||||
})
|
||||
|
@ -1019,7 +1025,8 @@ func decryptAPIKey(ctx context.Context, db database.Store, encryptedAPIKey strin
|
|||
|
||||
// Lookup the API key so we can decrypt it.
|
||||
keyID := object.Header.KeyID
|
||||
key, err := db.GetAPIKeyByID(ctx, keyID)
|
||||
//nolint:gocritic // needed to check API key
|
||||
key, err := db.GetAPIKeyByID(dbauthz.AsSystem(ctx), keyID)
|
||||
if err != nil {
|
||||
return database.APIKey{}, "", xerrors.Errorf("get API key by key ID: %w", err)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/coder/coder/coderd/awsidentity"
|
||||
"github.com/coder/coder/coderd/azureidentity"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
@ -126,7 +127,8 @@ func (api *API) postWorkspaceAuthGoogleInstanceIdentity(rw http.ResponseWriter,
|
|||
|
||||
func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, instanceID string) {
|
||||
ctx := r.Context()
|
||||
agent, err := api.Database.GetWorkspaceAgentByInstanceID(ctx, instanceID)
|
||||
//nolint:gocritic // needed for auth instance id
|
||||
agent, err := api.Database.GetWorkspaceAgentByInstanceID(dbauthz.AsSystem(ctx), instanceID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Instance with id %q not found.", instanceID),
|
||||
|
@ -140,7 +142,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
|
|||
})
|
||||
return
|
||||
}
|
||||
resource, err := api.Database.GetWorkspaceResourceByID(ctx, agent.ResourceID)
|
||||
//nolint:gocritic // needed for auth instance id
|
||||
resource, err := api.Database.GetWorkspaceResourceByID(dbauthz.AsSystem(ctx), agent.ResourceID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner job resource.",
|
||||
|
@ -148,7 +151,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
|
|||
})
|
||||
return
|
||||
}
|
||||
job, err := api.Database.GetProvisionerJobByID(ctx, resource.JobID)
|
||||
//nolint:gocritic // needed for auth instance id
|
||||
job, err := api.Database.GetProvisionerJobByID(dbauthz.AsSystem(ctx), resource.JobID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner job.",
|
||||
|
@ -171,7 +175,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
|
|||
})
|
||||
return
|
||||
}
|
||||
resourceHistory, err := api.Database.GetWorkspaceBuildByID(ctx, jobData.WorkspaceBuildID)
|
||||
//nolint:gocritic // needed for auth instance id
|
||||
resourceHistory, err := api.Database.GetWorkspaceBuildByID(dbauthz.AsSystem(ctx), jobData.WorkspaceBuildID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace build.",
|
||||
|
@ -182,7 +187,8 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
|
|||
// This token should only be exchanged if the instance ID is valid
|
||||
// for the latest history. If an instance ID is recycled by a cloud,
|
||||
// we'd hate to leak access to a user's workspace.
|
||||
latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, resourceHistory.WorkspaceID)
|
||||
//nolint:gocritic // needed for auth instance id
|
||||
latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(dbauthz.AsSystem(ctx), resourceHistory.WorkspaceID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching the latest workspace build.",
|
||||
|
|
|
@ -371,6 +371,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
|
|||
return
|
||||
}
|
||||
|
||||
// TODO: This should be a system call as the actor might not be able to
|
||||
// read other workspaces. Ideally we check the error on create and look for
|
||||
// a postgres conflict error.
|
||||
workspace, err := api.Database.GetWorkspaceByOwnerIDAndName(ctx, database.GetWorkspaceByOwnerIDAndNameParams{
|
||||
OwnerID: user.ID,
|
||||
Name: createWorkspace.Name,
|
||||
|
|
|
@ -144,7 +144,9 @@ func New(ctx context.Context, options *Options) (*API, error) {
|
|||
|
||||
if len(options.SCIMAPIKey) != 0 {
|
||||
api.AGPL.RootHandler.Route("/scim/v2", func(r chi.Router) {
|
||||
r.Use(api.scimEnabledMW)
|
||||
r.Use(
|
||||
api.scimEnabledMW,
|
||||
)
|
||||
r.Post("/Users", api.scimPostUser)
|
||||
r.Route("/Users", func(r chi.Router) {
|
||||
r.Get("/", api.scimGetUsers)
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
@ -100,7 +102,9 @@ func TestEntitlements(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.False(t, entitlements.HasLicense)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
//nolint:gocritic // unit test
|
||||
ctx := dbauthz.AsSystem(context.Background())
|
||||
_, err = api.Database.InsertLicense(ctx, database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
Exp: database.Now().AddDate(1, 0, 0),
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
|
@ -128,7 +132,9 @@ func TestEntitlements(t *testing.T) {
|
|||
require.False(t, entitlements.HasLicense)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
// Valid
|
||||
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
ctx := context.Background()
|
||||
//nolint:gocritic // unit test
|
||||
_, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
Exp: database.Now().AddDate(1, 0, 0),
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
|
@ -139,7 +145,8 @@ func TestEntitlements(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
// Expired
|
||||
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
//nolint:gocritic // unit test
|
||||
_, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
Exp: database.Now().AddDate(-1, 0, 0),
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
|
@ -148,7 +155,8 @@ func TestEntitlements(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
// Invalid
|
||||
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
//nolint:gocritic // unit test
|
||||
_, err = api.Database.InsertLicense(dbauthz.AsSystem(ctx), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
Exp: database.Now().AddDate(1, 0, 0),
|
||||
JWT: "invalid",
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -22,6 +24,9 @@ func TestNew(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
if strings.Contains(os.Getenv("CODER_EXPERIMENTS_TEST"), string(codersdk.ExperimentAuthzQuerier)) {
|
||||
t.Skip("Skipping TestAuthorizeAllEndpoints for authz_querier experiment")
|
||||
}
|
||||
t.Parallel()
|
||||
client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
agpl "github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
@ -155,7 +156,8 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
user, _, err := api.AGPL.CreateUser(ctx, api.Database, agpl.CreateUserRequest{
|
||||
//nolint:gocritic // needed for SCIM
|
||||
user, _, err := api.AGPL.CreateUser(dbauthz.AsSystem(ctx), api.Database, agpl.CreateUserRequest{
|
||||
CreateUserRequest: codersdk.CreateUserRequest{
|
||||
Username: sUser.UserName,
|
||||
Email: email,
|
||||
|
@ -207,7 +209,8 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
dbUser, err := api.Database.GetUserByID(ctx, uid)
|
||||
//nolint:gocritic // needed for SCIM
|
||||
dbUser, err := api.Database.GetUserByID(dbauthz.AsSystem(ctx), uid)
|
||||
if err != nil {
|
||||
_ = handlerutil.WriteError(rw, err)
|
||||
return
|
||||
|
@ -220,7 +223,8 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
|
|||
status = database.UserStatusSuspended
|
||||
}
|
||||
|
||||
_, err = api.Database.UpdateUserStatus(r.Context(), database.UpdateUserStatusParams{
|
||||
//nolint:gocritic // needed for SCIM
|
||||
_, err = api.Database.UpdateUserStatus(dbauthz.AsSystem(r.Context()), database.UpdateUserStatusParams{
|
||||
ID: dbUser.ID,
|
||||
Status: status,
|
||||
UpdatedAt: database.Now(),
|
||||
|
|
|
@ -921,6 +921,10 @@ func TestTemplateAccess(t *testing.T) {
|
|||
|
||||
testTemplateRead := func(t *testing.T, org orgSetup, usr *codersdk.Client, read []codersdk.Template) {
|
||||
found, err := usr.TemplatesByOrganization(ctx, org.Org.ID)
|
||||
if len(read) == 0 && err != nil {
|
||||
require.ErrorContains(t, err, "Resource not found")
|
||||
return
|
||||
}
|
||||
require.NoError(t, err, "failed to get templates")
|
||||
|
||||
exp := make(map[uuid.UUID]codersdk.Template)
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/provisionerd/proto"
|
||||
sdkproto "github.com/coder/coder/provisionersdk/proto"
|
||||
|
@ -886,7 +887,8 @@ func (r *Runner) commitQuota(ctx context.Context, resources []*sdkproto.Resource
|
|||
|
||||
const stage = "Commit quota"
|
||||
|
||||
resp, err := r.quotaCommitter.CommitQuota(ctx, &proto.CommitQuotaRequest{
|
||||
//nolint:gocritic // TODO: make a provisionerd role
|
||||
resp, err := r.quotaCommitter.CommitQuota(dbauthz.AsSystem(ctx), &proto.CommitQuotaRequest{
|
||||
JobId: r.job.JobId,
|
||||
DailyCost: int32(cost),
|
||||
})
|
||||
|
|
|
@ -20,6 +20,29 @@ import (
|
|||
"github.com/quasilyte/go-ruleguard/dsl/types"
|
||||
)
|
||||
|
||||
// dbauthzAuthorizationContext is a lint rule that protects the usage of
|
||||
// system contexts. This is a dangerous pattern that can lead to
|
||||
// leaking database information as a system context can be essentially
|
||||
// "sudo".
|
||||
//
|
||||
// Anytime a function like "AsSystem" is used, it should be accompanied by a comment
|
||||
// explaining why it's ok and a nolint.
|
||||
func dbauthzAuthorizationContext(m dsl.Matcher) {
|
||||
m.Import("context")
|
||||
m.Import("github.com/coder/coder/coderd/database/dbauthz")
|
||||
|
||||
m.Match(
|
||||
`dbauthz.$f($c)`,
|
||||
).
|
||||
Where(
|
||||
m["c"].Type.Implements("context.Context") &&
|
||||
// Only report on functions that start with "As".
|
||||
m["f"].Text.Matches("^As"),
|
||||
).
|
||||
// Instructions for fixing the lint error should be included on the dangerous function.
|
||||
Report("Using '$f' is dangerous and should be accompanied by a comment explaining why it's ok and a nolint.")
|
||||
}
|
||||
|
||||
// Use xerrors everywhere! It provides additional stacktrace info!
|
||||
//
|
||||
//nolint:unused,deadcode,varnamelen
|
||||
|
|
Loading…
Reference in New Issue