feat: Convert rego queries into SQL clauses (#4225)

* feat: Convert rego queries into SQL clauses

* Fix postgres quotes to single quotes

* Ensure all test cases can compile into SQL clauses

* Do not export extra types

* Add custom query with rbac filter

* First draft of a custom authorized db call

* Add comments + tests

* Support better regex style matching for variables

* Handle jsonb arrays

* Remove auth call on workspaces

* Fix PG endpoints test

* Match psql implementation

* Add some comments

* Remove unused argument

* Add query name for tracking

* Handle nested types

This solves it without proper types in our AST.
Might bite the bullet and implement some better types

* Add comment

* Renaming function call to GetAuthorizedWorkspaces
This commit is contained in:
Steven Masley 2022-10-04 11:35:33 -04:00 committed by GitHub
parent 6325a9ea91
commit cd4ab97efa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 870 additions and 20 deletions

View File

@ -13,6 +13,9 @@ import (
"github.com/coder/coder/codersdk"
)
// AuthorizeFilter takes a list of objects and returns the filtered list of
// objects that the user is authorized to perform the given action on.
// This is faster than calling Authorize() on each object.
func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) {
roles := httpmw.UserAuthorization(r)
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objects)
@ -85,6 +88,26 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
return true
}
// AuthorizeSQLFilter returns an authorization filter that can used in a
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
// from postgres are already authorized, and the caller does not need to
// call 'Authorize()' on the returned objects.
// Note the authorization is only for the given action and object type.
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
roles := httpmw.UserAuthorization(r)
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType)
if err != nil {
return nil, xerrors.Errorf("prepare filter: %w", err)
}
filter, err := prepared.Compile()
if err != nil {
return nil, xerrors.Errorf("compile filter: %w", err)
}
return filter, nil
}
// checkAuthorization returns if the current API key can use the given
// permissions, factoring in the current user's roles and the API key scopes.
func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {

View File

@ -124,11 +124,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,
},
"GET:/api/v2/workspaces/": {
StatusCode: http.StatusOK,
AssertAction: rbac.ActionRead,
AssertObject: workspaceRBACObj,
},
"GET:/api/v2/organizations/{organization}/templates": {
StatusCode: http.StatusOK,
AssertAction: rbac.ActionRead,
@ -246,6 +241,9 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true},
"POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
"POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
// Endpoints that use the SQLQuery filter.
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
}
// Routes like proxy routes support all HTTP methods. A helper func to expand
@ -513,6 +511,12 @@ type RecordingAuthorizer struct {
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
// ByRoleNameSQL does not record the call. This matches the postgres behavior
// of not calling Authorize()
func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []string, _ rbac.Scope, _ rbac.Action, _ rbac.Object) error {
return r.AlwaysReturn
}
func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
r.Called = &authCall{
SubjectID: subjectID,
@ -526,11 +530,12 @@ func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro
func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
return &fakePreparedAuthorizer{
Original: r,
SubjectID: subjectID,
Roles: roles,
Scope: scope,
Action: action,
Original: r,
SubjectID: subjectID,
Roles: roles,
Scope: scope,
Action: action,
HardCodedSQLString: "true",
}, nil
}
@ -539,13 +544,39 @@ func (r *RecordingAuthorizer) reset() {
}
type fakePreparedAuthorizer struct {
Original *RecordingAuthorizer
SubjectID string
Roles []string
Scope rbac.Scope
Action rbac.Action
Original *RecordingAuthorizer
SubjectID string
Roles []string
Scope rbac.Scope
Action rbac.Action
HardCodedSQLString string
HardCodedRegoString string
}
func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error {
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object)
}
// Compile returns a compiled version of the authorizer that will work for
// in memory databases. This fake version will not work against a SQL database.
func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
return f, nil
}
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
return f.Original.ByRoleNameSQL(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil
}
func (f fakePreparedAuthorizer) RegoString() string {
if f.HardCodedRegoString != "" {
return f.HardCodedRegoString
}
panic("not implemented")
}
func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string {
if f.HardCodedSQLString != "" {
return f.HardCodedSQLString
}
panic("not implemented")
}

View File

@ -0,0 +1,62 @@
package database
import (
"context"
"fmt"
"github.com/lib/pq"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac"
)
type customQuerier interface {
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error)
}
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
// clause.
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) {
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig()))
rows, err := q.db.QueryContext(ctx, query,
arg.Deleted,
arg.OwnerID,
arg.OwnerUsername,
arg.TemplateName,
pq.Array(arg.TemplateIds),
arg.Name,
)
if err != nil {
return nil, xerrors.Errorf("get authorized workspaces: %w", err)
}
defer rows.Close()
var items []Workspace
for rows.Next() {
var i Workspace
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OwnerID,
&i.OrganizationID,
&i.TemplateID,
&i.Deleted,
&i.Name,
&i.AutostartSchedule,
&i.Ttl,
&i.LastUsedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -520,7 +520,13 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
}, nil
}
func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) {
func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) {
// A nil auth filter means no auth filter.
workspaces, err := q.GetAuthorizedWorkspaces(ctx, arg, nil)
return workspaces, err
}
func (q *fakeQuerier) GetAuthorizedWorkspaces(_ context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -560,6 +566,11 @@ func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspace
continue
}
}
// If the filter exists, ensure the object is authorized.
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
continue
}
workspaces = append(workspaces, workspace)
}

View File

@ -20,6 +20,8 @@ import (
// It extends the generated interface to add transaction support.
type Store interface {
querier
// customQuerier contains custom queries that are not generated.
customQuerier
InTx(func(Store) error) error
}

View File

@ -21,6 +21,7 @@ type Authorizer interface {
type PreparedAuthorized interface {
Authorize(ctx context.Context, object Object) error
Compile() (AuthorizeFilter, error)
}
// Filter takes in a list of objects, and will filter the list removing all

View File

@ -781,6 +781,11 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type)
require.NoError(t, err, "make prepared authorizer")
// Ensure the partial can compile to a SQL clause.
// This does not guarantee that the clause is valid SQL.
_, err = Compile(partialAuthz.partialQueries)
require.NoError(t, err, "compile prepared authorizer")
// Also check the rego policy can form a valid partial query result.
// This ensures we can convert the queries into SQL WHERE clauses in the future.
// If this function returns 'Support' sections, then we cannot convert the query into SQL.

View File

@ -28,6 +28,14 @@ type PartialAuthorizer struct {
var _ PreparedAuthorized = (*PartialAuthorizer)(nil)
func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) {
filter, err := Compile(pa.partialQueries)
if err != nil {
return nil, xerrors.Errorf("compile: %w", err)
}
return filter, nil
}
func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error {
if pa.alwaysTrue {
return nil

616
coderd/rbac/query.go Normal file
View File

@ -0,0 +1,616 @@
package rbac
import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"golang.org/x/xerrors"
)
type TermType string
const (
VarTypeJsonbTextArray TermType = "jsonb-text-array"
VarTypeText TermType = "text"
)
type SQLColumn struct {
// RegoMatch matches the original variable string.
// If it is a match, then this variable config will apply.
RegoMatch *regexp.Regexp
// ColumnSelect is the name of the postgres column to select.
// Can use capture groups from RegoMatch with $1, $2, etc.
ColumnSelect string
// Type indicates the postgres type of the column. Some expressions will
// need to know this in order to determine what SQL to produce.
// An example is if the variable is a jsonb array, the "contains" SQL
// query is `variable ? 'value'` instead of `'value' = ANY(variable)`.
// This type is only needed to be provided
Type TermType
}
type SQLConfig struct {
// Variables is a map of rego variable names to SQL columns.
// Example:
// "input\.object\.org_owner": SQLColumn{
// ColumnSelect: "organization_id",
// Type: VarTypeUUID
// }
// "input\.object\.owner": SQLColumn{
// ColumnSelect: "owner_id",
// Type: VarTypeUUID
// }
// "input\.object\.group_acl\.(.*)": SQLColumn{
// ColumnSelect: "group_acl->$1",
// Type: VarTypeJsonbTextArray
// }
Variables []SQLColumn
}
func DefaultConfig() SQLConfig {
return SQLConfig{
Variables: []SQLColumn{
{
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
ColumnSelect: "group_acl->$1",
Type: VarTypeJsonbTextArray,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
ColumnSelect: "user_acl->$1",
Type: VarTypeJsonbTextArray,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
ColumnSelect: "organization_id :: text",
Type: VarTypeText,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
ColumnSelect: "owner_id :: text",
Type: VarTypeText,
},
},
}
}
type AuthorizeFilter interface {
// RegoString is used in debugging to see the original rego expression.
RegoString() string
// SQLString returns the SQL expression that can be used in a WHERE clause.
SQLString(cfg SQLConfig) string
// Eval is required for the fake in memory database to work. The in memory
// database can use this function to filter the results.
Eval(object Object) bool
}
// Compile will convert a rego query AST into our custom types. The output is
// an AST that can be used to generate SQL.
func Compile(partialQueries *rego.PartialQueries) (Expression, error) {
if len(partialQueries.Support) > 0 {
return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support))
}
// 0 queries means the result is "undefined". This is the same as "false".
if len(partialQueries.Queries) == 0 {
return &termBoolean{
base: base{Rego: "false"},
Value: false,
}, nil
}
// Abort early if any of the "OR"'d expressions are the empty string.
// This is the same as "true".
for _, query := range partialQueries.Queries {
if query.String() == "" {
return &termBoolean{
base: base{Rego: "true"},
Value: true,
}, nil
}
}
result := make([]Expression, 0, len(partialQueries.Queries))
var builder strings.Builder
for i := range partialQueries.Queries {
query, err := processQuery(partialQueries.Queries[i])
if err != nil {
return nil, err
}
result = append(result, query)
if i != 0 {
builder.WriteString("\n")
}
builder.WriteString(partialQueries.Queries[i].String())
}
return expOr{
base: base{
Rego: builder.String(),
},
Expressions: result,
}, nil
}
// processQuery processes an entire set of expressions and joins them with
// "AND".
func processQuery(query ast.Body) (Expression, error) {
expressions := make([]Expression, 0, len(query))
for _, astExpr := range query {
expr, err := processExpression(astExpr)
if err != nil {
return nil, err
}
expressions = append(expressions, expr)
}
return expAnd{
base: base{
Rego: query.String(),
},
Expressions: expressions,
}, nil
}
func processExpression(expr *ast.Expr) (Expression, error) {
if !expr.IsCall() {
// This could be a single term that is a valid expression.
if term, ok := expr.Terms.(*ast.Term); ok {
value, err := processTerm(term)
if err != nil {
return nil, xerrors.Errorf("single term expression: %w", err)
}
if boolExp, ok := value.(Expression); ok {
return boolExp, nil
}
// Default to error.
}
return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported")
}
op := expr.Operator().String()
base := base{Rego: op}
switch op {
case "neq", "eq", "equal":
terms, err := processTerms(2, expr.Operands())
if err != nil {
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
}
return &opEqual{
base: base,
Terms: [2]Term{terms[0], terms[1]},
Not: op == "neq",
}, nil
case "internal.member_2":
terms, err := processTerms(2, expr.Operands())
if err != nil {
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
}
return &opInternalMember2{
base: base,
Needle: terms[0],
Haystack: terms[1],
}, nil
default:
return nil, xerrors.Errorf("invalid expression: operator %s not supported", op)
}
}
func processTerms(expected int, terms []*ast.Term) ([]Term, error) {
if len(terms) != expected {
return nil, xerrors.Errorf("too many arguments, expect %d found %d", expected, len(terms))
}
result := make([]Term, 0, len(terms))
for _, term := range terms {
processed, err := processTerm(term)
if err != nil {
return nil, xerrors.Errorf("invalid term: %w", err)
}
result = append(result, processed)
}
return result, nil
}
func processTerm(term *ast.Term) (Term, error) {
base := base{Rego: term.String()}
switch v := term.Value.(type) {
case ast.Boolean:
return &termBoolean{
base: base,
Value: bool(v),
}, nil
case ast.Ref:
obj := &termObject{
base: base,
Variables: []termVariable{},
}
var idx int
// A ref is a set of terms. If the first term is a var, then the
// following terms are the path to the value.
var builder strings.Builder
for _, term := range v {
if idx == 0 {
if _, ok := v[0].Value.(ast.Var); !ok {
return nil, xerrors.Errorf("invalid term (%s): ref must start with a var, started with %T", v[0].String(), v[0])
}
}
if _, ok := term.Value.(ast.Ref); ok {
// New obj
obj.Variables = append(obj.Variables, termVariable{
base: base,
Name: builder.String(),
})
builder.Reset()
idx = 0
}
if builder.Len() != 0 {
builder.WriteString(".")
}
builder.WriteString(trimQuotes(term.String()))
idx++
}
obj.Variables = append(obj.Variables, termVariable{
base: base,
Name: builder.String(),
})
return obj, nil
case ast.Var:
return &termVariable{
Name: trimQuotes(v.String()),
base: base,
}, nil
case ast.String:
return &termString{
Value: trimQuotes(v.String()),
base: base,
}, nil
case ast.Set:
slice := v.Slice()
set := make([]Term, 0, len(slice))
for _, elem := range slice {
processed, err := processTerm(elem)
if err != nil {
return nil, xerrors.Errorf("invalid set term %s: %w", elem.String(), err)
}
set = append(set, processed)
}
return &termSet{
Value: set,
base: base,
}, nil
default:
return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String())
}
}
type base struct {
// Rego is the original rego string
Rego string
}
func (b base) RegoString() string {
return b.Rego
}
// Expression comprises a set of terms, operators, and functions. All
// expressions return a boolean value.
//
// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo"
type Expression interface {
AuthorizeFilter
}
type expAnd struct {
base
Expressions []Expression
}
func (t expAnd) SQLString(cfg SQLConfig) string {
if len(t.Expressions) == 1 {
return t.Expressions[0].SQLString(cfg)
}
exprs := make([]string, 0, len(t.Expressions))
for _, expr := range t.Expressions {
exprs = append(exprs, expr.SQLString(cfg))
}
return "(" + strings.Join(exprs, " AND ") + ")"
}
func (t expAnd) Eval(object Object) bool {
for _, expr := range t.Expressions {
if !expr.Eval(object) {
return false
}
}
return true
}
type expOr struct {
base
Expressions []Expression
}
func (t expOr) SQLString(cfg SQLConfig) string {
if len(t.Expressions) == 1 {
return t.Expressions[0].SQLString(cfg)
}
exprs := make([]string, 0, len(t.Expressions))
for _, expr := range t.Expressions {
exprs = append(exprs, expr.SQLString(cfg))
}
return "(" + strings.Join(exprs, " OR ") + ")"
}
func (t expOr) Eval(object Object) bool {
for _, expr := range t.Expressions {
if expr.Eval(object) {
return true
}
}
return false
}
// Operator joins terms together to form an expression.
// Operators are also expressions.
//
// Eg: "=", "neq", "internal.member_2", etc.
type Operator interface {
Expression
}
type opEqual struct {
base
Terms [2]Term
// For NotEqual
Not bool
}
func (t opEqual) SQLString(cfg SQLConfig) string {
op := "="
if t.Not {
op = "!="
}
return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg))
}
func (t opEqual) Eval(object Object) bool {
a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object)
if t.Not {
return a != b
}
return a == b
}
// opInternalMember2 is checking if the first term is a member of the second term.
// The second term is a set or list.
type opInternalMember2 struct {
base
Needle Term
Haystack Term
}
func (t opInternalMember2) Eval(object Object) bool {
a, b := t.Needle.EvalTerm(object), t.Haystack.EvalTerm(object)
bset, ok := b.([]interface{})
if !ok {
return false
}
for _, elem := range bset {
if a == elem {
return true
}
}
return false
}
func (t opInternalMember2) SQLString(cfg SQLConfig) string {
if haystack, ok := t.Haystack.(*termObject); ok {
// This is a special case where the haystack is a jsonb array.
// There is a more general way to solve this, but that requires a lot
// more code to cover a lot more cases that we do not care about.
// To handle this more generally we should implement "Array" as a type.
// Then have the `contains` function on the Array type. This would defer
// knowing the element type to the Array and cover more cases without
// having to add more "if" branches here.
// But until we need more cases, our basic type system is ok, and
// this is the only case we need to handle.
if haystack.SQLType(cfg) == VarTypeJsonbTextArray {
return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg))
}
}
return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg))
}
// Term is a single value in an expression. Terms can be variables or constants.
//
// Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner",
// "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}"
type Term interface {
RegoString() string
SQLString(cfg SQLConfig) string
// Eval will evaluate the term
// Terms can eval to any type. The operator/expression will type check.
EvalTerm(object Object) interface{}
}
type termString struct {
base
Value string
}
func (t termString) EvalTerm(_ Object) interface{} {
return t.Value
}
func (t termString) SQLString(_ SQLConfig) string {
return "'" + t.Value + "'"
}
func (termString) SQLType(_ SQLConfig) TermType {
return VarTypeText
}
// termObject is a variable that can be dereferenced. We count some rego objects
// as single variables, eg: input.object.org_owner. In reality, it is a nested
// object.
// In rego, we can dereference the object with the "." operator, which we can
// handle with regex.
// Or we can dereference the object with the "[]", which we can handle with this
// term type.
type termObject struct {
base
Variables []termVariable
}
func (t termObject) EvalTerm(obj Object) interface{} {
if len(t.Variables) == 0 {
return t.Variables[0].EvalTerm(obj)
}
panic("no nested structures are supported yet")
}
func (t termObject) SQLType(cfg SQLConfig) TermType {
// Without a full type system, let's just assume the type of the first var
// is the resulting type. This is correct for our use case.
// Solving this more generally requires a full type system, which is
// excessive for our mostly static policy.
return t.Variables[0].SQLType(cfg)
}
func (t termObject) SQLString(cfg SQLConfig) string {
if len(t.Variables) == 1 {
return t.Variables[0].SQLString(cfg)
}
// Combine the last 2 variables into 1 variable.
end := t.Variables[len(t.Variables)-1]
before := t.Variables[len(t.Variables)-2]
// Recursively solve the SQLString by removing the last nested reference.
// This continues until we have a single variable.
return termObject{
base: t.base,
Variables: append(
t.Variables[:len(t.Variables)-2],
termVariable{
base: base{
Rego: before.base.Rego + "[" + end.base.Rego + "]",
},
// Convert the end to SQL string. We evaluate each term
// one at a time.
Name: before.Name + "." + end.SQLString(cfg),
},
),
}.SQLString(cfg)
}
type termVariable struct {
base
Name string
}
func (t termVariable) EvalTerm(obj Object) interface{} {
switch t.Name {
case "input.object.org_owner":
return obj.OrgID
case "input.object.owner":
return obj.Owner
case "input.object.type":
return obj.Type
default:
return fmt.Sprintf("'Unknown variable %s'", t.Name)
}
}
func (t termVariable) SQLType(cfg SQLConfig) TermType {
if col := t.ColumnConfig(cfg); col != nil {
return col.Type
}
return VarTypeText
}
func (t termVariable) SQLString(cfg SQLConfig) string {
if col := t.ColumnConfig(cfg); col != nil {
matches := col.RegoMatch.FindStringSubmatch(t.Name)
if len(matches) > 0 {
// This config matches this variable.
replace := make([]string, 0, len(matches)*2)
for i, m := range matches {
replace = append(replace, fmt.Sprintf("$%d", i))
replace = append(replace, m)
}
replacer := strings.NewReplacer(replace...)
return replacer.Replace(col.ColumnSelect)
}
}
return t.Name
}
// ColumnConfig returns the correct SQLColumn settings for the
// term. If there is no configured column, it will return nil.
func (t termVariable) ColumnConfig(cfg SQLConfig) *SQLColumn {
for _, col := range cfg.Variables {
matches := col.RegoMatch.MatchString(t.Name)
if matches {
return &col
}
}
return nil
}
// termSet is a set of unique terms.
type termSet struct {
base
Value []Term
}
func (t termSet) EvalTerm(obj Object) interface{} {
set := make([]interface{}, 0, len(t.Value))
for _, term := range t.Value {
set = append(set, term.EvalTerm(obj))
}
return set
}
func (t termSet) SQLString(cfg SQLConfig) string {
elems := make([]string, 0, len(t.Value))
for _, v := range t.Value {
elems = append(elems, v.SQLString(cfg))
}
return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ","))
}
type termBoolean struct {
base
Value bool
}
func (t termBoolean) Eval(_ Object) bool {
return t.Value
}
func (t termBoolean) EvalTerm(_ Object) interface{} {
return t.Value
}
func (t termBoolean) SQLString(_ SQLConfig) string {
return strconv.FormatBool(t.Value)
}
func trimQuotes(s string) string {
return strings.Trim(s, "\"")
}

View File

@ -0,0 +1,92 @@
package rbac
import (
"testing"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/stretchr/testify/require"
)
func TestCompileQuery(t *testing.T) {
t.Parallel()
opts := ast.ParserOptions{
AllFutureKeywords: true,
}
t.Run("EmptyQuery", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
must(ast.ParseBody("")),
},
Support: []*ast.Module{},
})
require.NoError(t, err, "compile empty")
require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'")
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'")
})
t.Run("TrueQuery", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
must(ast.ParseBody("true")),
},
Support: []*ast.Module{},
})
require.NoError(t, err, "compile")
require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'")
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'")
})
t.Run("ACLIn", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list.allUsers`, opts),
},
Support: []*ast.Module{},
})
require.NoError(t, err, "compile")
require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member")
require.Equal(t, `group_acl->allUsers ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in")
})
t.Run("Complex", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts),
ast.MustParseBodyWithOpts(`input.object.org_owner in {"a", "b", "c"}`, opts),
ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts),
ast.MustParseBodyWithOpts(`"read" in input.object.acl_group_list.allUsers`, opts),
ast.MustParseBodyWithOpts(`"read" in input.object.acl_user_list.me`, opts),
},
Support: []*ast.Module{},
})
require.NoError(t, err, "compile")
require.Equal(t, `(organization_id :: text != '' OR `+
`organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+
`organization_id :: text != '' OR `+
`group_acl->allUsers ? 'read' OR `+
`user_acl->me ? 'read')`,
expression.SQLString(DefaultConfig()), "complex")
})
t.Run("SetDereference", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list[input.object.org_owner]`, opts),
},
Support: []*ast.Module{},
})
require.NoError(t, err, "compile")
require.Equal(t, `group_acl->organization_id :: text ? '*'`,
expression.SQLString(DefaultConfig()), "set dereference")
})
}

View File

@ -113,17 +113,16 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
filter.OwnerUsername = ""
}
workspaces, err := api.Database.GetWorkspaces(ctx, filter)
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspaces.",
Message: "Internal error preparing sql filter.",
Detail: err.Error(),
})
return
}
// Only return workspaces the user can read
workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces)
workspaces, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspaces.",