diff --git a/coderd/authorize.go b/coderd/authorize.go index 53fa6e553a..577447d306 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/google/uuid" - "golang.org/x/xerrors" "cdr.dev/slog" @@ -95,19 +94,14 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r // from postgres are already authorized, and the caller does not need to // call 'Authorize()' on the returned objects. // Note the authorization is only for the given action and object type. -func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) { +func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) { roles := httpmw.UserAuthorization(r) prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType) if err != nil { return nil, xerrors.Errorf("prepare filter: %w", err) } - filter, err := prepared.Compile() - if err != nil { - return nil, xerrors.Errorf("compile filter: %w", err) - } - - return filter, nil + return prepared, nil } // checkAuthorization returns if the current API key can use the given diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 127c5037bd..b72db5fa05 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -9,6 +9,8 @@ import ( "strings" "testing" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,12 +18,19 @@ import ( "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/rbac/regosql" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" ) func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { + // For any route using SQL filters, we need to know if the database is an + // in memory fake. This is because the in memory fake does not use SQL, and + // still uses rego. So this boolean indicates how to assert the expected + // behavior. + _, isMemoryDB := a.api.Database.(databasefake.FakeDatabase) + // Some quick reused objects workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) @@ -125,11 +134,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { AssertAction: rbac.ActionCreate, AssertObject: workspaceExecObj, }, - "GET:/api/v2/organizations/{organization}/templates": { - StatusCode: http.StatusOK, - AssertAction: rbac.ActionRead, - AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID), - }, "POST:/api/v2/organizations/{organization}/templates": { AssertAction: rbac.ActionCreate, AssertObject: rbac.ResourceTemplate.InOrg(a.Organization.ID), @@ -240,7 +244,18 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "GET:/api/v2/organizations/{organization}/templateversions/{templateversionname}": {StatusCode: http.StatusBadRequest, NoAuthorize: true}, // Endpoints that use the SQLQuery filter. - "GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true}, + "GET:/api/v2/workspaces/": { + StatusCode: http.StatusOK, + NoAuthorize: !isMemoryDB, + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceWorkspace, + }, + "GET:/api/v2/organizations/{organization}/templates": { + StatusCode: http.StatusOK, + NoAuthorize: !isMemoryDB, + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceTemplate, + }, } // Routes like proxy routes support all HTTP methods. A helper func to expand @@ -549,10 +564,10 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object) } -// Compile returns a compiled version of the authorizer that will work for +// CompileToSQL returns a compiled version of the authorizer that will work for // in memory databases. This fake version will not work against a SQL database. -func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) { - return f, nil +func (fakePreparedAuthorizer) CompileToSQL(_ regosql.ConvertConfig) (string, error) { + return "", xerrors.New("not implemented") } func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool { @@ -565,10 +580,3 @@ func (f fakePreparedAuthorizer) RegoString() string { } panic("not implemented") } - -func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string { - if f.HardCodedSQLString != "" { - return f.HardCodedSQLString - } - panic("not implemented") -} diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 64f8506832..72441d83e8 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -20,6 +20,13 @@ import ( "github.com/coder/coder/coderd/util/slice" ) +// FakeDatabase is helpful for knowing if the underlying db is an in memory fake +// database. This is only in the databasefake package, so will only be used +// by unit tests. +type FakeDatabase interface { + IsFakeDB() +} + var errDuplicateKey = &pq.Error{ Code: "23505", Message: "duplicate key value violates unique constraint", @@ -117,6 +124,7 @@ type data struct { lastLicenseID int32 } +func (fakeQuerier) IsFakeDB() {} func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) { return 0, nil } @@ -488,11 +496,20 @@ func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.Get return count, err } -func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) { +func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { q.mutex.RLock() defer q.mutex.RUnlock() - users := append([]database.User{}, q.users...) + users := make([]database.User, 0, len(q.users)) + + for _, user := range q.users { + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { + continue + } + + users = append(users, user) + } if params.Deleted { tmp := make([]database.User, 0, len(users)) @@ -539,13 +556,6 @@ func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database. users = usersFilteredByRole } - for _, user := range q.workspaces { - // If the filter exists, ensure the object is authorized. - if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) { - continue - } - } - return int64(len(users)), nil } @@ -750,7 +760,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa } //nolint:gocyclo -func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.GetWorkspacesRow, error) { +func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -923,7 +933,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database. } // If the filter exists, ensure the object is authorized. - if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) { + if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { continue } workspaces = append(workspaces, workspace) @@ -1505,12 +1515,20 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd return database.Template{}, sql.ErrNoRows } -func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { +func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) { + return q.GetAuthorizedTemplates(ctx, arg, nil) +} + +func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { q.mutex.RLock() defer q.mutex.RUnlock() var templates []database.Template for _, template := range q.templates { + if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { + continue + } + if template.Deleted != arg.Deleted { continue } diff --git a/coderd/database/databasefake/databasefake_test.go b/coderd/database/databasefake/databasefake_test.go index f3f53de7d5..b0963cd911 100644 --- a/coderd/database/databasefake/databasefake_test.go +++ b/coderd/database/databasefake/databasefake_test.go @@ -74,6 +74,7 @@ func TestExactMethods(t *testing.T) { extraFakeMethods := map[string]string{ // Example // "SortFakeLists": "Helper function used", + "IsFakeDB": "Helper function used for unit testing", } fake := reflect.TypeOf(databasefake.New()) diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 5b4fc7f5f3..46a4b74d69 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -5,12 +5,16 @@ import ( "fmt" "strings" + "github.com/google/uuid" "github.com/lib/pq" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/coderd/rbac/regosql" +) - "github.com/google/uuid" - "golang.org/x/xerrors" +const ( + authorizedQueryPlaceholder = "-- @authorize_filter" ) // customQuerier encompasses all non-generated queries. @@ -23,10 +27,70 @@ type customQuerier interface { } type templateQuerier interface { + GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error) } +func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) { + authorizedFilter, err := prepared.CompileToSQL(regosql.ConvertConfig{ + VariableConverter: regosql.TemplateConverter(), + }) + if err != nil { + return nil, xerrors.Errorf("compile authorized filter: %w", err) + } + + filtered, err := insertAuthorizedFilter(getTemplatesWithFilter, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return nil, xerrors.Errorf("insert authorized filter: %w", err) + } + + // The name comment is for metric tracking + query := fmt.Sprintf("-- name: GetAuthorizedTemplates :many\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.Deleted, + arg.OrganizationID, + arg.ExactName, + pq.Array(arg.IDs), + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Template + for rows.Next() { + var i Template + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OrganizationID, + &i.Deleted, + &i.Name, + &i.Provisioner, + &i.ActiveVersionID, + &i.Description, + &i.DefaultTTL, + &i.CreatedBy, + &i.Icon, + &i.UserACL, + &i.GroupACL, + &i.DisplayName, + &i.AllowUserCancelWorkspaceJobs, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + type TemplateUser struct { User Actions Actions `db:"actions"` @@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([ } type workspaceQuerier interface { - GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) + GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) } // GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access. // This code is copied from `GetWorkspaces` and adds the authorized filter WHERE // clause. -func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) { +func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) { + authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL()) + if err != nil { + return nil, xerrors.Errorf("compile authorized filter: %w", err) + } + // In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the // authorizedFilter between the end of the where clause and those statements. - filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1) + filtered, err := insertAuthorizedFilter(getWorkspaces, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return nil, xerrors.Errorf("insert authorized filter: %w", err) + } + // The name comment is for metric tracking - query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter) + query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered) rows, err := q.db.QueryContext(ctx, query, arg.Deleted, arg.Status, @@ -172,12 +245,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa } type userQuerier interface { - GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) + GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) } -func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) { - filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1) - query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter) +func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL()) + if err != nil { + return -1, xerrors.Errorf("compile authorized filter: %w", err) + } + + filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return -1, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered) row := q.db.QueryRowContext(ctx, query, arg.Deleted, arg.Search, @@ -185,6 +267,14 @@ func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFiltered pq.Array(arg.RbacRole), ) var count int64 - err := row.Scan(&count) + err = row.Scan(&count) return count, err } + +func insertAuthorizedFilter(query string, replaceWith string) (string, error) { + if !strings.Contains(query, authorizedQueryPlaceholder) { + return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") + } + filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1) + return filtered, nil +} diff --git a/coderd/database/modelqueries_internal_test.go b/coderd/database/modelqueries_internal_test.go new file mode 100644 index 0000000000..4977120e88 --- /dev/null +++ b/coderd/database/modelqueries_internal_test.go @@ -0,0 +1,15 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsAuthorizedQuery(t *testing.T) { + t.Parallel() + + query := `SELECT true;` + _, err := insertAuthorizedFilter(query, "") + require.ErrorContains(t, err, "does not contain authorized replace string", "ensure replace string") +} diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index bd154369a5..2d4656becc 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3197,6 +3197,8 @@ WHERE id = ANY($4) ELSE true END + -- Authorize Filter clause will be injected below in GetAuthorizedTemplates + -- @authorize_filter ORDER BY (name, id) ASC ` diff --git a/coderd/database/queries/templates.sql b/coderd/database/queries/templates.sql index 86781c414f..5fa8270b5a 100644 --- a/coderd/database/queries/templates.sql +++ b/coderd/database/queries/templates.sql @@ -34,6 +34,8 @@ WHERE id = ANY(@ids) ELSE true END + -- Authorize Filter clause will be injected below in GetAuthorizedTemplates + -- @authorize_filter ORDER BY (name, id) ASC ; diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index 58d942f363..eaac305c71 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -10,6 +10,7 @@ import ( "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" + "github.com/coder/coder/coderd/rbac/regosql" "github.com/coder/coder/coderd/tracing" ) @@ -20,7 +21,7 @@ type Authorizer interface { type PreparedAuthorized interface { Authorize(ctx context.Context, object Object) error - Compile() (AuthorizeFilter, error) + CompileToSQL(cfg regosql.ConvertConfig) (string, error) } // Filter takes in a list of objects, and will filter the list removing all diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index 1df2e17f82..e6ce53b12a 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -847,7 +847,7 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes // Ensure the partial can compile to a SQL clause. // This does not guarantee that the clause is valid SQL. - _, err = Compile(partialAuthz) + _, err = Compile(ConfigWithACL(), partialAuthz) require.NoError(t, err, "compile prepared authorizer") // Also check the rego policy can form a valid partial query result. diff --git a/coderd/rbac/partial.go b/coderd/rbac/partial.go index 6049b6754f..1bf3155033 100644 --- a/coderd/rbac/partial.go +++ b/coderd/rbac/partial.go @@ -3,11 +3,11 @@ package rbac import ( "context" - "golang.org/x/xerrors" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" + "golang.org/x/xerrors" + "github.com/coder/coder/coderd/rbac/regosql" "github.com/coder/coder/coderd/tracing" ) @@ -28,12 +28,12 @@ type PartialAuthorizer struct { var _ PreparedAuthorized = (*PartialAuthorizer)(nil) -func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) { - filter, err := Compile(pa) +func (pa *PartialAuthorizer) CompileToSQL(cfg regosql.ConvertConfig) (string, error) { + filter, err := Compile(cfg, pa) if err != nil { - return nil, xerrors.Errorf("compile: %w", err) + return "", xerrors.Errorf("compile: %w", err) } - return filter, nil + return filter.SQLString(), nil } func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error { diff --git a/coderd/rbac/partial.rego b/coderd/rbac/partial.rego new file mode 100644 index 0000000000..0ea94a137a --- /dev/null +++ b/coderd/rbac/partial.rego @@ -0,0 +1 @@ +package partial diff --git a/coderd/rbac/query.go b/coderd/rbac/query.go index 8a046eb8ad..6d535753d2 100644 --- a/coderd/rbac/query.go +++ b/coderd/rbac/query.go @@ -2,635 +2,63 @@ package rbac import ( "context" - "fmt" - "regexp" - "strconv" "strings" - "github.com/open-policy-agent/opa/ast" + "github.com/coder/coder/coderd/rbac/regosql" + + "github.com/coder/coder/coderd/rbac/regosql/sqltypes" + "golang.org/x/xerrors" ) -type TermType string - -const ( - VarTypeJsonbTextArray TermType = "jsonb-text-array" - VarTypeText TermType = "text" - VarTypeBoolean TermType = "boolean" - // VarTypeSkip means this variable does not exist to use. - VarTypeSkip TermType = "skip" -) - -type SQLColumn struct { - // RegoMatch matches the original variable string. - // If it is a match, then this variable config will apply. - RegoMatch *regexp.Regexp - // ColumnSelect is the name of the postgres column to select. - // Can use capture groups from RegoMatch with $1, $2, etc. - ColumnSelect string - - // Type indicates the postgres type of the column. Some expressions will - // need to know this in order to determine what SQL to produce. - // An example is if the variable is a jsonb array, the "contains" SQL - // query is `variable ? 'value'` instead of `'value' = ANY(variable)`. - // This type is only needed to be provided - Type TermType -} - -type SQLConfig struct { - // Variables is a map of rego variable names to SQL columns. - // Example: - // "input\.object\.org_owner": SQLColumn{ - // ColumnSelect: "organization_id", - // Type: VarTypeUUID - // } - // "input\.object\.owner": SQLColumn{ - // ColumnSelect: "owner_id", - // Type: VarTypeUUID - // } - // "input\.object\.group_acl\.(.*)": SQLColumn{ - // ColumnSelect: "group_acl->$1", - // Type: VarTypeJsonbTextArray - // } - Variables []SQLColumn -} - -func DefaultConfig() SQLConfig { - return SQLConfig{ - Variables: []SQLColumn{ - { - RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`), - ColumnSelect: "group_acl->$1", - Type: VarTypeJsonbTextArray, - }, - { - RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`), - ColumnSelect: "user_acl->$1", - Type: VarTypeJsonbTextArray, - }, - { - RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`), - ColumnSelect: "organization_id :: text", - Type: VarTypeText, - }, - { - RegoMatch: regexp.MustCompile(`^input\.object\.owner$`), - ColumnSelect: "owner_id :: text", - Type: VarTypeText, - }, - }, - } -} - -func NoACLConfig() SQLConfig { - return SQLConfig{ - Variables: []SQLColumn{ - { - RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`), - ColumnSelect: "", - Type: VarTypeSkip, - }, - { - RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`), - ColumnSelect: "", - Type: VarTypeSkip, - }, - { - RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`), - ColumnSelect: "organization_id :: text", - Type: VarTypeText, - }, - { - RegoMatch: regexp.MustCompile(`^input\.object\.owner$`), - ColumnSelect: "owner_id :: text", - Type: VarTypeText, - }, - }, - } -} - type AuthorizeFilter interface { - Expression - // Eval is required for the fake in memory database to work. The in memory - // database can use this function to filter the results. - Eval(object Object) bool + SQLString() string } -// expressionTop handles Eval(object Object) for in memory expressions -type expressionTop struct { - Expression - Auth *PartialAuthorizer +type authorizedSQLFilter struct { + sqlString string + auth *PartialAuthorizer } -func (e expressionTop) Eval(object Object) bool { - return e.Auth.Authorize(context.Background(), object) == nil +func ConfigWithACL() regosql.ConvertConfig { + return regosql.ConvertConfig{ + VariableConverter: regosql.DefaultVariableConverter(), + } } -// Compile will convert a rego query AST into our custom types. The output is -// an AST that can be used to generate SQL. -func Compile(pa *PartialAuthorizer) (AuthorizeFilter, error) { - partialQueries := pa.partialQueries - if len(partialQueries.Support) > 0 { - return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support)) +func ConfigWithoutACL() regosql.ConvertConfig { + return regosql.ConvertConfig{ + VariableConverter: regosql.NoACLConverter(), + } +} + +func Compile(cfg regosql.ConvertConfig, pa *PartialAuthorizer) (AuthorizeFilter, error) { + root, err := regosql.ConvertRegoAst(cfg, pa.partialQueries) + if err != nil { + return nil, xerrors.Errorf("convert rego ast: %w", err) } - // 0 queries means the result is "undefined". This is the same as "false". - if len(partialQueries.Queries) == 0 { - return &termBoolean{ - base: base{Rego: "false"}, - Value: false, - }, nil + // Generate the SQL + gen := sqltypes.NewSQLGenerator() + sqlString := root.SQLString(gen) + if len(gen.Errors()) > 0 { + var errStrings []string + for _, err := range gen.Errors() { + errStrings = append(errStrings, err.Error()) + } + return nil, xerrors.Errorf("sql generation errors: %v", strings.Join(errStrings, ", ")) } - // Abort early if any of the "OR"'d expressions are the empty string. - // This is the same as "true". - for _, query := range partialQueries.Queries { - if query.String() == "" { - return &termBoolean{ - base: base{Rego: "true"}, - Value: true, - }, nil - } - } - - result := make([]Expression, 0, len(partialQueries.Queries)) - var builder strings.Builder - for i := range partialQueries.Queries { - query, err := processQuery(partialQueries.Queries[i]) - if err != nil { - return nil, err - } - result = append(result, query) - if i != 0 { - builder.WriteString("\n") - } - builder.WriteString(partialQueries.Queries[i].String()) - } - exp := expOr{ - base: base{ - Rego: builder.String(), - }, - Expressions: result, - } - return expressionTop{ - Expression: &exp, - Auth: pa, + return &authorizedSQLFilter{ + sqlString: sqlString, + auth: pa, }, nil } -// processQuery processes an entire set of expressions and joins them with -// "AND". -func processQuery(query ast.Body) (Expression, error) { - expressions := make([]Expression, 0, len(query)) - for _, astExpr := range query { - expr, err := processExpression(astExpr) - if err != nil { - return nil, err - } - expressions = append(expressions, expr) - } - - return expAnd{ - base: base{ - Rego: query.String(), - }, - Expressions: expressions, - }, nil +func (a *authorizedSQLFilter) Eval(object Object) bool { + return a.auth.Authorize(context.Background(), object) == nil } -func processExpression(expr *ast.Expr) (Expression, error) { - if !expr.IsCall() { - // This could be a single term that is a valid expression. - if term, ok := expr.Terms.(*ast.Term); ok { - value, err := processTerm(term) - if err != nil { - return nil, xerrors.Errorf("single term expression: %w", err) - } - if boolExp, ok := value.(Expression); ok { - return boolExp, nil - } - // Default to error. - } - return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported") - } - - op := expr.Operator().String() - base := base{Rego: op} - switch op { - case "neq", "eq", "equal": - terms, err := processTerms(2, expr.Operands()) - if err != nil { - return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) - } - return &opEqual{ - base: base, - Terms: [2]Term{terms[0], terms[1]}, - Not: op == "neq", - }, nil - case "internal.member_2": - terms, err := processTerms(2, expr.Operands()) - if err != nil { - return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err) - } - return &opInternalMember2{ - base: base, - Needle: terms[0], - Haystack: terms[1], - }, nil - default: - return nil, xerrors.Errorf("invalid expression: operator %s not supported", op) - } -} - -func processTerms(expected int, terms []*ast.Term) ([]Term, error) { - if len(terms) != expected { - return nil, xerrors.Errorf("too many arguments, expect %d found %d", expected, len(terms)) - } - - result := make([]Term, 0, len(terms)) - for _, term := range terms { - processed, err := processTerm(term) - if err != nil { - return nil, xerrors.Errorf("invalid term: %w", err) - } - result = append(result, processed) - } - - return result, nil -} - -func processTerm(term *ast.Term) (Term, error) { - termBase := base{Rego: term.String()} - switch v := term.Value.(type) { - case ast.Boolean: - return &termBoolean{ - base: termBase, - Value: bool(v), - }, nil - case ast.Ref: - obj := &termObject{ - base: termBase, - Path: []Term{}, - } - var idx int - // A ref is a set of terms. If the first term is a var, then the - // following terms are the path to the value. - isRef := true - var builder strings.Builder - for _, term := range v { - if idx == 0 { - if _, ok := v[0].Value.(ast.Var); !ok { - return nil, xerrors.Errorf("invalid term (%s): ref must start with a var, started with %T", v[0].String(), v[0]) - } - } - - _, newRef := term.Value.(ast.Ref) - if newRef || - // This is an unfortunate hack. To fix this, we need to rewrite - // our SQL config as a path ([]string{"input", "object", "acl_group"}). - // In the rego AST, there is no difference between selecting - // a field by a variable, and selecting a field by a literal (string). - // This was a misunderstanding. - // Example (these are equivalent by AST): - // input.object.acl_group_list['4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75'] - // input.object.acl_group_list.organization_id - // - // This is not equivalent - // input.object.acl_group_list[input.object.organization_id] - // - // If this becomes even more hairy, we should fix the sql config. - builder.String() == "input.object.acl_group_list" || - builder.String() == "input.object.acl_user_list" { - if !newRef { - isRef = false - } - // New obj - obj.Path = append(obj.Path, termVariable{ - base: base{ - Rego: builder.String(), - }, - Name: builder.String(), - }) - builder.Reset() - idx = 0 - } - - if builder.Len() != 0 { - builder.WriteString(".") - } - builder.WriteString(trimQuotes(term.String())) - idx++ - } - - if isRef { - obj.Path = append(obj.Path, termVariable{ - base: base{ - Rego: builder.String(), - }, - Name: builder.String(), - }) - } else { - obj.Path = append(obj.Path, termString{ - base: base{ - Rego: fmt.Sprintf("%q", builder.String()), - }, - Value: builder.String(), - }) - } - return obj, nil - case ast.Var: - return &termVariable{ - Name: trimQuotes(v.String()), - base: termBase, - }, nil - case ast.String: - return &termString{ - Value: trimQuotes(v.String()), - base: termBase, - }, nil - case ast.Set: - slice := v.Slice() - set := make([]Term, 0, len(slice)) - for _, elem := range slice { - processed, err := processTerm(elem) - if err != nil { - return nil, xerrors.Errorf("invalid set term %s: %w", elem.String(), err) - } - set = append(set, processed) - } - - return &termSet{ - Value: set, - base: termBase, - }, nil - default: - return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String()) - } -} - -type base struct { - // Rego is the original rego string - Rego string -} - -func (b base) RegoString() string { - return b.Rego -} - -// Expression comprises a set of terms, operators, and functions. All -// expressions return a boolean value. -// -// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo" -type Expression interface { - // RegoString is used in debugging to see the original rego expression. - RegoString() string - // SQLString returns the SQL expression that can be used in a WHERE clause. - SQLString(cfg SQLConfig) string -} - -type expAnd struct { - base - Expressions []Expression -} - -func (t expAnd) SQLString(cfg SQLConfig) string { - if len(t.Expressions) == 1 { - return t.Expressions[0].SQLString(cfg) - } - - exprs := make([]string, 0, len(t.Expressions)) - for _, expr := range t.Expressions { - exprs = append(exprs, expr.SQLString(cfg)) - } - return "(" + strings.Join(exprs, " AND ") + ")" -} - -type expOr struct { - base - Expressions []Expression -} - -func (t expOr) SQLString(cfg SQLConfig) string { - if len(t.Expressions) == 1 { - return t.Expressions[0].SQLString(cfg) - } - - exprs := make([]string, 0, len(t.Expressions)) - for _, expr := range t.Expressions { - exprs = append(exprs, expr.SQLString(cfg)) - } - return "(" + strings.Join(exprs, " OR ") + ")" -} - -// Operator joins terms together to form an expression. -// Operators are also expressions. -// -// Eg: "=", "neq", "internal.member_2", etc. -type Operator interface { - Expression -} - -type opEqual struct { - base - Terms [2]Term - // For NotEqual - Not bool -} - -func (t opEqual) SQLString(cfg SQLConfig) string { - op := "=" - if t.Not { - op = "!=" - } - return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg)) -} - -// opInternalMember2 is checking if the first term is a member of the second term. -// The second term is a set or list. -type opInternalMember2 struct { - base - Needle Term - Haystack Term -} - -func (t opInternalMember2) SQLString(cfg SQLConfig) string { - if haystack, ok := t.Haystack.(*termObject); ok { - // This is a special case where the haystack is a jsonb array. - // There is a more general way to solve this, but that requires a lot - // more code to cover a lot more cases that we do not care about. - // To handle this more generally we should implement "Array" as a type. - // Then have the `contains` function on the Array type. This would defer - // knowing the element type to the Array and cover more cases without - // having to add more "if" branches here. - // But until we need more cases, our basic type system is ok, and - // this is the only case we need to handle. - sqlType := haystack.SQLType(cfg) - if sqlType == VarTypeJsonbTextArray { - return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg)) - } - - if sqlType == VarTypeSkip { - return "false" - } - } - - return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg)) -} - -// Term is a single value in an expression. Terms can be variables or constants. -// -// Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner", -// "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}" -type Term interface { - RegoString() string - SQLString(cfg SQLConfig) string - SQLType(cfg SQLConfig) TermType -} - -type termString struct { - base - Value string -} - -func (t termString) SQLString(_ SQLConfig) string { - return "'" + t.Value + "'" -} - -func (termString) SQLType(_ SQLConfig) TermType { - return VarTypeText -} - -// termObject is a variable that can be dereferenced. We count some rego objects -// as single variables, eg: input.object.org_owner. In reality, it is a nested -// object. -// In rego, we can dereference the object with the "." operator, which we can -// handle with regex. -// Or we can dereference the object with the "[]", which we can handle with this -// term type. -type termObject struct { - base - Path []Term -} - -func (t termObject) SQLType(cfg SQLConfig) TermType { - // Without a full type system, let's just assume the type of the first var - // is the resulting type. This is correct for our use case. - // Solving this more generally requires a full type system, which is - // excessive for our mostly static policy. - return t.Path[0].SQLType(cfg) -} - -func (t termObject) SQLString(cfg SQLConfig) string { - if len(t.Path) == 1 { - return t.Path[0].SQLString(cfg) - } - // Combine the last 2 variables into 1 variable. - end := t.Path[len(t.Path)-1] - before := t.Path[len(t.Path)-2] - - // Recursively solve the SQLString by removing the last nested reference. - // This continues until we have a single variable. - return termObject{ - base: t.base, - Path: append( - t.Path[:len(t.Path)-2], - termVariable{ - base: base{ - Rego: before.RegoString() + "[" + end.RegoString() + "]", - }, - // Convert the end to SQL string. We evaluate each term - // one at a time. - Name: before.RegoString() + "." + end.SQLString(cfg), - }, - ), - }.SQLString(cfg) -} - -type termVariable struct { - base - Name string -} - -func (t termVariable) SQLType(cfg SQLConfig) TermType { - if col := t.ColumnConfig(cfg); col != nil { - return col.Type - } - return VarTypeText -} - -func (t termVariable) SQLString(cfg SQLConfig) string { - if col := t.ColumnConfig(cfg); col != nil { - matches := col.RegoMatch.FindStringSubmatch(t.Name) - if len(matches) > 0 { - // This config matches this variable. - replace := make([]string, 0, len(matches)*2) - for i, m := range matches { - replace = append(replace, fmt.Sprintf("$%d", i)) - replace = append(replace, m) - } - replacer := strings.NewReplacer(replace...) - return replacer.Replace(col.ColumnSelect) - } - } - - return t.Name -} - -// ColumnConfig returns the correct SQLColumn settings for the -// term. If there is no configured column, it will return nil. -func (t termVariable) ColumnConfig(cfg SQLConfig) *SQLColumn { - for _, col := range cfg.Variables { - matches := col.RegoMatch.MatchString(t.Name) - if matches { - return &col - } - } - return nil -} - -// termSet is a set of unique terms. -type termSet struct { - base - Value []Term -} - -func (t termSet) SQLType(cfg SQLConfig) TermType { - if len(t.Value) == 0 { - return VarTypeText - } - // Without a full type system, let's just assume the type of the first var - // is the resulting type. This is correct for our use case. - // Solving this more generally requires a full type system, which is - // excessive for our mostly static policy. - return t.Value[0].SQLType(cfg) -} - -func (t termSet) SQLString(cfg SQLConfig) string { - elems := make([]string, 0, len(t.Value)) - for _, v := range t.Value { - elems = append(elems, v.SQLString(cfg)) - } - - return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ",")) -} - -type termBoolean struct { - base - Value bool -} - -func (termBoolean) SQLType(SQLConfig) TermType { - return VarTypeBoolean -} - -func (t termBoolean) Eval(_ Object) bool { - return t.Value -} - -func (t termBoolean) SQLString(_ SQLConfig) string { - return strconv.FormatBool(t.Value) -} - -func trimQuotes(s string) string { - return strings.Trim(s, "\"") +func (a *authorizedSQLFilter) SQLString() string { + return a.sqlString } diff --git a/coderd/rbac/query_internal_test.go b/coderd/rbac/query_internal_test.go deleted file mode 100644 index f5c2a57715..0000000000 --- a/coderd/rbac/query_internal_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package rbac - -import ( - "context" - "testing" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/rego" - - "github.com/stretchr/testify/require" -) - -func TestCompileQuery(t *testing.T) { - t.Parallel() - - t.Run("EmptyQuery", func(t *testing.T) { - t.Parallel() - expression, err := Compile(partialQueries(t, "")) - require.NoError(t, err, "compile empty") - - require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'") - require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'") - }) - - t.Run("TrueQuery", func(t *testing.T) { - t.Parallel() - expression, err := Compile(partialQueries(t, "true")) - require.NoError(t, err, "compile") - - require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'") - require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'") - }) - - t.Run("ACLIn", func(t *testing.T) { - t.Parallel() - expression, err := Compile(partialQueries(t, `"*" in input.object.acl_group_list.allUsers`)) - require.NoError(t, err, "compile") - - require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member") - require.Equal(t, `group_acl->'allUsers' ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in") - }) - - t.Run("Complex", func(t *testing.T) { - t.Parallel() - expression, err := Compile(partialQueries(t, - `input.object.org_owner != ""`, - `input.object.org_owner in {"a", "b", "c"}`, - `input.object.org_owner != ""`, - `"read" in input.object.acl_group_list.allUsers`, - `"read" in input.object.acl_user_list.me`, - )) - require.NoError(t, err, "compile") - require.Equal(t, `(organization_id :: text != '' OR `+ - `organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+ - `organization_id :: text != '' OR `+ - `group_acl->'allUsers' ? 'read' OR `+ - `user_acl->'me' ? 'read')`, - expression.SQLString(DefaultConfig()), "complex") - }) - - t.Run("SetDereference", func(t *testing.T) { - t.Parallel() - expression, err := Compile(partialQueries(t, - `"*" in input.object.acl_group_list[input.object.org_owner]`, - )) - require.NoError(t, err, "compile") - require.Equal(t, `group_acl->organization_id :: text ? '*'`, - expression.SQLString(DefaultConfig()), "set dereference") - }) - - t.Run("JsonbLiteralDereference", func(t *testing.T) { - t.Parallel() - expression, err := Compile(partialQueries(t, - `"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`, - )) - require.NoError(t, err, "compile") - require.Equal(t, `group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'`, - expression.SQLString(DefaultConfig()), "literal dereference") - }) - - t.Run("NoACLColumns", func(t *testing.T) { - t.Parallel() - expression, err := Compile(partialQueries(t, - `"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`, - )) - require.NoError(t, err, "compile") - require.Equal(t, `false`, - expression.SQLString(NoACLConfig()), "literal dereference") - }) -} - -func TestEvalQuery(t *testing.T) { - t.Parallel() - - t.Run("GroupACL", func(t *testing.T) { - t.Parallel() - expression, err := Compile(partialQueries(t, - `"read" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`, - )) - require.NoError(t, err, "compile") - - result := expression.Eval(Object{ - Owner: "not-me", - OrgID: "random", - Type: "workspace", - ACLUserList: map[string][]Action{}, - ACLGroupList: map[string][]Action{ - "4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75": {"read"}, - }, - }) - require.True(t, result, "eval") - }) -} - -func partialQueries(t *testing.T, queries ...string) *PartialAuthorizer { - opts := ast.ParserOptions{ - AllFutureKeywords: true, - } - - astQueries := make([]ast.Body, 0, len(queries)) - for _, q := range queries { - astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts)) - } - - prepareQueries := make([]rego.PreparedEvalQuery, 0, len(queries)) - for _, q := range astQueries { - var prepped rego.PreparedEvalQuery - var err error - if q.String() == "" { - prepped, err = rego.New( - rego.Query("true"), - ).PrepareForEval(context.Background()) - } else { - prepped, err = rego.New( - rego.ParsedQuery(q), - ).PrepareForEval(context.Background()) - } - require.NoError(t, err, "prepare query") - prepareQueries = append(prepareQueries, prepped) - } - return &PartialAuthorizer{ - partialQueries: ®o.PartialQueries{ - Queries: astQueries, - Support: []*ast.Module{}, - }, - preparedQueries: prepareQueries, - input: nil, - alwaysTrue: false, - } -} diff --git a/coderd/rbac/regosql/acl_group_var.go b/coderd/rbac/regosql/acl_group_var.go new file mode 100644 index 0000000000..f15a520b8d --- /dev/null +++ b/coderd/rbac/regosql/acl_group_var.go @@ -0,0 +1,102 @@ +package regosql + +import ( + "fmt" + + "golang.org/x/xerrors" + + "github.com/open-policy-agent/opa/ast" + + "github.com/coder/coder/coderd/rbac/regosql/sqltypes" +) + +var _ sqltypes.VariableMatcher = ACLGroupVar{} +var _ sqltypes.Node = ACLGroupVar{} + +// ACLGroupVar is a variable matcher that handles group_acl and user_acl. +// The sql type is a jsonb object with the following structure: +// +// "group_acl": { +// "": [""] +// } +// +// This is a custom variable matcher as json objects have arbitrary complexity. +type ACLGroupVar struct { + StructSQL string + // input.object.group_acl -> ["input", "object", "group_acl"] + StructPath []string + + // FieldReference handles referencing the subfields, which could be + // more variables. We pass one in as the global one might not be correctly + // scoped. + FieldReference sqltypes.VariableMatcher + + // Instance fields + Source sqltypes.RegoSource + GroupNode sqltypes.Node +} + +func ACLGroupMatcher(fieldReference sqltypes.VariableMatcher, structSQL string, structPath []string) ACLGroupVar { + return ACLGroupVar{StructSQL: structSQL, StructPath: structPath, FieldReference: fieldReference} +} + +func (ACLGroupVar) UseAs() sqltypes.Node { return ACLGroupVar{} } + +func (g ACLGroupVar) ConvertVariable(rego ast.Ref) (sqltypes.Node, bool) { + // "left" will be a map of group names to actions in rego. + // { + // "all_users": ["read"] + // } + left, err := sqltypes.RegoVarPath(g.StructPath, rego) + if err != nil { + return nil, false + } + + aclGrp := ACLGroupVar{ + StructSQL: g.StructSQL, + StructPath: g.StructPath, + FieldReference: g.FieldReference, + + Source: sqltypes.RegoSource(rego.String()), + } + + // We expect 1 more term. Either a ref or a string. + if len(left) != 1 { + return nil, false + } + + // If the remaining is a variable, then we need to convert it. + // Assuming we support variable fields. + ref, ok := left[0].Value.(ast.Ref) + if ok && g.FieldReference != nil { + groupNode, ok := g.FieldReference.ConvertVariable(ref) + if ok { + aclGrp.GroupNode = groupNode + return aclGrp, true + } + } + + // If it is a string, we assume it is a literal + groupName, ok := left[0].Value.(ast.String) + if ok { + aclGrp.GroupNode = sqltypes.String(string(groupName)) + return aclGrp, true + } + + // If we have not matched it yet, then it is something we do not recognize. + return nil, false +} + +func (g ACLGroupVar) SQLString(cfg *sqltypes.SQLGenerator) string { + return fmt.Sprintf("%s->%s", g.StructSQL, g.GroupNode.SQLString(cfg)) +} + +func (g ACLGroupVar) ContainsSQL(cfg *sqltypes.SQLGenerator, other sqltypes.Node) (string, error) { + switch other.UseAs().(type) { + // Only supports containing other strings. + case sqltypes.AstString: + return fmt.Sprintf("%s ? %s", g.SQLString(cfg), other.SQLString(cfg)), nil + default: + return "", xerrors.Errorf("unsupported acl group contains %T", other) + } +} diff --git a/coderd/rbac/regosql/compile.go b/coderd/rbac/regosql/compile.go new file mode 100644 index 0000000000..599348b0e4 --- /dev/null +++ b/coderd/rbac/regosql/compile.go @@ -0,0 +1,230 @@ +package regosql + +import ( + "encoding/json" + "strings" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/rbac/regosql/sqltypes" +) + +// ConvertConfig is required to generate SQL from the rego queries. +type ConvertConfig struct { + // VariableConverter is called each time a var is encountered. This creates + // the SQL ast for the variable. Without this, the SQL generator does not + // know how to convert rego variables into SQL columns. + VariableConverter sqltypes.VariableMatcher +} + +// ConvertRegoAst converts partial rego queries into a single SQL where +// clause. If the query equates to "true" then the user should have access. +func ConvertRegoAst(cfg ConvertConfig, partial *rego.PartialQueries) (sqltypes.BooleanNode, error) { + if len(partial.Queries) == 0 { + // Always deny if there are no queries. This means there is no possible + // way this user can access these resources. + return sqltypes.Bool(false), nil + } + + for _, q := range partial.Queries { + // An empty query in rego means "true". If any query in the set is + // empty, then the user should have access. + if len(q) == 0 { + // Always allow + return sqltypes.Bool(true), nil + } + } + + var queries []sqltypes.BooleanNode + var builder strings.Builder + for i, q := range partial.Queries { + converted, err := convertQuery(cfg, q) + if err != nil { + return nil, xerrors.Errorf("query %s: %w", q.String(), err) + } + + if i != 0 { + builder.WriteString("\n") + } + builder.WriteString(q.String()) + queries = append(queries, converted) + } + + // All queries are OR'd together. This means that if any query is true, + // then the user should have access. + sqlClause := sqltypes.Or(sqltypes.RegoSource(builder.String()), queries...) + // Always wrap in parens to ensure the correct precedence when combining with other + // SQL clauses. + return sqltypes.BoolParenthesis(sqlClause), nil +} + +func convertQuery(cfg ConvertConfig, q ast.Body) (sqltypes.BooleanNode, error) { + var expressions []sqltypes.BooleanNode + for _, e := range q { + exp, err := convertExpression(cfg, e) + if err != nil { + return nil, xerrors.Errorf("expression %s: %w", e.String(), err) + } + + expressions = append(expressions, exp) + } + + // All expressions in a single query are AND'd together. This means that + // all expressions must be true for the user to have access. + return sqltypes.And(sqltypes.RegoSource(q.String()), expressions...), nil +} + +func convertExpression(cfg ConvertConfig, e *ast.Expr) (sqltypes.BooleanNode, error) { + if e.IsCall() { + n, err := convertCall(cfg, e.Terms.([]*ast.Term)) + if err != nil { + return nil, xerrors.Errorf("call: %w", err) + } + + boolN, ok := n.(sqltypes.BooleanNode) + if !ok { + return nil, xerrors.Errorf("call %q: not a boolean expression", e.String()) + } + return boolN, nil + } + + // If it's not a call, it is a single term + if term, ok := e.Terms.(*ast.Term); ok { + ty, err := convertTerm(cfg, term) + if err != nil { + return nil, xerrors.Errorf("convert term %s: %w", term.String(), err) + } + + tyBool, ok := ty.(sqltypes.BooleanNode) + if !ok { + return nil, xerrors.Errorf("convert term %s is not a boolean: %w", term.String(), err) + } + + return tyBool, nil + } + + return nil, xerrors.Errorf("expression %s not supported", e.String()) +} + +// convertCall converts a function call to a SQL expression. +func convertCall(cfg ConvertConfig, call ast.Call) (sqltypes.Node, error) { + if len(call) == 0 { + return nil, xerrors.Errorf("empty call") + } + + // Operator is the first term + op := call[0] + var args []*ast.Term + if len(call) > 1 { + args = call[1:] + } + + opString := op.String() + // Supported operators. + switch op.String() { + case "neq", "eq", "equals", "equal": + args, err := convertTerms(cfg, args, 2) + if err != nil { + return nil, xerrors.Errorf("arguments: %w", err) + } + + not := false + if opString == "neq" || opString == "notequals" || opString == "notequal" { + not = true + } + + equality := sqltypes.Equality(not, args[0], args[1]) + return sqltypes.BoolParenthesis(equality), nil + case "internal.member_2": + args, err := convertTerms(cfg, args, 2) + if err != nil { + return nil, xerrors.Errorf("arguments: %w", err) + } + + member := sqltypes.MemberOf(args[0], args[1]) + return sqltypes.BoolParenthesis(member), nil + default: + return nil, xerrors.Errorf("operator %s not supported", op) + } +} + +func convertTerms(cfg ConvertConfig, terms []*ast.Term, expected int) ([]sqltypes.Node, error) { + if len(terms) != expected { + return nil, xerrors.Errorf("expected %d terms, got %d", expected, len(terms)) + } + + result := make([]sqltypes.Node, 0, len(terms)) + for _, t := range terms { + term, err := convertTerm(cfg, t) + if err != nil { + return nil, xerrors.Errorf("term: %w", err) + } + result = append(result, term) + } + + return result, nil +} + +func convertTerm(cfg ConvertConfig, term *ast.Term) (sqltypes.Node, error) { + source := sqltypes.RegoSource(term.String()) + switch t := term.Value.(type) { + case ast.Var: + // All vars should be contained in ast.Ref's. + return nil, xerrors.New("var not yet supported") + case ast.Ref: + if len(t) == 0 { + // A reference with no text is a variable with no name? + // This makes no sense. + return nil, xerrors.New("empty ref not supported") + } + + if cfg.VariableConverter == nil { + return nil, xerrors.New("no variable converter provided to handle variables") + } + + // The structure of references is as follows: + // 1. All variables start with a regoAst.Var as the first term. + // 2. The next term is either a regoAst.String or a regoAst.Var. + // - regoAst.String if a static field name or index. + // - regoAst.Var if the field reference is a variable itself. Such as + // the wildcard "[_]" + // 3. Repeat 1-2 until the end of the reference. + node, ok := cfg.VariableConverter.ConvertVariable(t) + if !ok { + return nil, xerrors.Errorf("variable %q cannot be converted", t.String()) + } + return node, nil + case ast.String: + return sqltypes.String(string(t)), nil + case ast.Number: + return sqltypes.Number(source, json.Number(t)), nil + case ast.Boolean: + return sqltypes.Bool(bool(t)), nil + case *ast.Array: + elems := make([]sqltypes.Node, 0, t.Len()) + for i := 0; i < t.Len(); i++ { + value, err := convertTerm(cfg, t.Elem(i)) + if err != nil { + return nil, xerrors.Errorf("array element %d in %q: %w", i, t.String(), err) + } + elems = append(elems, value) + } + return sqltypes.Array(source, elems...) + case ast.Object: + return nil, xerrors.New("object not yet supported") + case ast.Set: + // Just treat a set like an array for now. + arr := t.Sorted() + return convertTerm(cfg, &ast.Term{ + Value: arr, + Location: term.Location, + }) + case ast.Call: + // This is a function call + return convertCall(cfg, t) + default: + return nil, xerrors.Errorf("%T not yet supported", t) + } +} diff --git a/coderd/rbac/regosql/compile_test.go b/coderd/rbac/regosql/compile_test.go new file mode 100644 index 0000000000..0799816f2d --- /dev/null +++ b/coderd/rbac/regosql/compile_test.go @@ -0,0 +1,307 @@ +package regosql_test + +import ( + "testing" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/rbac/regosql" + "github.com/coder/coder/coderd/rbac/regosql/sqltypes" +) + +// TestRegoQueriesNoVariables handles cases without variables. These should be +// very simple and straight forward. +func TestRegoQueries(t *testing.T) { + t.Parallel() + + p := func(v string) string { + return "(" + v + ")" + } + + testCases := []struct { + Name string + Queries []string + ExpectedSQL string + ExpectError bool + ExpectedSQLGenError bool + + VariableConverter sqltypes.VariableMatcher + }{ + { + Name: "Empty", + Queries: []string{``}, + ExpectedSQL: "true", + }, + { + Name: "True", + Queries: []string{`true`}, + ExpectedSQL: "true", + }, + { + Name: "False", + Queries: []string{`false`}, + ExpectedSQL: "false", + }, + { + Name: "MultipleBool", + Queries: []string{"true", "false"}, + ExpectedSQL: "(true OR false)", + }, + { + Name: "Numbers", + Queries: []string{ + "(1 != 2) = true", + "5 == 5", + }, + ExpectedSQL: p("((1 != 2) = true) OR (5 = 5)"), + }, + // Variables + { + // Always return a constant string for all variables. + Name: "V_Basic", + Queries: []string{ + `input.x = "hello_world"`, + }, + ExpectedSQL: p("only_var = 'hello_world'"), + VariableConverter: sqltypes.NewVariableConverter().RegisterMatcher( + sqltypes.StringVarMatcher("only_var", []string{ + "input", "x", + }), + ), + }, + // Coder Variables + { + // Always return a constant string for all variables. + Name: "GroupACL", + Queries: []string{ + `"read" in input.object.acl_group_list.allUsers`, + }, + ExpectedSQL: "(group_acl->'allUsers' ? 'read')", + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "GroupWildcard", + Queries: []string{`"*" in input.object.acl_group_list.allUsers`}, + ExpectedSQL: "(group_acl->'allUsers' ? '*')", + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + // Always return a constant string for all variables. + Name: "GroupACLWithVarField", + Queries: []string{ + `"read" in input.object.acl_group_list[input.object.org_owner]`, + }, + ExpectedSQL: "(group_acl->organization_id :: text ? 'read')", + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "VarInArray", + Queries: []string{ + `input.object.org_owner in {"a", "b", "c"}`, + }, + ExpectedSQL: p("organization_id :: text = ANY(ARRAY ['a','b','c'])"), + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "SetDereference", + Queries: []string{`"*" in input.object.acl_group_list[input.object.org_owner]`}, + ExpectedSQL: p("group_acl->organization_id :: text ? '*'"), + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "JsonbLiteralDereference", + Queries: []string{`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`}, + ExpectedSQL: p("group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'"), + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "Complex", + Queries: []string{ + `input.object.org_owner != ""`, + `input.object.org_owner in {"a", "b", "c"}`, + `input.object.org_owner != ""`, + `"read" in input.object.acl_group_list.allUsers`, + `"read" in input.object.acl_user_list.me`, + }, + ExpectedSQL: `((organization_id :: text != '') OR ` + + `(organization_id :: text = ANY(ARRAY ['a','b','c'])) OR ` + + `(organization_id :: text != '') OR ` + + `(group_acl->'allUsers' ? 'read') OR ` + + `(user_acl->'me' ? 'read'))`, + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "NoACLs", + Queries: []string{ + `"read" in input.object.acl_group_list[input.object.org_owner]`, + `"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`, + }, + // Special case where the bool is wrapped + ExpectedSQL: p("(false) OR (false)"), + VariableConverter: regosql.NoACLConverter(), + }, + { + Name: "TwoExpressions", + Queries: []string{ + `true; true`, + }, + ExpectedSQL: p("true AND true"), + VariableConverter: regosql.DefaultVariableConverter(), + }, + + // Actual vectors from production + { + Name: "FromOwner", + Queries: []string{ + ``, + `"05f58202-4bfc-43ce-9ba4-5ff6e0174a71" = input.object.org_owner`, + `"read" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`, + }, + ExpectedSQL: "true", + VariableConverter: regosql.NoACLConverter(), + }, + { + Name: "OrgAdmin", + Queries: []string{ + `input.object.org_owner != ""; + input.object.org_owner in {"05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}; + input.object.owner != ""; + "d5389ccc-57a4-4b13-8c3f-31747bcdc9f1" = input.object.owner`, + }, + ExpectedSQL: "((organization_id :: text != '') AND " + + "(organization_id :: text = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'])) AND " + + "(owner_id :: text != '') AND " + + "('d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' = owner_id :: text))", + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "UserACLAllow", + Queries: []string{ + `"read" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`, + `"*" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`, + }, + ExpectedSQL: "((user_acl->'d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' ? 'read') OR " + + "(user_acl->'d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' ? '*'))", + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "NoACLConfig", + Queries: []string{ + `input.object.org_owner != ""; + input.object.org_owner in {"05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}; + "read" in input.object.acl_group_list[input.object.org_owner]`, + }, + ExpectedSQL: "((organization_id :: text != '') AND (organization_id :: text = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'])) AND (false))", + VariableConverter: regosql.NoACLConverter(), + }, + { + Name: "EmptyACLListNoACLs", + Queries: []string{ + `input.object.org_owner != ""; + input.object.org_owner in set(); + "create" in input.object.acl_group_list[input.object.org_owner]`, + + `input.object.org_owner != ""; + input.object.org_owner in set(); + "*" in input.object.acl_group_list[input.object.org_owner]`, + + `"create" in input.object.acl_user_list.me`, + + `"*" in input.object.acl_user_list.me`, + }, + ExpectedSQL: p(p("(organization_id :: text != '') AND (false) AND (group_acl->organization_id :: text ? 'create')") + " OR " + + p("(organization_id :: text != '') AND (false) AND (group_acl->organization_id :: text ? '*')") + " OR " + + p("user_acl->'me' ? 'create'") + " OR " + + p("user_acl->'me' ? '*'")), + VariableConverter: regosql.DefaultVariableConverter(), + }, + { + Name: "TemplateOwner", + Queries: []string{ + `neq(input.object.org_owner, ""); +internal.member_2(input.object.org_owner, {"3bf82434-e40b-44ae-b3d8-d0115bba9bad", "5630fda3-26ab-462c-9014-a88a62d7a415", "c304877a-bc0d-4e9b-9623-a38eae412929"}); +neq(input.object.owner, ""); +"806dd721-775f-4c85-9ce3-63fbbd975954" = input.object.owner`, + }, + ExpectedSQL: p(p("organization_id :: text != ''") + " AND " + + p("organization_id :: text = ANY(ARRAY ['3bf82434-e40b-44ae-b3d8-d0115bba9bad','5630fda3-26ab-462c-9014-a88a62d7a415','c304877a-bc0d-4e9b-9623-a38eae412929'])") + " AND " + + p("false") + " AND " + + p("false")), + VariableConverter: regosql.TemplateConverter(), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + part := partialQueries(tc.Queries...) + + cfg := regosql.ConvertConfig{ + VariableConverter: tc.VariableConverter, + } + + requireConvert(t, convertTestCase{ + part: part, + cfg: cfg, + expectSQL: tc.ExpectedSQL, + expectConvertError: tc.ExpectError, + expectSQLGenError: tc.ExpectedSQLGenError, + }) + }) + } +} + +type convertTestCase struct { + part *rego.PartialQueries + cfg regosql.ConvertConfig + + expectConvertError bool + expectSQL string + expectSQLGenError bool +} + +func requireConvert(t *testing.T, tc convertTestCase) { + t.Helper() + + for i, q := range tc.part.Queries { + t.Logf("Query %d: %s", i, q.String()) + } + for i, s := range tc.part.Support { + t.Logf("Support %d: %s", i, s.String()) + } + + root, err := regosql.ConvertRegoAst(tc.cfg, tc.part) + if tc.expectConvertError { + require.Error(t, err) + } else { + require.NoError(t, err, "compile") + + gen := sqltypes.NewSQLGenerator() + sqlString := root.SQLString(gen) + if tc.expectSQLGenError { + require.True(t, len(gen.Errors()) > 0, "expected SQL generation error") + } else { + require.NoError(t, err, "sql gen") + require.Equal(t, tc.expectSQL, sqlString, "sql match") + } + } +} + +func partialQueries(queries ...string) *rego.PartialQueries { + opts := ast.ParserOptions{ + AllFutureKeywords: true, + } + + astQueries := make([]ast.Body, 0, len(queries)) + for _, q := range queries { + astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts)) + } + + return ®o.PartialQueries{ + Queries: astQueries, + Support: []*ast.Module{}, + } +} diff --git a/coderd/rbac/regosql/configs.go b/coderd/rbac/regosql/configs.go new file mode 100644 index 0000000000..7064ceccb1 --- /dev/null +++ b/coderd/rbac/regosql/configs.go @@ -0,0 +1,60 @@ +package regosql + +import "github.com/coder/coder/coderd/rbac/regosql/sqltypes" + +func organizationOwnerMatcher() sqltypes.VariableMatcher { + return sqltypes.StringVarMatcher("organization_id :: text", []string{"input", "object", "org_owner"}) +} + +func userOwnerMatcher() sqltypes.VariableMatcher { + return sqltypes.StringVarMatcher("owner_id :: text", []string{"input", "object", "owner"}) +} + +func groupACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher { + return ACLGroupMatcher(m, "group_acl", []string{"input", "object", "acl_group_list"}) +} + +func userACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher { + return ACLGroupMatcher(m, "user_acl", []string{"input", "object", "acl_user_list"}) +} + +func TemplateConverter() *sqltypes.VariableConverter { + matcher := sqltypes.NewVariableConverter().RegisterMatcher( + organizationOwnerMatcher(), + // Templates have no user owner, only owner by an organization. + sqltypes.AlwaysFalse(userOwnerMatcher()), + ) + matcher.RegisterMatcher( + groupACLMatcher(matcher), + userACLMatcher(matcher), + ) + return matcher +} + +// NoACLConverter should be used when the target SQL table does not contain +// group or user ACL columns. +func NoACLConverter() *sqltypes.VariableConverter { + matcher := sqltypes.NewVariableConverter().RegisterMatcher( + organizationOwnerMatcher(), + userOwnerMatcher(), + ) + matcher.RegisterMatcher( + sqltypes.AlwaysFalse(groupACLMatcher(matcher)), + sqltypes.AlwaysFalse(userACLMatcher(matcher)), + ) + + return matcher +} + +func DefaultVariableConverter() *sqltypes.VariableConverter { + matcher := sqltypes.NewVariableConverter().RegisterMatcher( + organizationOwnerMatcher(), + userOwnerMatcher(), + ) + matcher.RegisterMatcher( + groupACLMatcher(matcher), + userACLMatcher(matcher), + ) + + return matcher +} diff --git a/coderd/rbac/regosql/doc.go b/coderd/rbac/regosql/doc.go new file mode 100644 index 0000000000..6be58573de --- /dev/null +++ b/coderd/rbac/regosql/doc.go @@ -0,0 +1,3 @@ +// Package regosql converts rego queries into SQL WHERE clauses. This is so +// the rego queries can be used to filter the results of a SQL query. +package regosql diff --git a/coderd/rbac/regosql/sqltypes/always_false.go b/coderd/rbac/regosql/sqltypes/always_false.go new file mode 100644 index 0000000000..7555805050 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/always_false.go @@ -0,0 +1,61 @@ +package sqltypes + +import ( + "github.com/open-policy-agent/opa/ast" +) + +var _ Node = alwaysFalse{} +var _ VariableMatcher = alwaysFalse{} + +type alwaysFalse struct { + Matcher VariableMatcher + + InnerNode Node +} + +// AlwaysFalse overrides the inner node with a constant "false". +func AlwaysFalse(m VariableMatcher) VariableMatcher { + return alwaysFalse{ + Matcher: m, + } +} + +// AlwaysFalseNode is mainly used for unit testing to make a Node immediately. +func AlwaysFalseNode(n Node) Node { + return alwaysFalse{ + InnerNode: n, + Matcher: nil, + } +} + +// UseAs uses a type no one supports to always override with false. +func (alwaysFalse) UseAs() Node { return alwaysFalse{} } +func (f alwaysFalse) ConvertVariable(rego ast.Ref) (Node, bool) { + if f.Matcher != nil { + n, ok := f.Matcher.ConvertVariable(rego) + if ok { + return alwaysFalse{ + Matcher: f.Matcher, + InnerNode: n, + }, true + } + } + + return nil, false +} + +func (alwaysFalse) SQLString(_ *SQLGenerator) string { + return "false" +} + +func (alwaysFalse) ContainsSQL(_ *SQLGenerator, _ Node) (string, error) { + return "false", nil +} + +func (alwaysFalse) ContainedInSQL(_ *SQLGenerator, _ Node) (string, error) { + return "false", nil +} + +func (alwaysFalse) EqualsSQLString(_ *SQLGenerator, _ bool, _ Node) (string, error) { + return "false", nil +} diff --git a/coderd/rbac/regosql/sqltypes/array.go b/coderd/rbac/regosql/sqltypes/array.go new file mode 100644 index 0000000000..8ea0138a1a --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/array.go @@ -0,0 +1,69 @@ +package sqltypes + +import ( + "fmt" + "reflect" + "strings" + + "golang.org/x/xerrors" +) + +type ASTArray struct { + Source RegoSource + Value []Node +} + +// Array is typed to whatever the first element is. If there is not first +// element, the array element type is invalid. +func Array(source RegoSource, nodes ...Node) (Node, error) { + for i := 1; i < len(nodes); i++ { + if reflect.TypeOf(nodes[0]) != reflect.TypeOf(nodes[i]) { + // Do not allow mixed types in arrays + return nil, xerrors.Errorf("array element %d in %q: type mismatch", i, source) + } + } + return ASTArray{Value: nodes, Source: source}, nil +} + +func (ASTArray) UseAs() Node { return ASTArray{} } + +func (a ASTArray) ContainsSQL(cfg *SQLGenerator, needle Node) (string, error) { + // If we have no elements in our set, then our needle is never in the set. + if len(a.Value) == 0 { + return "false", nil + } + + // This condition supports any contains function if the needle type is + // the same as the ASTArray element type. + if reflect.TypeOf(a.MyType().UseAs()) != reflect.TypeOf(needle.UseAs()) { + return "ArrayContainsError", xerrors.Errorf("array contains %q: type mismatch (%T, %T)", + a.Source, a.MyType(), needle) + } + + return fmt.Sprintf("%s = ANY(%s)", needle.SQLString(cfg), a.SQLString(cfg)), nil +} + +func (a ASTArray) SQLString(cfg *SQLGenerator) string { + switch a.MyType().UseAs().(type) { + case invalidNode: + cfg.AddError(xerrors.Errorf("array %q: empty array", a.Source)) + return "ArrayError" + case AstNumber, AstString, AstBoolean: + // Primitive types + values := make([]string, 0, len(a.Value)) + for _, v := range a.Value { + values = append(values, v.SQLString(cfg)) + } + return fmt.Sprintf("ARRAY [%s]", strings.Join(values, ",")) + } + + cfg.AddError(xerrors.Errorf("array %q: unsupported type %T", a.Source, a.MyType())) + return "ArrayError" +} + +func (a ASTArray) MyType() Node { + if len(a.Value) == 0 { + return invalidNode{} + } + return a.Value[0] +} diff --git a/coderd/rbac/regosql/sqltypes/binary.go b/coderd/rbac/regosql/sqltypes/binary.go new file mode 100644 index 0000000000..c53dcfb7d6 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/binary.go @@ -0,0 +1,77 @@ +package sqltypes + +import ( + "strings" + + "golang.org/x/xerrors" +) + +type binaryOperator int + +const ( + _ binaryOperator = iota + binaryOpOR + binaryOpAND +) + +type binaryOp struct { + source RegoSource + op binaryOperator + + Terms []BooleanNode +} + +func (binaryOp) UseAs() Node { return binaryOp{} } +func (binaryOp) IsBooleanNode() {} + +func Or(source RegoSource, terms ...BooleanNode) BooleanNode { + return newBinaryOp(source, binaryOpOR, terms...) +} + +func And(source RegoSource, terms ...BooleanNode) BooleanNode { + return newBinaryOp(source, binaryOpAND, terms...) +} + +func newBinaryOp(source RegoSource, op binaryOperator, terms ...BooleanNode) BooleanNode { + if len(terms) == 0 { + // TODO: How to handle 0 terms? + return Bool(false) + } + + opTerms := make([]BooleanNode, 0, len(terms)) + for i := range terms { + // Always wrap terms in parentheses to be safe. + opTerms = append(opTerms, BoolParenthesis(terms[i])) + } + + if len(opTerms) == 1 { + return opTerms[0] + } + + return binaryOp{ + Terms: opTerms, + op: op, + source: source, + } +} + +func (b binaryOp) SQLString(cfg *SQLGenerator) string { + sqlOp := "" + switch b.op { + case binaryOpOR: + sqlOp = "OR" + case binaryOpAND: + sqlOp = "AND" + default: + cfg.AddError(xerrors.Errorf("unsupported binary operator: %s (%d)", b.source, b.op)) + return "BinaryOpError" + } + + terms := make([]string, 0, len(b.Terms)) + for _, term := range b.Terms { + termSQL := term.SQLString(cfg) + terms = append(terms, termSQL) + } + + return strings.Join(terms, " "+sqlOp+" ") +} diff --git a/coderd/rbac/regosql/sqltypes/bool.go b/coderd/rbac/regosql/sqltypes/bool.go new file mode 100644 index 0000000000..691548c647 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/bool.go @@ -0,0 +1,26 @@ +package sqltypes + +import ( + "strconv" +) + +// AstBoolean is a literal true/false value. +type AstBoolean struct { + Source RegoSource + Value bool +} + +func Bool(t bool) BooleanNode { + return AstBoolean{Value: t, Source: RegoSource(strconv.FormatBool(t))} +} + +func (AstBoolean) IsBooleanNode() {} +func (AstBoolean) UseAs() Node { return AstBoolean{} } + +func (b AstBoolean) SQLString(_ *SQLGenerator) string { + return strconv.FormatBool(b.Value) +} + +func (b AstBoolean) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) { + return boolEqualsSQLString(cfg, b, not, other) +} diff --git a/coderd/rbac/regosql/sqltypes/doc.go b/coderd/rbac/regosql/sqltypes/doc.go new file mode 100644 index 0000000000..5aa38c57d3 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/doc.go @@ -0,0 +1,5 @@ +// Package sqltypes contains the types used to convert rego queries into SQL. +// The rego ast is converted into these types to better control the SQL +// generation. It allows writing the SQL generation for types in an easier to +// read way. +package sqltypes diff --git a/coderd/rbac/regosql/sqltypes/equality.go b/coderd/rbac/regosql/sqltypes/equality.go new file mode 100644 index 0000000000..84134123a8 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/equality.go @@ -0,0 +1,102 @@ +package sqltypes + +import ( + "fmt" + + "golang.org/x/xerrors" +) + +// SupportsEquality is an interface that can be implemented by types that +// support equality with other types. We defer to other types to implement this +// as it is much easier to implement this in the context of the type. +type SupportsEquality interface { + // EqualsSQLString intentionally returns an error. This is so if + // left = right is not supported, we can try right = left. + EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) +} + +var _ BooleanNode = equality{} +var _ Node = equality{} +var _ SupportsEquality = equality{} + +type equality struct { + Left Node + Right Node + + // Not just inverses the result of the comparison. We could implement this + // as a Not node wrapping the equality, but this is more efficient. + Not bool +} + +func Equality(notEquals bool, a, b Node) BooleanNode { + return equality{ + Left: a, + Right: b, + Not: notEquals, + } +} + +func (equality) IsBooleanNode() {} + +// UseAs returns an ASTBoolean as equalities resolve to boolean values +func (equality) UseAs() Node { return AstBoolean{} } + +func (e equality) SQLString(cfg *SQLGenerator) string { + // Equalities can be flipped without changing the result, so we can + // try both left = right and right = left. + if eq, ok := e.Left.(SupportsEquality); ok { + v, err := eq.EqualsSQLString(cfg, e.Not, e.Right) + if err == nil { + return v + } + } + + if eq, ok := e.Right.(SupportsEquality); ok { + v, err := eq.EqualsSQLString(cfg, e.Not, e.Left) + if err == nil { + return v + } + } + + cfg.AddError(xerrors.Errorf("unsupported equality: %T %s %T", e.Left, equalsOp(e.Not), e.Right)) + return "EqualityError" +} + +func (e equality) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) { + return boolEqualsSQLString(cfg, e, not, other) +} + +func boolEqualsSQLString(cfg *SQLGenerator, a BooleanNode, not bool, other Node) (string, error) { + switch other.UseAs().(type) { + case BooleanNode: + bn, ok := other.(BooleanNode) + if !ok { + return "", xerrors.Errorf("not a boolean node: %T", other) + } + + // Always wrap both sides in parens to ensure the correct precedence. + return fmt.Sprintf("%s %s %s", + BoolParenthesis(a).SQLString(cfg), + equalsOp(not), + BoolParenthesis(bn).SQLString(cfg), + ), nil + default: + return "", xerrors.Errorf("unsupported equality: %T %s %T", a, equalsOp(not), other) + } +} + +// nolint:revive +func equalsOp(not bool) string { + if not { + return "!=" + } + return "=" +} + +func basicSQLEquality(cfg *SQLGenerator, not bool, a, b Node) string { + return fmt.Sprintf("%s %s %s", + a.SQLString(cfg), + equalsOp(not), + b.SQLString(cfg), + ) +} diff --git a/coderd/rbac/regosql/sqltypes/equality_test.go b/coderd/rbac/regosql/sqltypes/equality_test.go new file mode 100644 index 0000000000..8764508ad8 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/equality_test.go @@ -0,0 +1,130 @@ +package sqltypes_test + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/rbac/regosql/sqltypes" +) + +func TestEquality(t *testing.T) { + t.Parallel() + testCases := []struct { + Name string + Equality sqltypes.Node + ExpectedSQL string + ExpectedErrors int + }{ + { + Name: "String=String", + Equality: sqltypes.Equality(false, + sqltypes.String("foo"), + sqltypes.String("bar"), + ), + ExpectedSQL: "'foo' = 'bar'", + }, + { + Name: "Number=Number", + Equality: sqltypes.Equality(false, + sqltypes.Number("", json.Number("5")), + sqltypes.Number("", json.Number("22")), + ), + ExpectedSQL: "5 = 22", + }, + { + Name: "Bool=Bool", + Equality: sqltypes.Equality(false, + sqltypes.Bool(true), + sqltypes.Bool(false), + ), + ExpectedSQL: "true = false", + }, + { + Name: "Bool=Equality", + Equality: sqltypes.Equality(false, + sqltypes.Bool(true), + sqltypes.Equality(true, + sqltypes.Equality(true, + sqltypes.String("foo"), + sqltypes.String("bar"), + ), + sqltypes.Bool(false), + ), + ), + ExpectedSQL: "true = (('foo' != 'bar') != false)", + }, + { + Name: "Equality=Equality", + Equality: sqltypes.Equality(false, + sqltypes.Equality(true, + sqltypes.Bool(true), + sqltypes.Bool(false), + ), + sqltypes.Equality(false, + sqltypes.String("foo"), + sqltypes.String("foo"), + ), + ), + ExpectedSQL: "(true != false) = ('foo' = 'foo')", + }, + { + Name: "Membership=Membership", + Equality: sqltypes.Equality(false, + sqltypes.Equality(true, + sqltypes.MemberOf( + sqltypes.String("foo"), + must(sqltypes.Array("", + sqltypes.String("foo"), + sqltypes.String("bar"), + )), + ), + sqltypes.Bool(false), + ), + sqltypes.Equality(false, + sqltypes.Bool(true), + sqltypes.MemberOf( + sqltypes.Number("", "2"), + must(sqltypes.Array("", + sqltypes.Number("", "5"), + sqltypes.Number("", "2"), + )), + ), + ), + ), + ExpectedSQL: "(('foo' = ANY(ARRAY ['foo','bar'])) != false) = (true = (2 = ANY(ARRAY [5,2])))", + }, + { + Name: "AlwaysFalse=String", + Equality: sqltypes.Equality(false, + sqltypes.AlwaysFalseNode(sqltypes.String("foo")), + sqltypes.String("foo"), + ), + ExpectedSQL: "false", + }, + { + Name: "String=AlwaysFalse", + Equality: sqltypes.Equality(false, + sqltypes.String("foo"), + sqltypes.AlwaysFalseNode(sqltypes.String("foo")), + ), + ExpectedSQL: "false", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + gen := sqltypes.NewSQLGenerator() + found := tc.Equality.SQLString(gen) + if tc.ExpectedErrors > 0 { + require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected AstNumber of errors") + } else { + require.Equal(t, tc.ExpectedSQL, found, "expected sql") + } + }) + } +} diff --git a/coderd/rbac/regosql/sqltypes/gen.go b/coderd/rbac/regosql/sqltypes/gen.go new file mode 100644 index 0000000000..3a06c98b8c --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/gen.go @@ -0,0 +1,19 @@ +package sqltypes + +type SQLGenerator struct { + errors []error +} + +func NewSQLGenerator() *SQLGenerator { + return &SQLGenerator{} +} + +func (g *SQLGenerator) AddError(err error) { + if err != nil { + g.errors = append(g.errors, err) + } +} + +func (g *SQLGenerator) Errors() []error { + return g.errors +} diff --git a/coderd/rbac/regosql/sqltypes/member.go b/coderd/rbac/regosql/sqltypes/member.go new file mode 100644 index 0000000000..7f022eb2a8 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/member.go @@ -0,0 +1,61 @@ +package sqltypes + +import ( + "golang.org/x/xerrors" +) + +// SupportsContains is an interface that can be implemented by types that +// support "me.Contains(other)". This is `internal_member2` in the rego. +type SupportsContains interface { + ContainsSQL(cfg *SQLGenerator, other Node) (string, error) +} + +// SupportsContainedIn is the inverse of SupportsContains. It is implemented +// from the "needle" rather than the haystack. +type SupportsContainedIn interface { + ContainedInSQL(cfg *SQLGenerator, other Node) (string, error) +} + +var _ BooleanNode = memberOf{} +var _ Node = memberOf{} +var _ SupportsEquality = memberOf{} + +type memberOf struct { + Needle Node + Haystack Node +} + +func MemberOf(needle, haystack Node) BooleanNode { + return memberOf{ + Needle: needle, + Haystack: haystack, + } +} + +func (memberOf) IsBooleanNode() {} +func (memberOf) UseAs() Node { return AstBoolean{} } + +func (e memberOf) SQLString(cfg *SQLGenerator) string { + // Equalities can be flipped without changing the result, so we can + // try both left = right and right = left. + if sc, ok := e.Haystack.(SupportsContains); ok { + v, err := sc.ContainsSQL(cfg, e.Needle) + if err == nil { + return v + } + } + + if sc, ok := e.Needle.(SupportsContainedIn); ok { + v, err := sc.ContainedInSQL(cfg, e.Haystack) + if err == nil { + return v + } + } + + cfg.AddError(xerrors.Errorf("unsupported contains: %T contains %T", e.Haystack, e.Needle)) + return "MemberOfError" +} + +func (e memberOf) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) { + return boolEqualsSQLString(cfg, e, not, other) +} diff --git a/coderd/rbac/regosql/sqltypes/member_test.go b/coderd/rbac/regosql/sqltypes/member_test.go new file mode 100644 index 0000000000..91259e286e --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/member_test.go @@ -0,0 +1,116 @@ +package sqltypes_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/rbac/regosql/sqltypes" +) + +func TestMembership(t *testing.T) { + t.Parallel() + testCases := []struct { + Name string + Membership sqltypes.Node + ExpectedSQL string + ExpectedErrors int + }{ + { + Name: "StringArray", + Membership: sqltypes.MemberOf( + sqltypes.String("foo"), + must(sqltypes.Array("", + sqltypes.String("bar"), + sqltypes.String("buzz"), + )), + ), + ExpectedSQL: "'foo' = ANY(ARRAY ['bar','buzz'])", + }, + { + Name: "NumberArray", + Membership: sqltypes.MemberOf( + sqltypes.Number("", "5"), + must(sqltypes.Array("", + sqltypes.Number("", "2"), + sqltypes.Number("", "5"), + )), + ), + ExpectedSQL: "5 = ANY(ARRAY [2,5])", + }, + { + Name: "BoolArray", + Membership: sqltypes.MemberOf( + sqltypes.Bool(true), + must(sqltypes.Array("", + sqltypes.Bool(false), + sqltypes.Bool(true), + )), + ), + ExpectedSQL: "true = ANY(ARRAY [false,true])", + }, + { + Name: "EmptyArray", + Membership: sqltypes.MemberOf( + sqltypes.Bool(true), + must(sqltypes.Array("")), + ), + ExpectedSQL: "false", + }, + { + Name: "AlwaysFalseMember", + Membership: sqltypes.MemberOf( + sqltypes.AlwaysFalseNode(sqltypes.Bool(true)), + must(sqltypes.Array("", + sqltypes.Bool(false), + sqltypes.Bool(true), + )), + ), + ExpectedSQL: "false", + }, + { + Name: "AlwaysFalseArray", + Membership: sqltypes.MemberOf( + sqltypes.Bool(true), + sqltypes.AlwaysFalseNode(must(sqltypes.Array("", + sqltypes.Bool(false), + sqltypes.Bool(true), + ))), + ), + ExpectedSQL: "false", + }, + + // Errors + { + Name: "Unsupported", + Membership: sqltypes.MemberOf( + sqltypes.Bool(true), + sqltypes.Bool(true), + ), + ExpectedErrors: 1, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + gen := sqltypes.NewSQLGenerator() + found := tc.Membership.SQLString(gen) + if tc.ExpectedErrors > 0 { + require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected some errors") + } else { + require.Equal(t, tc.ExpectedSQL, found, "expected sql") + require.Equal(t, tc.ExpectedErrors, 0, "expected no errors") + } + }) + } +} + +func must[V any](v V, err error) V { + if err != nil { + panic(err) + } + return v +} diff --git a/coderd/rbac/regosql/sqltypes/node.go b/coderd/rbac/regosql/sqltypes/node.go new file mode 100644 index 0000000000..9d78a71af0 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/node.go @@ -0,0 +1,40 @@ +package sqltypes + +import ( + "golang.org/x/xerrors" +) + +type Node interface { + SQLString(cfg *SQLGenerator) string + // UseAs is a helper function to allow a node to be used as a different + // Node in operators. For example, a variable is really just a "string", so + // having the Equality operator check for "String" or "StringVar" is just + // excessive. Instead, we can just have the variable implement this function. + UseAs() Node +} + +// BooleanNode is a node that returns a true/false when evaluated. +type BooleanNode interface { + Node + IsBooleanNode() +} + +type RegoSource string + +type invalidNode struct{} + +func (invalidNode) UseAs() Node { return invalidNode{} } + +func (invalidNode) SQLString(cfg *SQLGenerator) string { + cfg.AddError(xerrors.Errorf("invalid node called")) + return "invalid_type" +} + +func IsPrimitive(n Node) bool { + switch n.(type) { + case AstBoolean, AstString, AstNumber: + return true + } + + return false +} diff --git a/coderd/rbac/regosql/sqltypes/number.go b/coderd/rbac/regosql/sqltypes/number.go new file mode 100644 index 0000000000..fd0589c6db --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/number.go @@ -0,0 +1,36 @@ +package sqltypes + +import ( + "encoding/json" + + "golang.org/x/xerrors" +) + +type AstNumber struct { + Source RegoSource + // Value is intentionally vague as to if it's an integer or a float. + // This defers that decision to the user. Rego keeps all numbers in this + // type. If we were to source the type from something other than Rego, + // we might want to make a Float and Int type which keep the original + // precision. + Value json.Number +} + +func Number(source RegoSource, v json.Number) Node { + return AstNumber{Value: v, Source: source} +} + +func (AstNumber) UseAs() Node { return AstNumber{} } + +func (n AstNumber) SQLString(_ *SQLGenerator) string { + return n.Value.String() +} + +func (n AstNumber) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) { + switch other.UseAs().(type) { + case AstNumber: + return basicSQLEquality(cfg, not, n, other), nil + default: + return "", xerrors.Errorf("unsupported equality: %T %s %T", n, equalsOp(not), other) + } +} diff --git a/coderd/rbac/regosql/sqltypes/parens.go b/coderd/rbac/regosql/sqltypes/parens.go new file mode 100644 index 0000000000..82eb0225b7 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/parens.go @@ -0,0 +1,45 @@ +package sqltypes + +import ( + "golang.org/x/xerrors" +) + +type astParenthesis struct { + Value BooleanNode +} + +// BoolParenthesis wraps the given boolean node in parens. +// This is useful for grouping and avoiding ambiguity. This does not work for +// mathematical parenthesis to change order of operations. +func BoolParenthesis(value BooleanNode) BooleanNode { + // Wrapping primitives is useless. + if IsPrimitive(value) { + return value + } + + // Unwrap any existing parens. Do not add excess parens. + if p, ok := value.(astParenthesis); ok { + return BoolParenthesis(p.Value) + } + return astParenthesis{Value: value} +} + +func (astParenthesis) IsBooleanNode() {} +func (p astParenthesis) UseAs() Node { return p.Value.UseAs() } +func (p astParenthesis) SQLString(cfg *SQLGenerator) string { + return "(" + p.Value.SQLString(cfg) + ")" +} + +func (p astParenthesis) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) { + if supp, ok := p.Value.(SupportsEquality); ok { + return supp.EqualsSQLString(cfg, not, other) + } + return "", xerrors.Errorf("unsupported equality: %T %s %T", p.Value, equalsOp(not), other) +} + +func (p astParenthesis) ContainsSQL(cfg *SQLGenerator, other Node) (string, error) { + if supp, ok := p.Value.(SupportsContains); ok { + return supp.ContainsSQL(cfg, other) + } + return "", xerrors.Errorf("unsupported contains: %T %T", p.Value, other) +} diff --git a/coderd/rbac/regosql/sqltypes/string.go b/coderd/rbac/regosql/sqltypes/string.go new file mode 100644 index 0000000000..92060ef34d --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/string.go @@ -0,0 +1,29 @@ +package sqltypes + +import ( + "golang.org/x/xerrors" +) + +type AstString struct { + Source RegoSource + Value string +} + +func String(v string) Node { + return AstString{Value: v, Source: RegoSource(v)} +} + +func (AstString) UseAs() Node { return AstString{} } + +func (s AstString) SQLString(_ *SQLGenerator) string { + return "'" + s.Value + "'" +} + +func (s AstString) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) { + switch other.UseAs().(type) { + case AstString: + return basicSQLEquality(cfg, not, s, other), nil + default: + return "", xerrors.Errorf("unsupported equality: %T %s %T", s, equalsOp(not), other) + } +} diff --git a/coderd/rbac/regosql/sqltypes/variable.go b/coderd/rbac/regosql/sqltypes/variable.go new file mode 100644 index 0000000000..573dedb52f --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/variable.go @@ -0,0 +1,113 @@ +package sqltypes + +import ( + "golang.org/x/xerrors" + + "github.com/open-policy-agent/opa/ast" +) + +type VariableMatcher interface { + ConvertVariable(rego ast.Ref) (Node, bool) +} + +type VariableConverter struct { + converters []VariableMatcher +} + +func NewVariableConverter() *VariableConverter { + return &VariableConverter{} +} + +func (vc *VariableConverter) RegisterMatcher(m ...VariableMatcher) *VariableConverter { + vc.converters = append(vc.converters, m...) + // Returns the VariableConverter for easier instantiation + return vc +} + +func (vc *VariableConverter) ConvertVariable(rego ast.Ref) (Node, bool) { + for _, c := range vc.converters { + if n, ok := c.ConvertVariable(rego); ok { + return n, true + } + } + return nil, false +} + +// RegoVarPath will consume the following terms from the given rego Ref and +// return the remaining terms. If the path does not fully match, an error is +// returned. The first term must always be a Var. +func RegoVarPath(path []string, terms []*ast.Term) ([]*ast.Term, error) { + if len(terms) < len(path) { + return nil, xerrors.Errorf("path %s longer than rego path %s", path, terms) + } + + if len(terms) == 0 || len(path) == 0 { + return nil, xerrors.Errorf("path %s and rego path %s must not be empty", path, terms) + } + + varTerm, ok := terms[0].Value.(ast.Var) + if !ok { + return nil, xerrors.Errorf("expected var, got %T", terms[0]) + } + + if string(varTerm) != path[0] { + return nil, xerrors.Errorf("expected var %s, got %s", path[0], varTerm) + } + + for i := 1; i < len(path); i++ { + nextTerm, ok := terms[i].Value.(ast.String) + if !ok { + return nil, xerrors.Errorf("expected ast.string, got %T", terms[i]) + } + + if string(nextTerm) != path[i] { + return nil, xerrors.Errorf("expected string %s, got %s", path[i], nextTerm) + } + } + + return terms[len(path):], nil +} + +var _ VariableMatcher = astStringVar{} +var _ Node = astStringVar{} + +// astStringVar is any variable that represents a string. +type astStringVar struct { + Source RegoSource + FieldPath []string + ColumnString string +} + +func StringVarMatcher(sqlString string, regoPath []string) VariableMatcher { + return astStringVar{FieldPath: regoPath, ColumnString: sqlString} +} + +func (astStringVar) UseAs() Node { return AstString{} } + +// ConvertVariable will return a new astStringVar Node if the given rego Ref +// matches this astStringVar. +func (s astStringVar) ConvertVariable(rego ast.Ref) (Node, bool) { + left, err := RegoVarPath(s.FieldPath, rego) + if err == nil && len(left) == 0 { + return astStringVar{ + Source: RegoSource(rego.String()), + FieldPath: s.FieldPath, + ColumnString: s.ColumnString, + }, true + } + + return nil, false +} + +func (s astStringVar) SQLString(_ *SQLGenerator) string { + return s.ColumnString +} + +func (s astStringVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) { + switch other.UseAs().(type) { + case AstString: + return basicSQLEquality(cfg, not, s, other), nil + default: + return "", xerrors.Errorf("unsupported equality: %T %s %T", s, equalsOp(not), other) + } +} diff --git a/coderd/templates.go b/coderd/templates.go index 80dcac854c..1a97da0dd5 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -316,25 +316,27 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() organization := httpmw.OrganizationParam(r) - templates, err := api.Database.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{ - OrganizationID: organization.ID, - }) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } + + prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceTemplate.Type) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching templates in organization.", + Message: "Internal error preparing sql filter.", Detail: err.Error(), }) return } // Filter templates based on rbac permissions - templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates) + templates, err := api.Database.GetAuthorizedTemplates(ctx, database.GetTemplatesWithFilterParams{ + OrganizationID: organization.ID, + }, prepared) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching templates.", + Message: "Internal error fetching templates in organization.", Detail: err.Error(), }) return diff --git a/coderd/templates_test.go b/coderd/templates_test.go index c8d196b9d3..b581ebc2ff 100644 --- a/coderd/templates_test.go +++ b/coderd/templates_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/require" "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/agent" "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/coderdtest" diff --git a/coderd/workspaces.go b/coderd/workspaces.go index fe796f370c..e142a0cacd 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -118,7 +118,8 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { filter.OwnerUsername = "" } - sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type) + // Workspaces do not have ACL columns. + prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error preparing sql filter.", @@ -127,7 +128,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) { return } - workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter) + workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, prepared) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspaces.", diff --git a/enterprise/coderd/groups.go b/enterprise/coderd/groups.go index e467ba5b06..ec566c8569 100644 --- a/enterprise/coderd/groups.go +++ b/enterprise/coderd/groups.go @@ -31,7 +31,7 @@ func (api *API) postGroupByOrganization(rw http.ResponseWriter, r *http.Request) ) defer commitAudit() - if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceGroup) { + if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceGroup.InOrg(org.ID)) { http.NotFound(rw, r) return } diff --git a/enterprise/coderd/templates_test.go b/enterprise/coderd/templates_test.go index ee8003c553..f257766b0b 100644 --- a/enterprise/coderd/templates_test.go +++ b/enterprise/coderd/templates_test.go @@ -11,7 +11,9 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" + "github.com/coder/coder/cryptorand" "github.com/coder/coder/enterprise/coderd/coderdenttest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/testutil" @@ -747,3 +749,210 @@ func TestUpdateTemplateACL(t *testing.T) { require.Equal(t, http.StatusNotFound, cerr.StatusCode()) }) } + +// TestTemplateAccess tests the rego -> sql conversion. We need to implement +// this test on at least 1 table type to ensure that the conversion is correct. +// The rbac tests only assert against static SQL queries. +// This is a full rbac test of many of the common role combinations. +// +//nolint:tparallel +func TestTemplateAccess(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + t.Cleanup(cancel) + + ownerClient := coderdenttest.New(t, nil) + owner := coderdtest.CreateFirstUser(t, ownerClient) + _ = coderdenttest.AddLicense(t, ownerClient, coderdenttest.LicenseOptions{ + TemplateRBAC: true, + }) + + type coderUser struct { + *codersdk.Client + User codersdk.User + } + + type orgSetup struct { + Admin coderUser + MemberInGroup coderUser + MemberNoGroup coderUser + + DefaultTemplate codersdk.Template + AllRead codersdk.Template + UserACL codersdk.Template + GroupACL codersdk.Template + + Group codersdk.Group + Org codersdk.Organization + } + + // Create the following users + // - owner: Site wide owner + // - template-admin + // - org-admin (org 1) + // - org-admin (org 2) + // - org-member (org 1) + // - org-member (org 2) + + // Create the following templates in each org + // - template 1, default acls + // - template 2, all_user read + // - template 3, user_acl read for member + // - template 4, group_acl read for groupMember + + templateAdmin := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin()) + + makeTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, acl codersdk.UpdateTemplateACL) codersdk.Template { + version := coderdtest.CreateTemplateVersion(t, client, orgID, nil) + template := coderdtest.CreateTemplate(t, client, orgID, version.ID) + + err := client.UpdateTemplateACL(ctx, template.ID, acl) + require.NoError(t, err, "failed to update template acl") + + return template + } + + makeOrg := func(t *testing.T) orgSetup { + // Make org + orgName, err := cryptorand.String(5) + require.NoError(t, err, "org name") + + // Make users + newOrg, err := ownerClient.CreateOrganization(ctx, codersdk.CreateOrganizationRequest{Name: orgName}) + require.NoError(t, err, "failed to create org") + + adminCli, adminUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgAdmin(newOrg.ID)) + groupMemCli, groupMemUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgMember(newOrg.ID)) + memberCli, memberUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgMember(newOrg.ID)) + + // Make group + group, err := adminCli.CreateGroup(ctx, newOrg.ID, codersdk.CreateGroupRequest{ + Name: "SingleUser", + }) + require.NoError(t, err, "failed to create group") + + group, err = adminCli.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{ + AddUsers: []string{groupMemUsr.ID.String()}, + }) + require.NoError(t, err, "failed to add user to group") + + // Make templates + + return orgSetup{ + Admin: coderUser{Client: adminCli, User: adminUsr}, + MemberInGroup: coderUser{Client: groupMemCli, User: groupMemUsr}, + MemberNoGroup: coderUser{Client: memberCli, User: memberUsr}, + Org: newOrg, + Group: group, + + DefaultTemplate: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{ + GroupPerms: map[string]codersdk.TemplateRole{ + newOrg.ID.String(): codersdk.TemplateRoleDeleted, + }, + }), + AllRead: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{ + GroupPerms: map[string]codersdk.TemplateRole{ + newOrg.ID.String(): codersdk.TemplateRoleUse, + }, + }), + UserACL: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{ + GroupPerms: map[string]codersdk.TemplateRole{ + newOrg.ID.String(): codersdk.TemplateRoleDeleted, + }, + UserPerms: map[string]codersdk.TemplateRole{ + memberUsr.ID.String(): codersdk.TemplateRoleUse, + }, + }), + GroupACL: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{ + GroupPerms: map[string]codersdk.TemplateRole{ + group.ID.String(): codersdk.TemplateRoleUse, + newOrg.ID.String(): codersdk.TemplateRoleDeleted, + }, + }), + } + } + + // Make 2 organizations + orgs := []orgSetup{ + makeOrg(t), + makeOrg(t), + } + + testTemplateRead := func(t *testing.T, org orgSetup, usr *codersdk.Client, read []codersdk.Template) { + found, err := usr.TemplatesByOrganization(ctx, org.Org.ID) + require.NoError(t, err, "failed to get templates") + + exp := make(map[uuid.UUID]codersdk.Template) + for _, tmpl := range read { + exp[tmpl.ID] = tmpl + } + + for _, f := range found { + if _, ok := exp[f.ID]; !ok { + t.Errorf("found unexpected template %q", f.Name) + } + delete(exp, f.ID) + } + require.Len(t, exp, 0, "expected templates not found") + } + + // nolint:paralleltest + t.Run("OwnerReadAll", func(t *testing.T) { + for _, o := range orgs { + // Owners can read all templates in all orgs + exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL} + testTemplateRead(t, o, ownerClient, exp) + } + }) + + // nolint:paralleltest + t.Run("TemplateAdminReadAll", func(t *testing.T) { + for _, o := range orgs { + // Template Admins can read all templates in all orgs + exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL} + testTemplateRead(t, o, templateAdmin, exp) + } + }) + + // nolint:paralleltest + t.Run("OrgAdminReadAllTheirs", func(t *testing.T) { + for i, o := range orgs { + cli := o.Admin.Client + // Only read their own org + exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL} + testTemplateRead(t, o, cli, exp) + + other := orgs[(i+1)%len(orgs)] + require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs") + testTemplateRead(t, other, cli, []codersdk.Template{}) + } + }) + + // nolint:paralleltest + t.Run("TestMemberNoGroup", func(t *testing.T) { + for i, o := range orgs { + cli := o.MemberNoGroup.Client + // Only read their own org + exp := []codersdk.Template{o.AllRead, o.UserACL} + testTemplateRead(t, o, cli, exp) + + other := orgs[(i+1)%len(orgs)] + require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs") + testTemplateRead(t, other, cli, []codersdk.Template{}) + } + }) + + // nolint:paralleltest + t.Run("TestMemberInGroup", func(t *testing.T) { + for i, o := range orgs { + cli := o.MemberInGroup.Client + // Only read their own org + exp := []codersdk.Template{o.AllRead, o.GroupACL} + testTemplateRead(t, o, cli, exp) + + other := orgs[(i+1)%len(orgs)] + require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs") + testTemplateRead(t, other, cli, []codersdk.Template{}) + } + }) +}