chore: Rewrite rbac rego -> SQL clause (#5138)

* chore: Rewrite rbac rego -> SQL clause

Previous code was challenging to read with edge cases
- bug: OrgAdmin could not make new groups
- Also refactor some function names
This commit is contained in:
Steven Masley 2022-11-28 12:12:34 -06:00 committed by GitHub
parent d5ab4fdeb8
commit ab9298f382
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 2080 additions and 828 deletions

View File

@ -5,7 +5,6 @@ import (
"net/http"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
@ -95,19 +94,14 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
// from postgres are already authorized, and the caller does not need to
// call 'Authorize()' on the returned objects.
// Note the authorization is only for the given action and object type.
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) {
roles := httpmw.UserAuthorization(r)
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType)
if err != nil {
return nil, xerrors.Errorf("prepare filter: %w", err)
}
filter, err := prepared.Compile()
if err != nil {
return nil, xerrors.Errorf("compile filter: %w", err)
}
return filter, nil
return prepared, nil
}
// checkAuthorization returns if the current API key can use the given

View File

@ -9,6 +9,8 @@ import (
"strings"
"testing"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -16,12 +18,19 @@ import (
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
)
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
// For any route using SQL filters, we need to know if the database is an
// in memory fake. This is because the in memory fake does not use SQL, and
// still uses rego. So this boolean indicates how to assert the expected
// behavior.
_, isMemoryDB := a.api.Database.(databasefake.FakeDatabase)
// Some quick reused objects
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
@ -125,11 +134,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,
},
"GET:/api/v2/organizations/{organization}/templates": {
StatusCode: http.StatusOK,
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID),
},
"POST:/api/v2/organizations/{organization}/templates": {
AssertAction: rbac.ActionCreate,
AssertObject: rbac.ResourceTemplate.InOrg(a.Organization.ID),
@ -240,7 +244,18 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"GET:/api/v2/organizations/{organization}/templateversions/{templateversionname}": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
// Endpoints that use the SQLQuery filter.
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
"GET:/api/v2/workspaces/": {
StatusCode: http.StatusOK,
NoAuthorize: !isMemoryDB,
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceWorkspace,
},
"GET:/api/v2/organizations/{organization}/templates": {
StatusCode: http.StatusOK,
NoAuthorize: !isMemoryDB,
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceTemplate,
},
}
// Routes like proxy routes support all HTTP methods. A helper func to expand
@ -549,10 +564,10 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object)
}
// Compile returns a compiled version of the authorizer that will work for
// CompileToSQL returns a compiled version of the authorizer that will work for
// in memory databases. This fake version will not work against a SQL database.
func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
return f, nil
func (fakePreparedAuthorizer) CompileToSQL(_ regosql.ConvertConfig) (string, error) {
return "", xerrors.New("not implemented")
}
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
@ -565,10 +580,3 @@ func (f fakePreparedAuthorizer) RegoString() string {
}
panic("not implemented")
}
func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string {
if f.HardCodedSQLString != "" {
return f.HardCodedSQLString
}
panic("not implemented")
}

View File

@ -20,6 +20,13 @@ import (
"github.com/coder/coder/coderd/util/slice"
)
// FakeDatabase is helpful for knowing if the underlying db is an in memory fake
// database. This is only in the databasefake package, so will only be used
// by unit tests.
type FakeDatabase interface {
IsFakeDB()
}
var errDuplicateKey = &pq.Error{
Code: "23505",
Message: "duplicate key value violates unique constraint",
@ -117,6 +124,7 @@ type data struct {
lastLicenseID int32
}
func (fakeQuerier) IsFakeDB() {}
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
return 0, nil
}
@ -488,11 +496,20 @@ func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.Get
return count, err
}
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
users := append([]database.User{}, q.users...)
users := make([]database.User, 0, len(q.users))
for _, user := range q.users {
// If the filter exists, ensure the object is authorized.
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
continue
}
users = append(users, user)
}
if params.Deleted {
tmp := make([]database.User, 0, len(users))
@ -539,13 +556,6 @@ func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.
users = usersFilteredByRole
}
for _, user := range q.workspaces {
// If the filter exists, ensure the object is authorized.
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
continue
}
}
return int64(len(users)), nil
}
@ -750,7 +760,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa
}
//nolint:gocyclo
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.GetWorkspacesRow, error) {
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -923,7 +933,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
}
// If the filter exists, ensure the object is authorized.
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
continue
}
workspaces = append(workspaces, workspace)
@ -1505,12 +1515,20 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd
return database.Template{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
return q.GetAuthorizedTemplates(ctx, arg, nil)
}
func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var templates []database.Template
for _, template := range q.templates {
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
continue
}
if template.Deleted != arg.Deleted {
continue
}

View File

@ -74,6 +74,7 @@ func TestExactMethods(t *testing.T) {
extraFakeMethods := map[string]string{
// Example
// "SortFakeLists": "Helper function used",
"IsFakeDB": "Helper function used for unit testing",
}
fake := reflect.TypeOf(databasefake.New())

View File

@ -5,12 +5,16 @@ import (
"fmt"
"strings"
"github.com/google/uuid"
"github.com/lib/pq"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/rbac/regosql"
)
"github.com/google/uuid"
"golang.org/x/xerrors"
const (
authorizedQueryPlaceholder = "-- @authorize_filter"
)
// customQuerier encompasses all non-generated queries.
@ -23,10 +27,70 @@ type customQuerier interface {
}
type templateQuerier interface {
GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error)
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
}
func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) {
authorizedFilter, err := prepared.CompileToSQL(regosql.ConvertConfig{
VariableConverter: regosql.TemplateConverter(),
})
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(getTemplatesWithFilter, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedTemplates :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.Deleted,
arg.OrganizationID,
arg.ExactName,
pq.Array(arg.IDs),
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OrganizationID,
&i.Deleted,
&i.Name,
&i.Provisioner,
&i.ActiveVersionID,
&i.Description,
&i.DefaultTTL,
&i.CreatedBy,
&i.Icon,
&i.UserACL,
&i.GroupACL,
&i.DisplayName,
&i.AllowUserCancelWorkspaceJobs,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
type TemplateUser struct {
User
Actions Actions `db:"actions"`
@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
}
type workspaceQuerier interface {
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error)
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error)
}
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
// clause.
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) {
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) {
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
// authorizedFilter between the end of the where clause and those statements.
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
filtered, err := insertAuthorizedFilter(getWorkspaces, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter)
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.Deleted,
arg.Status,
@ -172,12 +245,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
}
type userQuerier interface {
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
}
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
if err != nil {
return -1, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return -1, xerrors.Errorf("insert authorized filter: %w", err)
}
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
row := q.db.QueryRowContext(ctx, query,
arg.Deleted,
arg.Search,
@ -185,6 +267,14 @@ func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFiltered
pq.Array(arg.RbacRole),
)
var count int64
err := row.Scan(&count)
err = row.Scan(&count)
return count, err
}
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
if !strings.Contains(query, authorizedQueryPlaceholder) {
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
}
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
return filtered, nil
}

View File

@ -0,0 +1,15 @@
package database
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsAuthorizedQuery(t *testing.T) {
t.Parallel()
query := `SELECT true;`
_, err := insertAuthorizedFilter(query, "")
require.ErrorContains(t, err, "does not contain authorized replace string", "ensure replace string")
}

View File

@ -3197,6 +3197,8 @@ WHERE
id = ANY($4)
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
-- @authorize_filter
ORDER BY (name, id) ASC
`

View File

@ -34,6 +34,8 @@ WHERE
id = ANY(@ids)
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
-- @authorize_filter
ORDER BY (name, id) ASC
;

View File

@ -10,6 +10,7 @@ import (
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/coderd/tracing"
)
@ -20,7 +21,7 @@ type Authorizer interface {
type PreparedAuthorized interface {
Authorize(ctx context.Context, object Object) error
Compile() (AuthorizeFilter, error)
CompileToSQL(cfg regosql.ConvertConfig) (string, error)
}
// Filter takes in a list of objects, and will filter the list removing all

View File

@ -847,7 +847,7 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
// Ensure the partial can compile to a SQL clause.
// This does not guarantee that the clause is valid SQL.
_, err = Compile(partialAuthz)
_, err = Compile(ConfigWithACL(), partialAuthz)
require.NoError(t, err, "compile prepared authorizer")
// Also check the rego policy can form a valid partial query result.

View File

@ -3,11 +3,11 @@ package rbac
import (
"context"
"golang.org/x/xerrors"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/coderd/tracing"
)
@ -28,12 +28,12 @@ type PartialAuthorizer struct {
var _ PreparedAuthorized = (*PartialAuthorizer)(nil)
func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) {
filter, err := Compile(pa)
func (pa *PartialAuthorizer) CompileToSQL(cfg regosql.ConvertConfig) (string, error) {
filter, err := Compile(cfg, pa)
if err != nil {
return nil, xerrors.Errorf("compile: %w", err)
return "", xerrors.Errorf("compile: %w", err)
}
return filter, nil
return filter.SQLString(), nil
}
func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error {

1
coderd/rbac/partial.rego Normal file
View File

@ -0,0 +1 @@
package partial

View File

@ -2,635 +2,63 @@ package rbac
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
"golang.org/x/xerrors"
)
type TermType string
const (
VarTypeJsonbTextArray TermType = "jsonb-text-array"
VarTypeText TermType = "text"
VarTypeBoolean TermType = "boolean"
// VarTypeSkip means this variable does not exist to use.
VarTypeSkip TermType = "skip"
)
type SQLColumn struct {
// RegoMatch matches the original variable string.
// If it is a match, then this variable config will apply.
RegoMatch *regexp.Regexp
// ColumnSelect is the name of the postgres column to select.
// Can use capture groups from RegoMatch with $1, $2, etc.
ColumnSelect string
// Type indicates the postgres type of the column. Some expressions will
// need to know this in order to determine what SQL to produce.
// An example is if the variable is a jsonb array, the "contains" SQL
// query is `variable ? 'value'` instead of `'value' = ANY(variable)`.
// This type is only needed to be provided
Type TermType
}
type SQLConfig struct {
// Variables is a map of rego variable names to SQL columns.
// Example:
// "input\.object\.org_owner": SQLColumn{
// ColumnSelect: "organization_id",
// Type: VarTypeUUID
// }
// "input\.object\.owner": SQLColumn{
// ColumnSelect: "owner_id",
// Type: VarTypeUUID
// }
// "input\.object\.group_acl\.(.*)": SQLColumn{
// ColumnSelect: "group_acl->$1",
// Type: VarTypeJsonbTextArray
// }
Variables []SQLColumn
}
func DefaultConfig() SQLConfig {
return SQLConfig{
Variables: []SQLColumn{
{
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
ColumnSelect: "group_acl->$1",
Type: VarTypeJsonbTextArray,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
ColumnSelect: "user_acl->$1",
Type: VarTypeJsonbTextArray,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
ColumnSelect: "organization_id :: text",
Type: VarTypeText,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
ColumnSelect: "owner_id :: text",
Type: VarTypeText,
},
},
}
}
func NoACLConfig() SQLConfig {
return SQLConfig{
Variables: []SQLColumn{
{
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
ColumnSelect: "",
Type: VarTypeSkip,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
ColumnSelect: "",
Type: VarTypeSkip,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
ColumnSelect: "organization_id :: text",
Type: VarTypeText,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
ColumnSelect: "owner_id :: text",
Type: VarTypeText,
},
},
}
}
type AuthorizeFilter interface {
Expression
// Eval is required for the fake in memory database to work. The in memory
// database can use this function to filter the results.
Eval(object Object) bool
SQLString() string
}
// expressionTop handles Eval(object Object) for in memory expressions
type expressionTop struct {
Expression
Auth *PartialAuthorizer
type authorizedSQLFilter struct {
sqlString string
auth *PartialAuthorizer
}
func (e expressionTop) Eval(object Object) bool {
return e.Auth.Authorize(context.Background(), object) == nil
func ConfigWithACL() regosql.ConvertConfig {
return regosql.ConvertConfig{
VariableConverter: regosql.DefaultVariableConverter(),
}
}
// Compile will convert a rego query AST into our custom types. The output is
// an AST that can be used to generate SQL.
func Compile(pa *PartialAuthorizer) (AuthorizeFilter, error) {
partialQueries := pa.partialQueries
if len(partialQueries.Support) > 0 {
return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support))
func ConfigWithoutACL() regosql.ConvertConfig {
return regosql.ConvertConfig{
VariableConverter: regosql.NoACLConverter(),
}
}
func Compile(cfg regosql.ConvertConfig, pa *PartialAuthorizer) (AuthorizeFilter, error) {
root, err := regosql.ConvertRegoAst(cfg, pa.partialQueries)
if err != nil {
return nil, xerrors.Errorf("convert rego ast: %w", err)
}
// 0 queries means the result is "undefined". This is the same as "false".
if len(partialQueries.Queries) == 0 {
return &termBoolean{
base: base{Rego: "false"},
Value: false,
}, nil
// Generate the SQL
gen := sqltypes.NewSQLGenerator()
sqlString := root.SQLString(gen)
if len(gen.Errors()) > 0 {
var errStrings []string
for _, err := range gen.Errors() {
errStrings = append(errStrings, err.Error())
}
return nil, xerrors.Errorf("sql generation errors: %v", strings.Join(errStrings, ", "))
}
// Abort early if any of the "OR"'d expressions are the empty string.
// This is the same as "true".
for _, query := range partialQueries.Queries {
if query.String() == "" {
return &termBoolean{
base: base{Rego: "true"},
Value: true,
}, nil
}
}
result := make([]Expression, 0, len(partialQueries.Queries))
var builder strings.Builder
for i := range partialQueries.Queries {
query, err := processQuery(partialQueries.Queries[i])
if err != nil {
return nil, err
}
result = append(result, query)
if i != 0 {
builder.WriteString("\n")
}
builder.WriteString(partialQueries.Queries[i].String())
}
exp := expOr{
base: base{
Rego: builder.String(),
},
Expressions: result,
}
return expressionTop{
Expression: &exp,
Auth: pa,
return &authorizedSQLFilter{
sqlString: sqlString,
auth: pa,
}, nil
}
// processQuery processes an entire set of expressions and joins them with
// "AND".
func processQuery(query ast.Body) (Expression, error) {
expressions := make([]Expression, 0, len(query))
for _, astExpr := range query {
expr, err := processExpression(astExpr)
if err != nil {
return nil, err
}
expressions = append(expressions, expr)
}
return expAnd{
base: base{
Rego: query.String(),
},
Expressions: expressions,
}, nil
func (a *authorizedSQLFilter) Eval(object Object) bool {
return a.auth.Authorize(context.Background(), object) == nil
}
func processExpression(expr *ast.Expr) (Expression, error) {
if !expr.IsCall() {
// This could be a single term that is a valid expression.
if term, ok := expr.Terms.(*ast.Term); ok {
value, err := processTerm(term)
if err != nil {
return nil, xerrors.Errorf("single term expression: %w", err)
}
if boolExp, ok := value.(Expression); ok {
return boolExp, nil
}
// Default to error.
}
return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported")
}
op := expr.Operator().String()
base := base{Rego: op}
switch op {
case "neq", "eq", "equal":
terms, err := processTerms(2, expr.Operands())
if err != nil {
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
}
return &opEqual{
base: base,
Terms: [2]Term{terms[0], terms[1]},
Not: op == "neq",
}, nil
case "internal.member_2":
terms, err := processTerms(2, expr.Operands())
if err != nil {
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
}
return &opInternalMember2{
base: base,
Needle: terms[0],
Haystack: terms[1],
}, nil
default:
return nil, xerrors.Errorf("invalid expression: operator %s not supported", op)
}
}
func processTerms(expected int, terms []*ast.Term) ([]Term, error) {
if len(terms) != expected {
return nil, xerrors.Errorf("too many arguments, expect %d found %d", expected, len(terms))
}
result := make([]Term, 0, len(terms))
for _, term := range terms {
processed, err := processTerm(term)
if err != nil {
return nil, xerrors.Errorf("invalid term: %w", err)
}
result = append(result, processed)
}
return result, nil
}
func processTerm(term *ast.Term) (Term, error) {
termBase := base{Rego: term.String()}
switch v := term.Value.(type) {
case ast.Boolean:
return &termBoolean{
base: termBase,
Value: bool(v),
}, nil
case ast.Ref:
obj := &termObject{
base: termBase,
Path: []Term{},
}
var idx int
// A ref is a set of terms. If the first term is a var, then the
// following terms are the path to the value.
isRef := true
var builder strings.Builder
for _, term := range v {
if idx == 0 {
if _, ok := v[0].Value.(ast.Var); !ok {
return nil, xerrors.Errorf("invalid term (%s): ref must start with a var, started with %T", v[0].String(), v[0])
}
}
_, newRef := term.Value.(ast.Ref)
if newRef ||
// This is an unfortunate hack. To fix this, we need to rewrite
// our SQL config as a path ([]string{"input", "object", "acl_group"}).
// In the rego AST, there is no difference between selecting
// a field by a variable, and selecting a field by a literal (string).
// This was a misunderstanding.
// Example (these are equivalent by AST):
// input.object.acl_group_list['4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75']
// input.object.acl_group_list.organization_id
//
// This is not equivalent
// input.object.acl_group_list[input.object.organization_id]
//
// If this becomes even more hairy, we should fix the sql config.
builder.String() == "input.object.acl_group_list" ||
builder.String() == "input.object.acl_user_list" {
if !newRef {
isRef = false
}
// New obj
obj.Path = append(obj.Path, termVariable{
base: base{
Rego: builder.String(),
},
Name: builder.String(),
})
builder.Reset()
idx = 0
}
if builder.Len() != 0 {
builder.WriteString(".")
}
builder.WriteString(trimQuotes(term.String()))
idx++
}
if isRef {
obj.Path = append(obj.Path, termVariable{
base: base{
Rego: builder.String(),
},
Name: builder.String(),
})
} else {
obj.Path = append(obj.Path, termString{
base: base{
Rego: fmt.Sprintf("%q", builder.String()),
},
Value: builder.String(),
})
}
return obj, nil
case ast.Var:
return &termVariable{
Name: trimQuotes(v.String()),
base: termBase,
}, nil
case ast.String:
return &termString{
Value: trimQuotes(v.String()),
base: termBase,
}, nil
case ast.Set:
slice := v.Slice()
set := make([]Term, 0, len(slice))
for _, elem := range slice {
processed, err := processTerm(elem)
if err != nil {
return nil, xerrors.Errorf("invalid set term %s: %w", elem.String(), err)
}
set = append(set, processed)
}
return &termSet{
Value: set,
base: termBase,
}, nil
default:
return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String())
}
}
type base struct {
// Rego is the original rego string
Rego string
}
func (b base) RegoString() string {
return b.Rego
}
// Expression comprises a set of terms, operators, and functions. All
// expressions return a boolean value.
//
// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo"
type Expression interface {
// RegoString is used in debugging to see the original rego expression.
RegoString() string
// SQLString returns the SQL expression that can be used in a WHERE clause.
SQLString(cfg SQLConfig) string
}
type expAnd struct {
base
Expressions []Expression
}
func (t expAnd) SQLString(cfg SQLConfig) string {
if len(t.Expressions) == 1 {
return t.Expressions[0].SQLString(cfg)
}
exprs := make([]string, 0, len(t.Expressions))
for _, expr := range t.Expressions {
exprs = append(exprs, expr.SQLString(cfg))
}
return "(" + strings.Join(exprs, " AND ") + ")"
}
type expOr struct {
base
Expressions []Expression
}
func (t expOr) SQLString(cfg SQLConfig) string {
if len(t.Expressions) == 1 {
return t.Expressions[0].SQLString(cfg)
}
exprs := make([]string, 0, len(t.Expressions))
for _, expr := range t.Expressions {
exprs = append(exprs, expr.SQLString(cfg))
}
return "(" + strings.Join(exprs, " OR ") + ")"
}
// Operator joins terms together to form an expression.
// Operators are also expressions.
//
// Eg: "=", "neq", "internal.member_2", etc.
type Operator interface {
Expression
}
type opEqual struct {
base
Terms [2]Term
// For NotEqual
Not bool
}
func (t opEqual) SQLString(cfg SQLConfig) string {
op := "="
if t.Not {
op = "!="
}
return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg))
}
// opInternalMember2 is checking if the first term is a member of the second term.
// The second term is a set or list.
type opInternalMember2 struct {
base
Needle Term
Haystack Term
}
func (t opInternalMember2) SQLString(cfg SQLConfig) string {
if haystack, ok := t.Haystack.(*termObject); ok {
// This is a special case where the haystack is a jsonb array.
// There is a more general way to solve this, but that requires a lot
// more code to cover a lot more cases that we do not care about.
// To handle this more generally we should implement "Array" as a type.
// Then have the `contains` function on the Array type. This would defer
// knowing the element type to the Array and cover more cases without
// having to add more "if" branches here.
// But until we need more cases, our basic type system is ok, and
// this is the only case we need to handle.
sqlType := haystack.SQLType(cfg)
if sqlType == VarTypeJsonbTextArray {
return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg))
}
if sqlType == VarTypeSkip {
return "false"
}
}
return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg))
}
// Term is a single value in an expression. Terms can be variables or constants.
//
// Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner",
// "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}"
type Term interface {
RegoString() string
SQLString(cfg SQLConfig) string
SQLType(cfg SQLConfig) TermType
}
type termString struct {
base
Value string
}
func (t termString) SQLString(_ SQLConfig) string {
return "'" + t.Value + "'"
}
func (termString) SQLType(_ SQLConfig) TermType {
return VarTypeText
}
// termObject is a variable that can be dereferenced. We count some rego objects
// as single variables, eg: input.object.org_owner. In reality, it is a nested
// object.
// In rego, we can dereference the object with the "." operator, which we can
// handle with regex.
// Or we can dereference the object with the "[]", which we can handle with this
// term type.
type termObject struct {
base
Path []Term
}
func (t termObject) SQLType(cfg SQLConfig) TermType {
// Without a full type system, let's just assume the type of the first var
// is the resulting type. This is correct for our use case.
// Solving this more generally requires a full type system, which is
// excessive for our mostly static policy.
return t.Path[0].SQLType(cfg)
}
func (t termObject) SQLString(cfg SQLConfig) string {
if len(t.Path) == 1 {
return t.Path[0].SQLString(cfg)
}
// Combine the last 2 variables into 1 variable.
end := t.Path[len(t.Path)-1]
before := t.Path[len(t.Path)-2]
// Recursively solve the SQLString by removing the last nested reference.
// This continues until we have a single variable.
return termObject{
base: t.base,
Path: append(
t.Path[:len(t.Path)-2],
termVariable{
base: base{
Rego: before.RegoString() + "[" + end.RegoString() + "]",
},
// Convert the end to SQL string. We evaluate each term
// one at a time.
Name: before.RegoString() + "." + end.SQLString(cfg),
},
),
}.SQLString(cfg)
}
type termVariable struct {
base
Name string
}
func (t termVariable) SQLType(cfg SQLConfig) TermType {
if col := t.ColumnConfig(cfg); col != nil {
return col.Type
}
return VarTypeText
}
func (t termVariable) SQLString(cfg SQLConfig) string {
if col := t.ColumnConfig(cfg); col != nil {
matches := col.RegoMatch.FindStringSubmatch(t.Name)
if len(matches) > 0 {
// This config matches this variable.
replace := make([]string, 0, len(matches)*2)
for i, m := range matches {
replace = append(replace, fmt.Sprintf("$%d", i))
replace = append(replace, m)
}
replacer := strings.NewReplacer(replace...)
return replacer.Replace(col.ColumnSelect)
}
}
return t.Name
}
// ColumnConfig returns the correct SQLColumn settings for the
// term. If there is no configured column, it will return nil.
func (t termVariable) ColumnConfig(cfg SQLConfig) *SQLColumn {
for _, col := range cfg.Variables {
matches := col.RegoMatch.MatchString(t.Name)
if matches {
return &col
}
}
return nil
}
// termSet is a set of unique terms.
type termSet struct {
base
Value []Term
}
func (t termSet) SQLType(cfg SQLConfig) TermType {
if len(t.Value) == 0 {
return VarTypeText
}
// Without a full type system, let's just assume the type of the first var
// is the resulting type. This is correct for our use case.
// Solving this more generally requires a full type system, which is
// excessive for our mostly static policy.
return t.Value[0].SQLType(cfg)
}
func (t termSet) SQLString(cfg SQLConfig) string {
elems := make([]string, 0, len(t.Value))
for _, v := range t.Value {
elems = append(elems, v.SQLString(cfg))
}
return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ","))
}
type termBoolean struct {
base
Value bool
}
func (termBoolean) SQLType(SQLConfig) TermType {
return VarTypeBoolean
}
func (t termBoolean) Eval(_ Object) bool {
return t.Value
}
func (t termBoolean) SQLString(_ SQLConfig) string {
return strconv.FormatBool(t.Value)
}
func trimQuotes(s string) string {
return strings.Trim(s, "\"")
func (a *authorizedSQLFilter) SQLString() string {
return a.sqlString
}

View File

@ -1,150 +0,0 @@
package rbac
import (
"context"
"testing"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/stretchr/testify/require"
)
func TestCompileQuery(t *testing.T) {
t.Parallel()
t.Run("EmptyQuery", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t, ""))
require.NoError(t, err, "compile empty")
require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'")
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'")
})
t.Run("TrueQuery", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t, "true"))
require.NoError(t, err, "compile")
require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'")
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'")
})
t.Run("ACLIn", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t, `"*" in input.object.acl_group_list.allUsers`))
require.NoError(t, err, "compile")
require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member")
require.Equal(t, `group_acl->'allUsers' ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in")
})
t.Run("Complex", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t,
`input.object.org_owner != ""`,
`input.object.org_owner in {"a", "b", "c"}`,
`input.object.org_owner != ""`,
`"read" in input.object.acl_group_list.allUsers`,
`"read" in input.object.acl_user_list.me`,
))
require.NoError(t, err, "compile")
require.Equal(t, `(organization_id :: text != '' OR `+
`organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+
`organization_id :: text != '' OR `+
`group_acl->'allUsers' ? 'read' OR `+
`user_acl->'me' ? 'read')`,
expression.SQLString(DefaultConfig()), "complex")
})
t.Run("SetDereference", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t,
`"*" in input.object.acl_group_list[input.object.org_owner]`,
))
require.NoError(t, err, "compile")
require.Equal(t, `group_acl->organization_id :: text ? '*'`,
expression.SQLString(DefaultConfig()), "set dereference")
})
t.Run("JsonbLiteralDereference", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t,
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
))
require.NoError(t, err, "compile")
require.Equal(t, `group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'`,
expression.SQLString(DefaultConfig()), "literal dereference")
})
t.Run("NoACLColumns", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t,
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
))
require.NoError(t, err, "compile")
require.Equal(t, `false`,
expression.SQLString(NoACLConfig()), "literal dereference")
})
}
func TestEvalQuery(t *testing.T) {
t.Parallel()
t.Run("GroupACL", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t,
`"read" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
))
require.NoError(t, err, "compile")
result := expression.Eval(Object{
Owner: "not-me",
OrgID: "random",
Type: "workspace",
ACLUserList: map[string][]Action{},
ACLGroupList: map[string][]Action{
"4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75": {"read"},
},
})
require.True(t, result, "eval")
})
}
func partialQueries(t *testing.T, queries ...string) *PartialAuthorizer {
opts := ast.ParserOptions{
AllFutureKeywords: true,
}
astQueries := make([]ast.Body, 0, len(queries))
for _, q := range queries {
astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts))
}
prepareQueries := make([]rego.PreparedEvalQuery, 0, len(queries))
for _, q := range astQueries {
var prepped rego.PreparedEvalQuery
var err error
if q.String() == "" {
prepped, err = rego.New(
rego.Query("true"),
).PrepareForEval(context.Background())
} else {
prepped, err = rego.New(
rego.ParsedQuery(q),
).PrepareForEval(context.Background())
}
require.NoError(t, err, "prepare query")
prepareQueries = append(prepareQueries, prepped)
}
return &PartialAuthorizer{
partialQueries: &rego.PartialQueries{
Queries: astQueries,
Support: []*ast.Module{},
},
preparedQueries: prepareQueries,
input: nil,
alwaysTrue: false,
}
}

View File

@ -0,0 +1,102 @@
package regosql
import (
"fmt"
"golang.org/x/xerrors"
"github.com/open-policy-agent/opa/ast"
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
)
var _ sqltypes.VariableMatcher = ACLGroupVar{}
var _ sqltypes.Node = ACLGroupVar{}
// ACLGroupVar is a variable matcher that handles group_acl and user_acl.
// The sql type is a jsonb object with the following structure:
//
// "group_acl": {
// "<group_name>": ["<actions>"]
// }
//
// This is a custom variable matcher as json objects have arbitrary complexity.
type ACLGroupVar struct {
StructSQL string
// input.object.group_acl -> ["input", "object", "group_acl"]
StructPath []string
// FieldReference handles referencing the subfields, which could be
// more variables. We pass one in as the global one might not be correctly
// scoped.
FieldReference sqltypes.VariableMatcher
// Instance fields
Source sqltypes.RegoSource
GroupNode sqltypes.Node
}
func ACLGroupMatcher(fieldReference sqltypes.VariableMatcher, structSQL string, structPath []string) ACLGroupVar {
return ACLGroupVar{StructSQL: structSQL, StructPath: structPath, FieldReference: fieldReference}
}
func (ACLGroupVar) UseAs() sqltypes.Node { return ACLGroupVar{} }
func (g ACLGroupVar) ConvertVariable(rego ast.Ref) (sqltypes.Node, bool) {
// "left" will be a map of group names to actions in rego.
// {
// "all_users": ["read"]
// }
left, err := sqltypes.RegoVarPath(g.StructPath, rego)
if err != nil {
return nil, false
}
aclGrp := ACLGroupVar{
StructSQL: g.StructSQL,
StructPath: g.StructPath,
FieldReference: g.FieldReference,
Source: sqltypes.RegoSource(rego.String()),
}
// We expect 1 more term. Either a ref or a string.
if len(left) != 1 {
return nil, false
}
// If the remaining is a variable, then we need to convert it.
// Assuming we support variable fields.
ref, ok := left[0].Value.(ast.Ref)
if ok && g.FieldReference != nil {
groupNode, ok := g.FieldReference.ConvertVariable(ref)
if ok {
aclGrp.GroupNode = groupNode
return aclGrp, true
}
}
// If it is a string, we assume it is a literal
groupName, ok := left[0].Value.(ast.String)
if ok {
aclGrp.GroupNode = sqltypes.String(string(groupName))
return aclGrp, true
}
// If we have not matched it yet, then it is something we do not recognize.
return nil, false
}
func (g ACLGroupVar) SQLString(cfg *sqltypes.SQLGenerator) string {
return fmt.Sprintf("%s->%s", g.StructSQL, g.GroupNode.SQLString(cfg))
}
func (g ACLGroupVar) ContainsSQL(cfg *sqltypes.SQLGenerator, other sqltypes.Node) (string, error) {
switch other.UseAs().(type) {
// Only supports containing other strings.
case sqltypes.AstString:
return fmt.Sprintf("%s ? %s", g.SQLString(cfg), other.SQLString(cfg)), nil
default:
return "", xerrors.Errorf("unsupported acl group contains %T", other)
}
}

View File

@ -0,0 +1,230 @@
package regosql
import (
"encoding/json"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
)
// ConvertConfig is required to generate SQL from the rego queries.
type ConvertConfig struct {
// VariableConverter is called each time a var is encountered. This creates
// the SQL ast for the variable. Without this, the SQL generator does not
// know how to convert rego variables into SQL columns.
VariableConverter sqltypes.VariableMatcher
}
// ConvertRegoAst converts partial rego queries into a single SQL where
// clause. If the query equates to "true" then the user should have access.
func ConvertRegoAst(cfg ConvertConfig, partial *rego.PartialQueries) (sqltypes.BooleanNode, error) {
if len(partial.Queries) == 0 {
// Always deny if there are no queries. This means there is no possible
// way this user can access these resources.
return sqltypes.Bool(false), nil
}
for _, q := range partial.Queries {
// An empty query in rego means "true". If any query in the set is
// empty, then the user should have access.
if len(q) == 0 {
// Always allow
return sqltypes.Bool(true), nil
}
}
var queries []sqltypes.BooleanNode
var builder strings.Builder
for i, q := range partial.Queries {
converted, err := convertQuery(cfg, q)
if err != nil {
return nil, xerrors.Errorf("query %s: %w", q.String(), err)
}
if i != 0 {
builder.WriteString("\n")
}
builder.WriteString(q.String())
queries = append(queries, converted)
}
// All queries are OR'd together. This means that if any query is true,
// then the user should have access.
sqlClause := sqltypes.Or(sqltypes.RegoSource(builder.String()), queries...)
// Always wrap in parens to ensure the correct precedence when combining with other
// SQL clauses.
return sqltypes.BoolParenthesis(sqlClause), nil
}
func convertQuery(cfg ConvertConfig, q ast.Body) (sqltypes.BooleanNode, error) {
var expressions []sqltypes.BooleanNode
for _, e := range q {
exp, err := convertExpression(cfg, e)
if err != nil {
return nil, xerrors.Errorf("expression %s: %w", e.String(), err)
}
expressions = append(expressions, exp)
}
// All expressions in a single query are AND'd together. This means that
// all expressions must be true for the user to have access.
return sqltypes.And(sqltypes.RegoSource(q.String()), expressions...), nil
}
func convertExpression(cfg ConvertConfig, e *ast.Expr) (sqltypes.BooleanNode, error) {
if e.IsCall() {
n, err := convertCall(cfg, e.Terms.([]*ast.Term))
if err != nil {
return nil, xerrors.Errorf("call: %w", err)
}
boolN, ok := n.(sqltypes.BooleanNode)
if !ok {
return nil, xerrors.Errorf("call %q: not a boolean expression", e.String())
}
return boolN, nil
}
// If it's not a call, it is a single term
if term, ok := e.Terms.(*ast.Term); ok {
ty, err := convertTerm(cfg, term)
if err != nil {
return nil, xerrors.Errorf("convert term %s: %w", term.String(), err)
}
tyBool, ok := ty.(sqltypes.BooleanNode)
if !ok {
return nil, xerrors.Errorf("convert term %s is not a boolean: %w", term.String(), err)
}
return tyBool, nil
}
return nil, xerrors.Errorf("expression %s not supported", e.String())
}
// convertCall converts a function call to a SQL expression.
func convertCall(cfg ConvertConfig, call ast.Call) (sqltypes.Node, error) {
if len(call) == 0 {
return nil, xerrors.Errorf("empty call")
}
// Operator is the first term
op := call[0]
var args []*ast.Term
if len(call) > 1 {
args = call[1:]
}
opString := op.String()
// Supported operators.
switch op.String() {
case "neq", "eq", "equals", "equal":
args, err := convertTerms(cfg, args, 2)
if err != nil {
return nil, xerrors.Errorf("arguments: %w", err)
}
not := false
if opString == "neq" || opString == "notequals" || opString == "notequal" {
not = true
}
equality := sqltypes.Equality(not, args[0], args[1])
return sqltypes.BoolParenthesis(equality), nil
case "internal.member_2":
args, err := convertTerms(cfg, args, 2)
if err != nil {
return nil, xerrors.Errorf("arguments: %w", err)
}
member := sqltypes.MemberOf(args[0], args[1])
return sqltypes.BoolParenthesis(member), nil
default:
return nil, xerrors.Errorf("operator %s not supported", op)
}
}
func convertTerms(cfg ConvertConfig, terms []*ast.Term, expected int) ([]sqltypes.Node, error) {
if len(terms) != expected {
return nil, xerrors.Errorf("expected %d terms, got %d", expected, len(terms))
}
result := make([]sqltypes.Node, 0, len(terms))
for _, t := range terms {
term, err := convertTerm(cfg, t)
if err != nil {
return nil, xerrors.Errorf("term: %w", err)
}
result = append(result, term)
}
return result, nil
}
func convertTerm(cfg ConvertConfig, term *ast.Term) (sqltypes.Node, error) {
source := sqltypes.RegoSource(term.String())
switch t := term.Value.(type) {
case ast.Var:
// All vars should be contained in ast.Ref's.
return nil, xerrors.New("var not yet supported")
case ast.Ref:
if len(t) == 0 {
// A reference with no text is a variable with no name?
// This makes no sense.
return nil, xerrors.New("empty ref not supported")
}
if cfg.VariableConverter == nil {
return nil, xerrors.New("no variable converter provided to handle variables")
}
// The structure of references is as follows:
// 1. All variables start with a regoAst.Var as the first term.
// 2. The next term is either a regoAst.String or a regoAst.Var.
// - regoAst.String if a static field name or index.
// - regoAst.Var if the field reference is a variable itself. Such as
// the wildcard "[_]"
// 3. Repeat 1-2 until the end of the reference.
node, ok := cfg.VariableConverter.ConvertVariable(t)
if !ok {
return nil, xerrors.Errorf("variable %q cannot be converted", t.String())
}
return node, nil
case ast.String:
return sqltypes.String(string(t)), nil
case ast.Number:
return sqltypes.Number(source, json.Number(t)), nil
case ast.Boolean:
return sqltypes.Bool(bool(t)), nil
case *ast.Array:
elems := make([]sqltypes.Node, 0, t.Len())
for i := 0; i < t.Len(); i++ {
value, err := convertTerm(cfg, t.Elem(i))
if err != nil {
return nil, xerrors.Errorf("array element %d in %q: %w", i, t.String(), err)
}
elems = append(elems, value)
}
return sqltypes.Array(source, elems...)
case ast.Object:
return nil, xerrors.New("object not yet supported")
case ast.Set:
// Just treat a set like an array for now.
arr := t.Sorted()
return convertTerm(cfg, &ast.Term{
Value: arr,
Location: term.Location,
})
case ast.Call:
// This is a function call
return convertCall(cfg, t)
default:
return nil, xerrors.Errorf("%T not yet supported", t)
}
}

View File

@ -0,0 +1,307 @@
package regosql_test
import (
"testing"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
)
// TestRegoQueriesNoVariables handles cases without variables. These should be
// very simple and straight forward.
func TestRegoQueries(t *testing.T) {
t.Parallel()
p := func(v string) string {
return "(" + v + ")"
}
testCases := []struct {
Name string
Queries []string
ExpectedSQL string
ExpectError bool
ExpectedSQLGenError bool
VariableConverter sqltypes.VariableMatcher
}{
{
Name: "Empty",
Queries: []string{``},
ExpectedSQL: "true",
},
{
Name: "True",
Queries: []string{`true`},
ExpectedSQL: "true",
},
{
Name: "False",
Queries: []string{`false`},
ExpectedSQL: "false",
},
{
Name: "MultipleBool",
Queries: []string{"true", "false"},
ExpectedSQL: "(true OR false)",
},
{
Name: "Numbers",
Queries: []string{
"(1 != 2) = true",
"5 == 5",
},
ExpectedSQL: p("((1 != 2) = true) OR (5 = 5)"),
},
// Variables
{
// Always return a constant string for all variables.
Name: "V_Basic",
Queries: []string{
`input.x = "hello_world"`,
},
ExpectedSQL: p("only_var = 'hello_world'"),
VariableConverter: sqltypes.NewVariableConverter().RegisterMatcher(
sqltypes.StringVarMatcher("only_var", []string{
"input", "x",
}),
),
},
// Coder Variables
{
// Always return a constant string for all variables.
Name: "GroupACL",
Queries: []string{
`"read" in input.object.acl_group_list.allUsers`,
},
ExpectedSQL: "(group_acl->'allUsers' ? 'read')",
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "GroupWildcard",
Queries: []string{`"*" in input.object.acl_group_list.allUsers`},
ExpectedSQL: "(group_acl->'allUsers' ? '*')",
VariableConverter: regosql.DefaultVariableConverter(),
},
{
// Always return a constant string for all variables.
Name: "GroupACLWithVarField",
Queries: []string{
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: "(group_acl->organization_id :: text ? 'read')",
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "VarInArray",
Queries: []string{
`input.object.org_owner in {"a", "b", "c"}`,
},
ExpectedSQL: p("organization_id :: text = ANY(ARRAY ['a','b','c'])"),
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "SetDereference",
Queries: []string{`"*" in input.object.acl_group_list[input.object.org_owner]`},
ExpectedSQL: p("group_acl->organization_id :: text ? '*'"),
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "JsonbLiteralDereference",
Queries: []string{`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`},
ExpectedSQL: p("group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'"),
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "Complex",
Queries: []string{
`input.object.org_owner != ""`,
`input.object.org_owner in {"a", "b", "c"}`,
`input.object.org_owner != ""`,
`"read" in input.object.acl_group_list.allUsers`,
`"read" in input.object.acl_user_list.me`,
},
ExpectedSQL: `((organization_id :: text != '') OR ` +
`(organization_id :: text = ANY(ARRAY ['a','b','c'])) OR ` +
`(organization_id :: text != '') OR ` +
`(group_acl->'allUsers' ? 'read') OR ` +
`(user_acl->'me' ? 'read'))`,
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "NoACLs",
Queries: []string{
`"read" in input.object.acl_group_list[input.object.org_owner]`,
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
},
// Special case where the bool is wrapped
ExpectedSQL: p("(false) OR (false)"),
VariableConverter: regosql.NoACLConverter(),
},
{
Name: "TwoExpressions",
Queries: []string{
`true; true`,
},
ExpectedSQL: p("true AND true"),
VariableConverter: regosql.DefaultVariableConverter(),
},
// Actual vectors from production
{
Name: "FromOwner",
Queries: []string{
``,
`"05f58202-4bfc-43ce-9ba4-5ff6e0174a71" = input.object.org_owner`,
`"read" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
},
ExpectedSQL: "true",
VariableConverter: regosql.NoACLConverter(),
},
{
Name: "OrgAdmin",
Queries: []string{
`input.object.org_owner != "";
input.object.org_owner in {"05f58202-4bfc-43ce-9ba4-5ff6e0174a71"};
input.object.owner != "";
"d5389ccc-57a4-4b13-8c3f-31747bcdc9f1" = input.object.owner`,
},
ExpectedSQL: "((organization_id :: text != '') AND " +
"(organization_id :: text = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'])) AND " +
"(owner_id :: text != '') AND " +
"('d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' = owner_id :: text))",
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "UserACLAllow",
Queries: []string{
`"read" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
`"*" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
},
ExpectedSQL: "((user_acl->'d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' ? 'read') OR " +
"(user_acl->'d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' ? '*'))",
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "NoACLConfig",
Queries: []string{
`input.object.org_owner != "";
input.object.org_owner in {"05f58202-4bfc-43ce-9ba4-5ff6e0174a71"};
"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: "((organization_id :: text != '') AND (organization_id :: text = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'])) AND (false))",
VariableConverter: regosql.NoACLConverter(),
},
{
Name: "EmptyACLListNoACLs",
Queries: []string{
`input.object.org_owner != "";
input.object.org_owner in set();
"create" in input.object.acl_group_list[input.object.org_owner]`,
`input.object.org_owner != "";
input.object.org_owner in set();
"*" in input.object.acl_group_list[input.object.org_owner]`,
`"create" in input.object.acl_user_list.me`,
`"*" in input.object.acl_user_list.me`,
},
ExpectedSQL: p(p("(organization_id :: text != '') AND (false) AND (group_acl->organization_id :: text ? 'create')") + " OR " +
p("(organization_id :: text != '') AND (false) AND (group_acl->organization_id :: text ? '*')") + " OR " +
p("user_acl->'me' ? 'create'") + " OR " +
p("user_acl->'me' ? '*'")),
VariableConverter: regosql.DefaultVariableConverter(),
},
{
Name: "TemplateOwner",
Queries: []string{
`neq(input.object.org_owner, "");
internal.member_2(input.object.org_owner, {"3bf82434-e40b-44ae-b3d8-d0115bba9bad", "5630fda3-26ab-462c-9014-a88a62d7a415", "c304877a-bc0d-4e9b-9623-a38eae412929"});
neq(input.object.owner, "");
"806dd721-775f-4c85-9ce3-63fbbd975954" = input.object.owner`,
},
ExpectedSQL: p(p("organization_id :: text != ''") + " AND " +
p("organization_id :: text = ANY(ARRAY ['3bf82434-e40b-44ae-b3d8-d0115bba9bad','5630fda3-26ab-462c-9014-a88a62d7a415','c304877a-bc0d-4e9b-9623-a38eae412929'])") + " AND " +
p("false") + " AND " +
p("false")),
VariableConverter: regosql.TemplateConverter(),
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
part := partialQueries(tc.Queries...)
cfg := regosql.ConvertConfig{
VariableConverter: tc.VariableConverter,
}
requireConvert(t, convertTestCase{
part: part,
cfg: cfg,
expectSQL: tc.ExpectedSQL,
expectConvertError: tc.ExpectError,
expectSQLGenError: tc.ExpectedSQLGenError,
})
})
}
}
type convertTestCase struct {
part *rego.PartialQueries
cfg regosql.ConvertConfig
expectConvertError bool
expectSQL string
expectSQLGenError bool
}
func requireConvert(t *testing.T, tc convertTestCase) {
t.Helper()
for i, q := range tc.part.Queries {
t.Logf("Query %d: %s", i, q.String())
}
for i, s := range tc.part.Support {
t.Logf("Support %d: %s", i, s.String())
}
root, err := regosql.ConvertRegoAst(tc.cfg, tc.part)
if tc.expectConvertError {
require.Error(t, err)
} else {
require.NoError(t, err, "compile")
gen := sqltypes.NewSQLGenerator()
sqlString := root.SQLString(gen)
if tc.expectSQLGenError {
require.True(t, len(gen.Errors()) > 0, "expected SQL generation error")
} else {
require.NoError(t, err, "sql gen")
require.Equal(t, tc.expectSQL, sqlString, "sql match")
}
}
}
func partialQueries(queries ...string) *rego.PartialQueries {
opts := ast.ParserOptions{
AllFutureKeywords: true,
}
astQueries := make([]ast.Body, 0, len(queries))
for _, q := range queries {
astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts))
}
return &rego.PartialQueries{
Queries: astQueries,
Support: []*ast.Module{},
}
}

View File

@ -0,0 +1,60 @@
package regosql
import "github.com/coder/coder/coderd/rbac/regosql/sqltypes"
func organizationOwnerMatcher() sqltypes.VariableMatcher {
return sqltypes.StringVarMatcher("organization_id :: text", []string{"input", "object", "org_owner"})
}
func userOwnerMatcher() sqltypes.VariableMatcher {
return sqltypes.StringVarMatcher("owner_id :: text", []string{"input", "object", "owner"})
}
func groupACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher {
return ACLGroupMatcher(m, "group_acl", []string{"input", "object", "acl_group_list"})
}
func userACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher {
return ACLGroupMatcher(m, "user_acl", []string{"input", "object", "acl_user_list"})
}
func TemplateConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
organizationOwnerMatcher(),
// Templates have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
matcher.RegisterMatcher(
groupACLMatcher(matcher),
userACLMatcher(matcher),
)
return matcher
}
// NoACLConverter should be used when the target SQL table does not contain
// group or user ACL columns.
func NoACLConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
organizationOwnerMatcher(),
userOwnerMatcher(),
)
matcher.RegisterMatcher(
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
)
return matcher
}
func DefaultVariableConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
organizationOwnerMatcher(),
userOwnerMatcher(),
)
matcher.RegisterMatcher(
groupACLMatcher(matcher),
userACLMatcher(matcher),
)
return matcher
}

View File

@ -0,0 +1,3 @@
// Package regosql converts rego queries into SQL WHERE clauses. This is so
// the rego queries can be used to filter the results of a SQL query.
package regosql

View File

@ -0,0 +1,61 @@
package sqltypes
import (
"github.com/open-policy-agent/opa/ast"
)
var _ Node = alwaysFalse{}
var _ VariableMatcher = alwaysFalse{}
type alwaysFalse struct {
Matcher VariableMatcher
InnerNode Node
}
// AlwaysFalse overrides the inner node with a constant "false".
func AlwaysFalse(m VariableMatcher) VariableMatcher {
return alwaysFalse{
Matcher: m,
}
}
// AlwaysFalseNode is mainly used for unit testing to make a Node immediately.
func AlwaysFalseNode(n Node) Node {
return alwaysFalse{
InnerNode: n,
Matcher: nil,
}
}
// UseAs uses a type no one supports to always override with false.
func (alwaysFalse) UseAs() Node { return alwaysFalse{} }
func (f alwaysFalse) ConvertVariable(rego ast.Ref) (Node, bool) {
if f.Matcher != nil {
n, ok := f.Matcher.ConvertVariable(rego)
if ok {
return alwaysFalse{
Matcher: f.Matcher,
InnerNode: n,
}, true
}
}
return nil, false
}
func (alwaysFalse) SQLString(_ *SQLGenerator) string {
return "false"
}
func (alwaysFalse) ContainsSQL(_ *SQLGenerator, _ Node) (string, error) {
return "false", nil
}
func (alwaysFalse) ContainedInSQL(_ *SQLGenerator, _ Node) (string, error) {
return "false", nil
}
func (alwaysFalse) EqualsSQLString(_ *SQLGenerator, _ bool, _ Node) (string, error) {
return "false", nil
}

View File

@ -0,0 +1,69 @@
package sqltypes
import (
"fmt"
"reflect"
"strings"
"golang.org/x/xerrors"
)
type ASTArray struct {
Source RegoSource
Value []Node
}
// Array is typed to whatever the first element is. If there is not first
// element, the array element type is invalid.
func Array(source RegoSource, nodes ...Node) (Node, error) {
for i := 1; i < len(nodes); i++ {
if reflect.TypeOf(nodes[0]) != reflect.TypeOf(nodes[i]) {
// Do not allow mixed types in arrays
return nil, xerrors.Errorf("array element %d in %q: type mismatch", i, source)
}
}
return ASTArray{Value: nodes, Source: source}, nil
}
func (ASTArray) UseAs() Node { return ASTArray{} }
func (a ASTArray) ContainsSQL(cfg *SQLGenerator, needle Node) (string, error) {
// If we have no elements in our set, then our needle is never in the set.
if len(a.Value) == 0 {
return "false", nil
}
// This condition supports any contains function if the needle type is
// the same as the ASTArray element type.
if reflect.TypeOf(a.MyType().UseAs()) != reflect.TypeOf(needle.UseAs()) {
return "ArrayContainsError", xerrors.Errorf("array contains %q: type mismatch (%T, %T)",
a.Source, a.MyType(), needle)
}
return fmt.Sprintf("%s = ANY(%s)", needle.SQLString(cfg), a.SQLString(cfg)), nil
}
func (a ASTArray) SQLString(cfg *SQLGenerator) string {
switch a.MyType().UseAs().(type) {
case invalidNode:
cfg.AddError(xerrors.Errorf("array %q: empty array", a.Source))
return "ArrayError"
case AstNumber, AstString, AstBoolean:
// Primitive types
values := make([]string, 0, len(a.Value))
for _, v := range a.Value {
values = append(values, v.SQLString(cfg))
}
return fmt.Sprintf("ARRAY [%s]", strings.Join(values, ","))
}
cfg.AddError(xerrors.Errorf("array %q: unsupported type %T", a.Source, a.MyType()))
return "ArrayError"
}
func (a ASTArray) MyType() Node {
if len(a.Value) == 0 {
return invalidNode{}
}
return a.Value[0]
}

View File

@ -0,0 +1,77 @@
package sqltypes
import (
"strings"
"golang.org/x/xerrors"
)
type binaryOperator int
const (
_ binaryOperator = iota
binaryOpOR
binaryOpAND
)
type binaryOp struct {
source RegoSource
op binaryOperator
Terms []BooleanNode
}
func (binaryOp) UseAs() Node { return binaryOp{} }
func (binaryOp) IsBooleanNode() {}
func Or(source RegoSource, terms ...BooleanNode) BooleanNode {
return newBinaryOp(source, binaryOpOR, terms...)
}
func And(source RegoSource, terms ...BooleanNode) BooleanNode {
return newBinaryOp(source, binaryOpAND, terms...)
}
func newBinaryOp(source RegoSource, op binaryOperator, terms ...BooleanNode) BooleanNode {
if len(terms) == 0 {
// TODO: How to handle 0 terms?
return Bool(false)
}
opTerms := make([]BooleanNode, 0, len(terms))
for i := range terms {
// Always wrap terms in parentheses to be safe.
opTerms = append(opTerms, BoolParenthesis(terms[i]))
}
if len(opTerms) == 1 {
return opTerms[0]
}
return binaryOp{
Terms: opTerms,
op: op,
source: source,
}
}
func (b binaryOp) SQLString(cfg *SQLGenerator) string {
sqlOp := ""
switch b.op {
case binaryOpOR:
sqlOp = "OR"
case binaryOpAND:
sqlOp = "AND"
default:
cfg.AddError(xerrors.Errorf("unsupported binary operator: %s (%d)", b.source, b.op))
return "BinaryOpError"
}
terms := make([]string, 0, len(b.Terms))
for _, term := range b.Terms {
termSQL := term.SQLString(cfg)
terms = append(terms, termSQL)
}
return strings.Join(terms, " "+sqlOp+" ")
}

View File

@ -0,0 +1,26 @@
package sqltypes
import (
"strconv"
)
// AstBoolean is a literal true/false value.
type AstBoolean struct {
Source RegoSource
Value bool
}
func Bool(t bool) BooleanNode {
return AstBoolean{Value: t, Source: RegoSource(strconv.FormatBool(t))}
}
func (AstBoolean) IsBooleanNode() {}
func (AstBoolean) UseAs() Node { return AstBoolean{} }
func (b AstBoolean) SQLString(_ *SQLGenerator) string {
return strconv.FormatBool(b.Value)
}
func (b AstBoolean) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
return boolEqualsSQLString(cfg, b, not, other)
}

View File

@ -0,0 +1,5 @@
// Package sqltypes contains the types used to convert rego queries into SQL.
// The rego ast is converted into these types to better control the SQL
// generation. It allows writing the SQL generation for types in an easier to
// read way.
package sqltypes

View File

@ -0,0 +1,102 @@
package sqltypes
import (
"fmt"
"golang.org/x/xerrors"
)
// SupportsEquality is an interface that can be implemented by types that
// support equality with other types. We defer to other types to implement this
// as it is much easier to implement this in the context of the type.
type SupportsEquality interface {
// EqualsSQLString intentionally returns an error. This is so if
// left = right is not supported, we can try right = left.
EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error)
}
var _ BooleanNode = equality{}
var _ Node = equality{}
var _ SupportsEquality = equality{}
type equality struct {
Left Node
Right Node
// Not just inverses the result of the comparison. We could implement this
// as a Not node wrapping the equality, but this is more efficient.
Not bool
}
func Equality(notEquals bool, a, b Node) BooleanNode {
return equality{
Left: a,
Right: b,
Not: notEquals,
}
}
func (equality) IsBooleanNode() {}
// UseAs returns an ASTBoolean as equalities resolve to boolean values
func (equality) UseAs() Node { return AstBoolean{} }
func (e equality) SQLString(cfg *SQLGenerator) string {
// Equalities can be flipped without changing the result, so we can
// try both left = right and right = left.
if eq, ok := e.Left.(SupportsEquality); ok {
v, err := eq.EqualsSQLString(cfg, e.Not, e.Right)
if err == nil {
return v
}
}
if eq, ok := e.Right.(SupportsEquality); ok {
v, err := eq.EqualsSQLString(cfg, e.Not, e.Left)
if err == nil {
return v
}
}
cfg.AddError(xerrors.Errorf("unsupported equality: %T %s %T", e.Left, equalsOp(e.Not), e.Right))
return "EqualityError"
}
func (e equality) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
return boolEqualsSQLString(cfg, e, not, other)
}
func boolEqualsSQLString(cfg *SQLGenerator, a BooleanNode, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case BooleanNode:
bn, ok := other.(BooleanNode)
if !ok {
return "", xerrors.Errorf("not a boolean node: %T", other)
}
// Always wrap both sides in parens to ensure the correct precedence.
return fmt.Sprintf("%s %s %s",
BoolParenthesis(a).SQLString(cfg),
equalsOp(not),
BoolParenthesis(bn).SQLString(cfg),
), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T", a, equalsOp(not), other)
}
}
// nolint:revive
func equalsOp(not bool) string {
if not {
return "!="
}
return "="
}
func basicSQLEquality(cfg *SQLGenerator, not bool, a, b Node) string {
return fmt.Sprintf("%s %s %s",
a.SQLString(cfg),
equalsOp(not),
b.SQLString(cfg),
)
}

View File

@ -0,0 +1,130 @@
package sqltypes_test
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
)
func TestEquality(t *testing.T) {
t.Parallel()
testCases := []struct {
Name string
Equality sqltypes.Node
ExpectedSQL string
ExpectedErrors int
}{
{
Name: "String=String",
Equality: sqltypes.Equality(false,
sqltypes.String("foo"),
sqltypes.String("bar"),
),
ExpectedSQL: "'foo' = 'bar'",
},
{
Name: "Number=Number",
Equality: sqltypes.Equality(false,
sqltypes.Number("", json.Number("5")),
sqltypes.Number("", json.Number("22")),
),
ExpectedSQL: "5 = 22",
},
{
Name: "Bool=Bool",
Equality: sqltypes.Equality(false,
sqltypes.Bool(true),
sqltypes.Bool(false),
),
ExpectedSQL: "true = false",
},
{
Name: "Bool=Equality",
Equality: sqltypes.Equality(false,
sqltypes.Bool(true),
sqltypes.Equality(true,
sqltypes.Equality(true,
sqltypes.String("foo"),
sqltypes.String("bar"),
),
sqltypes.Bool(false),
),
),
ExpectedSQL: "true = (('foo' != 'bar') != false)",
},
{
Name: "Equality=Equality",
Equality: sqltypes.Equality(false,
sqltypes.Equality(true,
sqltypes.Bool(true),
sqltypes.Bool(false),
),
sqltypes.Equality(false,
sqltypes.String("foo"),
sqltypes.String("foo"),
),
),
ExpectedSQL: "(true != false) = ('foo' = 'foo')",
},
{
Name: "Membership=Membership",
Equality: sqltypes.Equality(false,
sqltypes.Equality(true,
sqltypes.MemberOf(
sqltypes.String("foo"),
must(sqltypes.Array("",
sqltypes.String("foo"),
sqltypes.String("bar"),
)),
),
sqltypes.Bool(false),
),
sqltypes.Equality(false,
sqltypes.Bool(true),
sqltypes.MemberOf(
sqltypes.Number("", "2"),
must(sqltypes.Array("",
sqltypes.Number("", "5"),
sqltypes.Number("", "2"),
)),
),
),
),
ExpectedSQL: "(('foo' = ANY(ARRAY ['foo','bar'])) != false) = (true = (2 = ANY(ARRAY [5,2])))",
},
{
Name: "AlwaysFalse=String",
Equality: sqltypes.Equality(false,
sqltypes.AlwaysFalseNode(sqltypes.String("foo")),
sqltypes.String("foo"),
),
ExpectedSQL: "false",
},
{
Name: "String=AlwaysFalse",
Equality: sqltypes.Equality(false,
sqltypes.String("foo"),
sqltypes.AlwaysFalseNode(sqltypes.String("foo")),
),
ExpectedSQL: "false",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
gen := sqltypes.NewSQLGenerator()
found := tc.Equality.SQLString(gen)
if tc.ExpectedErrors > 0 {
require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected AstNumber of errors")
} else {
require.Equal(t, tc.ExpectedSQL, found, "expected sql")
}
})
}
}

View File

@ -0,0 +1,19 @@
package sqltypes
type SQLGenerator struct {
errors []error
}
func NewSQLGenerator() *SQLGenerator {
return &SQLGenerator{}
}
func (g *SQLGenerator) AddError(err error) {
if err != nil {
g.errors = append(g.errors, err)
}
}
func (g *SQLGenerator) Errors() []error {
return g.errors
}

View File

@ -0,0 +1,61 @@
package sqltypes
import (
"golang.org/x/xerrors"
)
// SupportsContains is an interface that can be implemented by types that
// support "me.Contains(other)". This is `internal_member2` in the rego.
type SupportsContains interface {
ContainsSQL(cfg *SQLGenerator, other Node) (string, error)
}
// SupportsContainedIn is the inverse of SupportsContains. It is implemented
// from the "needle" rather than the haystack.
type SupportsContainedIn interface {
ContainedInSQL(cfg *SQLGenerator, other Node) (string, error)
}
var _ BooleanNode = memberOf{}
var _ Node = memberOf{}
var _ SupportsEquality = memberOf{}
type memberOf struct {
Needle Node
Haystack Node
}
func MemberOf(needle, haystack Node) BooleanNode {
return memberOf{
Needle: needle,
Haystack: haystack,
}
}
func (memberOf) IsBooleanNode() {}
func (memberOf) UseAs() Node { return AstBoolean{} }
func (e memberOf) SQLString(cfg *SQLGenerator) string {
// Equalities can be flipped without changing the result, so we can
// try both left = right and right = left.
if sc, ok := e.Haystack.(SupportsContains); ok {
v, err := sc.ContainsSQL(cfg, e.Needle)
if err == nil {
return v
}
}
if sc, ok := e.Needle.(SupportsContainedIn); ok {
v, err := sc.ContainedInSQL(cfg, e.Haystack)
if err == nil {
return v
}
}
cfg.AddError(xerrors.Errorf("unsupported contains: %T contains %T", e.Haystack, e.Needle))
return "MemberOfError"
}
func (e memberOf) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
return boolEqualsSQLString(cfg, e, not, other)
}

View File

@ -0,0 +1,116 @@
package sqltypes_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
)
func TestMembership(t *testing.T) {
t.Parallel()
testCases := []struct {
Name string
Membership sqltypes.Node
ExpectedSQL string
ExpectedErrors int
}{
{
Name: "StringArray",
Membership: sqltypes.MemberOf(
sqltypes.String("foo"),
must(sqltypes.Array("",
sqltypes.String("bar"),
sqltypes.String("buzz"),
)),
),
ExpectedSQL: "'foo' = ANY(ARRAY ['bar','buzz'])",
},
{
Name: "NumberArray",
Membership: sqltypes.MemberOf(
sqltypes.Number("", "5"),
must(sqltypes.Array("",
sqltypes.Number("", "2"),
sqltypes.Number("", "5"),
)),
),
ExpectedSQL: "5 = ANY(ARRAY [2,5])",
},
{
Name: "BoolArray",
Membership: sqltypes.MemberOf(
sqltypes.Bool(true),
must(sqltypes.Array("",
sqltypes.Bool(false),
sqltypes.Bool(true),
)),
),
ExpectedSQL: "true = ANY(ARRAY [false,true])",
},
{
Name: "EmptyArray",
Membership: sqltypes.MemberOf(
sqltypes.Bool(true),
must(sqltypes.Array("")),
),
ExpectedSQL: "false",
},
{
Name: "AlwaysFalseMember",
Membership: sqltypes.MemberOf(
sqltypes.AlwaysFalseNode(sqltypes.Bool(true)),
must(sqltypes.Array("",
sqltypes.Bool(false),
sqltypes.Bool(true),
)),
),
ExpectedSQL: "false",
},
{
Name: "AlwaysFalseArray",
Membership: sqltypes.MemberOf(
sqltypes.Bool(true),
sqltypes.AlwaysFalseNode(must(sqltypes.Array("",
sqltypes.Bool(false),
sqltypes.Bool(true),
))),
),
ExpectedSQL: "false",
},
// Errors
{
Name: "Unsupported",
Membership: sqltypes.MemberOf(
sqltypes.Bool(true),
sqltypes.Bool(true),
),
ExpectedErrors: 1,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
gen := sqltypes.NewSQLGenerator()
found := tc.Membership.SQLString(gen)
if tc.ExpectedErrors > 0 {
require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected some errors")
} else {
require.Equal(t, tc.ExpectedSQL, found, "expected sql")
require.Equal(t, tc.ExpectedErrors, 0, "expected no errors")
}
})
}
}
func must[V any](v V, err error) V {
if err != nil {
panic(err)
}
return v
}

View File

@ -0,0 +1,40 @@
package sqltypes
import (
"golang.org/x/xerrors"
)
type Node interface {
SQLString(cfg *SQLGenerator) string
// UseAs is a helper function to allow a node to be used as a different
// Node in operators. For example, a variable is really just a "string", so
// having the Equality operator check for "String" or "StringVar" is just
// excessive. Instead, we can just have the variable implement this function.
UseAs() Node
}
// BooleanNode is a node that returns a true/false when evaluated.
type BooleanNode interface {
Node
IsBooleanNode()
}
type RegoSource string
type invalidNode struct{}
func (invalidNode) UseAs() Node { return invalidNode{} }
func (invalidNode) SQLString(cfg *SQLGenerator) string {
cfg.AddError(xerrors.Errorf("invalid node called"))
return "invalid_type"
}
func IsPrimitive(n Node) bool {
switch n.(type) {
case AstBoolean, AstString, AstNumber:
return true
}
return false
}

View File

@ -0,0 +1,36 @@
package sqltypes
import (
"encoding/json"
"golang.org/x/xerrors"
)
type AstNumber struct {
Source RegoSource
// Value is intentionally vague as to if it's an integer or a float.
// This defers that decision to the user. Rego keeps all numbers in this
// type. If we were to source the type from something other than Rego,
// we might want to make a Float and Int type which keep the original
// precision.
Value json.Number
}
func Number(source RegoSource, v json.Number) Node {
return AstNumber{Value: v, Source: source}
}
func (AstNumber) UseAs() Node { return AstNumber{} }
func (n AstNumber) SQLString(_ *SQLGenerator) string {
return n.Value.String()
}
func (n AstNumber) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case AstNumber:
return basicSQLEquality(cfg, not, n, other), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T", n, equalsOp(not), other)
}
}

View File

@ -0,0 +1,45 @@
package sqltypes
import (
"golang.org/x/xerrors"
)
type astParenthesis struct {
Value BooleanNode
}
// BoolParenthesis wraps the given boolean node in parens.
// This is useful for grouping and avoiding ambiguity. This does not work for
// mathematical parenthesis to change order of operations.
func BoolParenthesis(value BooleanNode) BooleanNode {
// Wrapping primitives is useless.
if IsPrimitive(value) {
return value
}
// Unwrap any existing parens. Do not add excess parens.
if p, ok := value.(astParenthesis); ok {
return BoolParenthesis(p.Value)
}
return astParenthesis{Value: value}
}
func (astParenthesis) IsBooleanNode() {}
func (p astParenthesis) UseAs() Node { return p.Value.UseAs() }
func (p astParenthesis) SQLString(cfg *SQLGenerator) string {
return "(" + p.Value.SQLString(cfg) + ")"
}
func (p astParenthesis) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
if supp, ok := p.Value.(SupportsEquality); ok {
return supp.EqualsSQLString(cfg, not, other)
}
return "", xerrors.Errorf("unsupported equality: %T %s %T", p.Value, equalsOp(not), other)
}
func (p astParenthesis) ContainsSQL(cfg *SQLGenerator, other Node) (string, error) {
if supp, ok := p.Value.(SupportsContains); ok {
return supp.ContainsSQL(cfg, other)
}
return "", xerrors.Errorf("unsupported contains: %T %T", p.Value, other)
}

View File

@ -0,0 +1,29 @@
package sqltypes
import (
"golang.org/x/xerrors"
)
type AstString struct {
Source RegoSource
Value string
}
func String(v string) Node {
return AstString{Value: v, Source: RegoSource(v)}
}
func (AstString) UseAs() Node { return AstString{} }
func (s AstString) SQLString(_ *SQLGenerator) string {
return "'" + s.Value + "'"
}
func (s AstString) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case AstString:
return basicSQLEquality(cfg, not, s, other), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T", s, equalsOp(not), other)
}
}

View File

@ -0,0 +1,113 @@
package sqltypes
import (
"golang.org/x/xerrors"
"github.com/open-policy-agent/opa/ast"
)
type VariableMatcher interface {
ConvertVariable(rego ast.Ref) (Node, bool)
}
type VariableConverter struct {
converters []VariableMatcher
}
func NewVariableConverter() *VariableConverter {
return &VariableConverter{}
}
func (vc *VariableConverter) RegisterMatcher(m ...VariableMatcher) *VariableConverter {
vc.converters = append(vc.converters, m...)
// Returns the VariableConverter for easier instantiation
return vc
}
func (vc *VariableConverter) ConvertVariable(rego ast.Ref) (Node, bool) {
for _, c := range vc.converters {
if n, ok := c.ConvertVariable(rego); ok {
return n, true
}
}
return nil, false
}
// RegoVarPath will consume the following terms from the given rego Ref and
// return the remaining terms. If the path does not fully match, an error is
// returned. The first term must always be a Var.
func RegoVarPath(path []string, terms []*ast.Term) ([]*ast.Term, error) {
if len(terms) < len(path) {
return nil, xerrors.Errorf("path %s longer than rego path %s", path, terms)
}
if len(terms) == 0 || len(path) == 0 {
return nil, xerrors.Errorf("path %s and rego path %s must not be empty", path, terms)
}
varTerm, ok := terms[0].Value.(ast.Var)
if !ok {
return nil, xerrors.Errorf("expected var, got %T", terms[0])
}
if string(varTerm) != path[0] {
return nil, xerrors.Errorf("expected var %s, got %s", path[0], varTerm)
}
for i := 1; i < len(path); i++ {
nextTerm, ok := terms[i].Value.(ast.String)
if !ok {
return nil, xerrors.Errorf("expected ast.string, got %T", terms[i])
}
if string(nextTerm) != path[i] {
return nil, xerrors.Errorf("expected string %s, got %s", path[i], nextTerm)
}
}
return terms[len(path):], nil
}
var _ VariableMatcher = astStringVar{}
var _ Node = astStringVar{}
// astStringVar is any variable that represents a string.
type astStringVar struct {
Source RegoSource
FieldPath []string
ColumnString string
}
func StringVarMatcher(sqlString string, regoPath []string) VariableMatcher {
return astStringVar{FieldPath: regoPath, ColumnString: sqlString}
}
func (astStringVar) UseAs() Node { return AstString{} }
// ConvertVariable will return a new astStringVar Node if the given rego Ref
// matches this astStringVar.
func (s astStringVar) ConvertVariable(rego ast.Ref) (Node, bool) {
left, err := RegoVarPath(s.FieldPath, rego)
if err == nil && len(left) == 0 {
return astStringVar{
Source: RegoSource(rego.String()),
FieldPath: s.FieldPath,
ColumnString: s.ColumnString,
}, true
}
return nil, false
}
func (s astStringVar) SQLString(_ *SQLGenerator) string {
return s.ColumnString
}
func (s astStringVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case AstString:
return basicSQLEquality(cfg, not, s, other), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T", s, equalsOp(not), other)
}
}

View File

@ -316,25 +316,27 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
organization := httpmw.OrganizationParam(r)
templates, err := api.Database.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
OrganizationID: organization.ID,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceTemplate.Type)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching templates in organization.",
Message: "Internal error preparing sql filter.",
Detail: err.Error(),
})
return
}
// Filter templates based on rbac permissions
templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates)
templates, err := api.Database.GetAuthorizedTemplates(ctx, database.GetTemplatesWithFilterParams{
OrganizationID: organization.ID,
}, prepared)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching templates.",
Message: "Internal error fetching templates in organization.",
Detail: err.Error(),
})
return

View File

@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/coderdtest"

View File

@ -118,7 +118,8 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
filter.OwnerUsername = ""
}
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
// Workspaces do not have ACL columns.
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error preparing sql filter.",
@ -127,7 +128,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
return
}
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter)
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, prepared)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspaces.",

View File

@ -31,7 +31,7 @@ func (api *API) postGroupByOrganization(rw http.ResponseWriter, r *http.Request)
)
defer commitAudit()
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceGroup) {
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceGroup.InOrg(org.ID)) {
http.NotFound(rw, r)
return
}

View File

@ -11,7 +11,9 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/testutil"
@ -747,3 +749,210 @@ func TestUpdateTemplateACL(t *testing.T) {
require.Equal(t, http.StatusNotFound, cerr.StatusCode())
})
}
// TestTemplateAccess tests the rego -> sql conversion. We need to implement
// this test on at least 1 table type to ensure that the conversion is correct.
// The rbac tests only assert against static SQL queries.
// This is a full rbac test of many of the common role combinations.
//
//nolint:tparallel
func TestTemplateAccess(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
t.Cleanup(cancel)
ownerClient := coderdenttest.New(t, nil)
owner := coderdtest.CreateFirstUser(t, ownerClient)
_ = coderdenttest.AddLicense(t, ownerClient, coderdenttest.LicenseOptions{
TemplateRBAC: true,
})
type coderUser struct {
*codersdk.Client
User codersdk.User
}
type orgSetup struct {
Admin coderUser
MemberInGroup coderUser
MemberNoGroup coderUser
DefaultTemplate codersdk.Template
AllRead codersdk.Template
UserACL codersdk.Template
GroupACL codersdk.Template
Group codersdk.Group
Org codersdk.Organization
}
// Create the following users
// - owner: Site wide owner
// - template-admin
// - org-admin (org 1)
// - org-admin (org 2)
// - org-member (org 1)
// - org-member (org 2)
// Create the following templates in each org
// - template 1, default acls
// - template 2, all_user read
// - template 3, user_acl read for member
// - template 4, group_acl read for groupMember
templateAdmin := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin())
makeTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, acl codersdk.UpdateTemplateACL) codersdk.Template {
version := coderdtest.CreateTemplateVersion(t, client, orgID, nil)
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
err := client.UpdateTemplateACL(ctx, template.ID, acl)
require.NoError(t, err, "failed to update template acl")
return template
}
makeOrg := func(t *testing.T) orgSetup {
// Make org
orgName, err := cryptorand.String(5)
require.NoError(t, err, "org name")
// Make users
newOrg, err := ownerClient.CreateOrganization(ctx, codersdk.CreateOrganizationRequest{Name: orgName})
require.NoError(t, err, "failed to create org")
adminCli, adminUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgAdmin(newOrg.ID))
groupMemCli, groupMemUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgMember(newOrg.ID))
memberCli, memberUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgMember(newOrg.ID))
// Make group
group, err := adminCli.CreateGroup(ctx, newOrg.ID, codersdk.CreateGroupRequest{
Name: "SingleUser",
})
require.NoError(t, err, "failed to create group")
group, err = adminCli.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{
AddUsers: []string{groupMemUsr.ID.String()},
})
require.NoError(t, err, "failed to add user to group")
// Make templates
return orgSetup{
Admin: coderUser{Client: adminCli, User: adminUsr},
MemberInGroup: coderUser{Client: groupMemCli, User: groupMemUsr},
MemberNoGroup: coderUser{Client: memberCli, User: memberUsr},
Org: newOrg,
Group: group,
DefaultTemplate: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
GroupPerms: map[string]codersdk.TemplateRole{
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
},
}),
AllRead: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
GroupPerms: map[string]codersdk.TemplateRole{
newOrg.ID.String(): codersdk.TemplateRoleUse,
},
}),
UserACL: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
GroupPerms: map[string]codersdk.TemplateRole{
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
},
UserPerms: map[string]codersdk.TemplateRole{
memberUsr.ID.String(): codersdk.TemplateRoleUse,
},
}),
GroupACL: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
GroupPerms: map[string]codersdk.TemplateRole{
group.ID.String(): codersdk.TemplateRoleUse,
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
},
}),
}
}
// Make 2 organizations
orgs := []orgSetup{
makeOrg(t),
makeOrg(t),
}
testTemplateRead := func(t *testing.T, org orgSetup, usr *codersdk.Client, read []codersdk.Template) {
found, err := usr.TemplatesByOrganization(ctx, org.Org.ID)
require.NoError(t, err, "failed to get templates")
exp := make(map[uuid.UUID]codersdk.Template)
for _, tmpl := range read {
exp[tmpl.ID] = tmpl
}
for _, f := range found {
if _, ok := exp[f.ID]; !ok {
t.Errorf("found unexpected template %q", f.Name)
}
delete(exp, f.ID)
}
require.Len(t, exp, 0, "expected templates not found")
}
// nolint:paralleltest
t.Run("OwnerReadAll", func(t *testing.T) {
for _, o := range orgs {
// Owners can read all templates in all orgs
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
testTemplateRead(t, o, ownerClient, exp)
}
})
// nolint:paralleltest
t.Run("TemplateAdminReadAll", func(t *testing.T) {
for _, o := range orgs {
// Template Admins can read all templates in all orgs
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
testTemplateRead(t, o, templateAdmin, exp)
}
})
// nolint:paralleltest
t.Run("OrgAdminReadAllTheirs", func(t *testing.T) {
for i, o := range orgs {
cli := o.Admin.Client
// Only read their own org
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
testTemplateRead(t, o, cli, exp)
other := orgs[(i+1)%len(orgs)]
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
testTemplateRead(t, other, cli, []codersdk.Template{})
}
})
// nolint:paralleltest
t.Run("TestMemberNoGroup", func(t *testing.T) {
for i, o := range orgs {
cli := o.MemberNoGroup.Client
// Only read their own org
exp := []codersdk.Template{o.AllRead, o.UserACL}
testTemplateRead(t, o, cli, exp)
other := orgs[(i+1)%len(orgs)]
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
testTemplateRead(t, other, cli, []codersdk.Template{})
}
})
// nolint:paralleltest
t.Run("TestMemberInGroup", func(t *testing.T) {
for i, o := range orgs {
cli := o.MemberInGroup.Client
// Only read their own org
exp := []codersdk.Template{o.AllRead, o.GroupACL}
testTemplateRead(t, o, cli, exp)
other := orgs[(i+1)%len(orgs)]
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
testTemplateRead(t, other, cli, []codersdk.Template{})
}
})
}