mirror of https://github.com/coder/coder.git
feat: Convert rego queries into SQL clauses (#4225)
* feat: Convert rego queries into SQL clauses * Fix postgres quotes to single quotes * Ensure all test cases can compile into SQL clauses * Do not export extra types * Add custom query with rbac filter * First draft of a custom authorized db call * Add comments + tests * Support better regex style matching for variables * Handle jsonb arrays * Remove auth call on workspaces * Fix PG endpoints test * Match psql implementation * Add some comments * Remove unused argument * Add query name for tracking * Handle nested types This solves it without proper types in our AST. Might bite the bullet and implement some better types * Add comment * Renaming function call to GetAuthorizedWorkspaces
This commit is contained in:
parent
6325a9ea91
commit
cd4ab97efa
|
@ -13,6 +13,9 @@ import (
|
|||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// AuthorizeFilter takes a list of objects and returns the filtered list of
|
||||
// objects that the user is authorized to perform the given action on.
|
||||
// This is faster than calling Authorize() on each object.
|
||||
func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) {
|
||||
roles := httpmw.UserAuthorization(r)
|
||||
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objects)
|
||||
|
@ -85,6 +88,26 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
|
|||
return true
|
||||
}
|
||||
|
||||
// AuthorizeSQLFilter returns an authorization filter that can used in a
|
||||
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
|
||||
// from postgres are already authorized, and the caller does not need to
|
||||
// call 'Authorize()' on the returned objects.
|
||||
// Note the authorization is only for the given action and object type.
|
||||
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
|
||||
roles := httpmw.UserAuthorization(r)
|
||||
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("prepare filter: %w", err)
|
||||
}
|
||||
|
||||
filter, err := prepared.Compile()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile filter: %w", err)
|
||||
}
|
||||
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
// checkAuthorization returns if the current API key can use the given
|
||||
// permissions, factoring in the current user's roles and the API key scopes.
|
||||
func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -124,11 +124,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
},
|
||||
"GET:/api/v2/workspaces/": {
|
||||
StatusCode: http.StatusOK,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: workspaceRBACObj,
|
||||
},
|
||||
"GET:/api/v2/organizations/{organization}/templates": {
|
||||
StatusCode: http.StatusOK,
|
||||
AssertAction: rbac.ActionRead,
|
||||
|
@ -246,6 +241,9 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
"PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
"POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
|
||||
// Endpoints that use the SQLQuery filter.
|
||||
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
|
||||
}
|
||||
|
||||
// Routes like proxy routes support all HTTP methods. A helper func to expand
|
||||
|
@ -513,6 +511,12 @@ type RecordingAuthorizer struct {
|
|||
|
||||
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
|
||||
|
||||
// ByRoleNameSQL does not record the call. This matches the postgres behavior
|
||||
// of not calling Authorize()
|
||||
func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []string, _ rbac.Scope, _ rbac.Action, _ rbac.Object) error {
|
||||
return r.AlwaysReturn
|
||||
}
|
||||
|
||||
func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
|
||||
r.Called = &authCall{
|
||||
SubjectID: subjectID,
|
||||
|
@ -526,11 +530,12 @@ func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro
|
|||
|
||||
func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
|
||||
return &fakePreparedAuthorizer{
|
||||
Original: r,
|
||||
SubjectID: subjectID,
|
||||
Roles: roles,
|
||||
Scope: scope,
|
||||
Action: action,
|
||||
Original: r,
|
||||
SubjectID: subjectID,
|
||||
Roles: roles,
|
||||
Scope: scope,
|
||||
Action: action,
|
||||
HardCodedSQLString: "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -539,13 +544,39 @@ func (r *RecordingAuthorizer) reset() {
|
|||
}
|
||||
|
||||
type fakePreparedAuthorizer struct {
|
||||
Original *RecordingAuthorizer
|
||||
SubjectID string
|
||||
Roles []string
|
||||
Scope rbac.Scope
|
||||
Action rbac.Action
|
||||
Original *RecordingAuthorizer
|
||||
SubjectID string
|
||||
Roles []string
|
||||
Scope rbac.Scope
|
||||
Action rbac.Action
|
||||
HardCodedSQLString string
|
||||
HardCodedRegoString string
|
||||
}
|
||||
|
||||
func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error {
|
||||
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object)
|
||||
}
|
||||
|
||||
// Compile returns a compiled version of the authorizer that will work for
|
||||
// in memory databases. This fake version will not work against a SQL database.
|
||||
func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
|
||||
return f.Original.ByRoleNameSQL(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil
|
||||
}
|
||||
|
||||
func (f fakePreparedAuthorizer) RegoString() string {
|
||||
if f.HardCodedRegoString != "" {
|
||||
return f.HardCodedRegoString
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string {
|
||||
if f.HardCodedSQLString != "" {
|
||||
return f.HardCodedSQLString
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
)
|
||||
|
||||
type customQuerier interface {
|
||||
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error)
|
||||
}
|
||||
|
||||
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
|
||||
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
|
||||
// clause.
|
||||
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) {
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig()))
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.OwnerID,
|
||||
arg.OwnerUsername,
|
||||
arg.TemplateName,
|
||||
pq.Array(arg.TemplateIds),
|
||||
arg.Name,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get authorized workspaces: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Workspace
|
||||
for rows.Next() {
|
||||
var i Workspace
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OwnerID,
|
||||
&i.OrganizationID,
|
||||
&i.TemplateID,
|
||||
&i.Deleted,
|
||||
&i.Name,
|
||||
&i.AutostartSchedule,
|
||||
&i.Ttl,
|
||||
&i.LastUsedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
|
@ -520,7 +520,13 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) {
|
||||
func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.Workspace, error) {
|
||||
// A nil auth filter means no auth filter.
|
||||
workspaces, err := q.GetAuthorizedWorkspaces(ctx, arg, nil)
|
||||
return workspaces, err
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAuthorizedWorkspaces(_ context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.Workspace, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
|
@ -560,6 +566,11 @@ func (q *fakeQuerier) GetWorkspaces(_ context.Context, arg database.GetWorkspace
|
|||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
|
||||
continue
|
||||
}
|
||||
workspaces = append(workspaces, workspace)
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
// It extends the generated interface to add transaction support.
|
||||
type Store interface {
|
||||
querier
|
||||
// customQuerier contains custom queries that are not generated.
|
||||
customQuerier
|
||||
|
||||
InTx(func(Store) error) error
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ type Authorizer interface {
|
|||
|
||||
type PreparedAuthorized interface {
|
||||
Authorize(ctx context.Context, object Object) error
|
||||
Compile() (AuthorizeFilter, error)
|
||||
}
|
||||
|
||||
// Filter takes in a list of objects, and will filter the list removing all
|
||||
|
|
|
@ -781,6 +781,11 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
|
|||
partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type)
|
||||
require.NoError(t, err, "make prepared authorizer")
|
||||
|
||||
// Ensure the partial can compile to a SQL clause.
|
||||
// This does not guarantee that the clause is valid SQL.
|
||||
_, err = Compile(partialAuthz.partialQueries)
|
||||
require.NoError(t, err, "compile prepared authorizer")
|
||||
|
||||
// Also check the rego policy can form a valid partial query result.
|
||||
// This ensures we can convert the queries into SQL WHERE clauses in the future.
|
||||
// If this function returns 'Support' sections, then we cannot convert the query into SQL.
|
||||
|
|
|
@ -28,6 +28,14 @@ type PartialAuthorizer struct {
|
|||
|
||||
var _ PreparedAuthorized = (*PartialAuthorizer)(nil)
|
||||
|
||||
func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) {
|
||||
filter, err := Compile(pa.partialQueries)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile: %w", err)
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error {
|
||||
if pa.alwaysTrue {
|
||||
return nil
|
||||
|
|
|
@ -0,0 +1,616 @@
|
|||
package rbac
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type TermType string
|
||||
|
||||
const (
|
||||
VarTypeJsonbTextArray TermType = "jsonb-text-array"
|
||||
VarTypeText TermType = "text"
|
||||
)
|
||||
|
||||
type SQLColumn struct {
|
||||
// RegoMatch matches the original variable string.
|
||||
// If it is a match, then this variable config will apply.
|
||||
RegoMatch *regexp.Regexp
|
||||
// ColumnSelect is the name of the postgres column to select.
|
||||
// Can use capture groups from RegoMatch with $1, $2, etc.
|
||||
ColumnSelect string
|
||||
|
||||
// Type indicates the postgres type of the column. Some expressions will
|
||||
// need to know this in order to determine what SQL to produce.
|
||||
// An example is if the variable is a jsonb array, the "contains" SQL
|
||||
// query is `variable ? 'value'` instead of `'value' = ANY(variable)`.
|
||||
// This type is only needed to be provided
|
||||
Type TermType
|
||||
}
|
||||
|
||||
type SQLConfig struct {
|
||||
// Variables is a map of rego variable names to SQL columns.
|
||||
// Example:
|
||||
// "input\.object\.org_owner": SQLColumn{
|
||||
// ColumnSelect: "organization_id",
|
||||
// Type: VarTypeUUID
|
||||
// }
|
||||
// "input\.object\.owner": SQLColumn{
|
||||
// ColumnSelect: "owner_id",
|
||||
// Type: VarTypeUUID
|
||||
// }
|
||||
// "input\.object\.group_acl\.(.*)": SQLColumn{
|
||||
// ColumnSelect: "group_acl->$1",
|
||||
// Type: VarTypeJsonbTextArray
|
||||
// }
|
||||
Variables []SQLColumn
|
||||
}
|
||||
|
||||
func DefaultConfig() SQLConfig {
|
||||
return SQLConfig{
|
||||
Variables: []SQLColumn{
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
|
||||
ColumnSelect: "group_acl->$1",
|
||||
Type: VarTypeJsonbTextArray,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
|
||||
ColumnSelect: "user_acl->$1",
|
||||
Type: VarTypeJsonbTextArray,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
|
||||
ColumnSelect: "organization_id :: text",
|
||||
Type: VarTypeText,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
|
||||
ColumnSelect: "owner_id :: text",
|
||||
Type: VarTypeText,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type AuthorizeFilter interface {
|
||||
// RegoString is used in debugging to see the original rego expression.
|
||||
RegoString() string
|
||||
// SQLString returns the SQL expression that can be used in a WHERE clause.
|
||||
SQLString(cfg SQLConfig) string
|
||||
// Eval is required for the fake in memory database to work. The in memory
|
||||
// database can use this function to filter the results.
|
||||
Eval(object Object) bool
|
||||
}
|
||||
|
||||
// Compile will convert a rego query AST into our custom types. The output is
|
||||
// an AST that can be used to generate SQL.
|
||||
func Compile(partialQueries *rego.PartialQueries) (Expression, error) {
|
||||
if len(partialQueries.Support) > 0 {
|
||||
return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support))
|
||||
}
|
||||
|
||||
// 0 queries means the result is "undefined". This is the same as "false".
|
||||
if len(partialQueries.Queries) == 0 {
|
||||
return &termBoolean{
|
||||
base: base{Rego: "false"},
|
||||
Value: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Abort early if any of the "OR"'d expressions are the empty string.
|
||||
// This is the same as "true".
|
||||
for _, query := range partialQueries.Queries {
|
||||
if query.String() == "" {
|
||||
return &termBoolean{
|
||||
base: base{Rego: "true"},
|
||||
Value: true,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]Expression, 0, len(partialQueries.Queries))
|
||||
var builder strings.Builder
|
||||
for i := range partialQueries.Queries {
|
||||
query, err := processQuery(partialQueries.Queries[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, query)
|
||||
if i != 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(partialQueries.Queries[i].String())
|
||||
}
|
||||
return expOr{
|
||||
base: base{
|
||||
Rego: builder.String(),
|
||||
},
|
||||
Expressions: result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processQuery processes an entire set of expressions and joins them with
|
||||
// "AND".
|
||||
func processQuery(query ast.Body) (Expression, error) {
|
||||
expressions := make([]Expression, 0, len(query))
|
||||
for _, astExpr := range query {
|
||||
expr, err := processExpression(astExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expressions = append(expressions, expr)
|
||||
}
|
||||
|
||||
return expAnd{
|
||||
base: base{
|
||||
Rego: query.String(),
|
||||
},
|
||||
Expressions: expressions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func processExpression(expr *ast.Expr) (Expression, error) {
|
||||
if !expr.IsCall() {
|
||||
// This could be a single term that is a valid expression.
|
||||
if term, ok := expr.Terms.(*ast.Term); ok {
|
||||
value, err := processTerm(term)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("single term expression: %w", err)
|
||||
}
|
||||
if boolExp, ok := value.(Expression); ok {
|
||||
return boolExp, nil
|
||||
}
|
||||
// Default to error.
|
||||
}
|
||||
return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported")
|
||||
}
|
||||
|
||||
op := expr.Operator().String()
|
||||
base := base{Rego: op}
|
||||
switch op {
|
||||
case "neq", "eq", "equal":
|
||||
terms, err := processTerms(2, expr.Operands())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
|
||||
}
|
||||
return &opEqual{
|
||||
base: base,
|
||||
Terms: [2]Term{terms[0], terms[1]},
|
||||
Not: op == "neq",
|
||||
}, nil
|
||||
case "internal.member_2":
|
||||
terms, err := processTerms(2, expr.Operands())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
|
||||
}
|
||||
return &opInternalMember2{
|
||||
base: base,
|
||||
Needle: terms[0],
|
||||
Haystack: terms[1],
|
||||
}, nil
|
||||
default:
|
||||
return nil, xerrors.Errorf("invalid expression: operator %s not supported", op)
|
||||
}
|
||||
}
|
||||
|
||||
func processTerms(expected int, terms []*ast.Term) ([]Term, error) {
|
||||
if len(terms) != expected {
|
||||
return nil, xerrors.Errorf("too many arguments, expect %d found %d", expected, len(terms))
|
||||
}
|
||||
|
||||
result := make([]Term, 0, len(terms))
|
||||
for _, term := range terms {
|
||||
processed, err := processTerm(term)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid term: %w", err)
|
||||
}
|
||||
result = append(result, processed)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func processTerm(term *ast.Term) (Term, error) {
|
||||
base := base{Rego: term.String()}
|
||||
switch v := term.Value.(type) {
|
||||
case ast.Boolean:
|
||||
return &termBoolean{
|
||||
base: base,
|
||||
Value: bool(v),
|
||||
}, nil
|
||||
case ast.Ref:
|
||||
obj := &termObject{
|
||||
base: base,
|
||||
Variables: []termVariable{},
|
||||
}
|
||||
var idx int
|
||||
// A ref is a set of terms. If the first term is a var, then the
|
||||
// following terms are the path to the value.
|
||||
var builder strings.Builder
|
||||
for _, term := range v {
|
||||
if idx == 0 {
|
||||
if _, ok := v[0].Value.(ast.Var); !ok {
|
||||
return nil, xerrors.Errorf("invalid term (%s): ref must start with a var, started with %T", v[0].String(), v[0])
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := term.Value.(ast.Ref); ok {
|
||||
// New obj
|
||||
obj.Variables = append(obj.Variables, termVariable{
|
||||
base: base,
|
||||
Name: builder.String(),
|
||||
})
|
||||
builder.Reset()
|
||||
idx = 0
|
||||
}
|
||||
if builder.Len() != 0 {
|
||||
builder.WriteString(".")
|
||||
}
|
||||
builder.WriteString(trimQuotes(term.String()))
|
||||
idx++
|
||||
}
|
||||
|
||||
obj.Variables = append(obj.Variables, termVariable{
|
||||
base: base,
|
||||
Name: builder.String(),
|
||||
})
|
||||
return obj, nil
|
||||
case ast.Var:
|
||||
return &termVariable{
|
||||
Name: trimQuotes(v.String()),
|
||||
base: base,
|
||||
}, nil
|
||||
case ast.String:
|
||||
return &termString{
|
||||
Value: trimQuotes(v.String()),
|
||||
base: base,
|
||||
}, nil
|
||||
case ast.Set:
|
||||
slice := v.Slice()
|
||||
set := make([]Term, 0, len(slice))
|
||||
for _, elem := range slice {
|
||||
processed, err := processTerm(elem)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid set term %s: %w", elem.String(), err)
|
||||
}
|
||||
set = append(set, processed)
|
||||
}
|
||||
|
||||
return &termSet{
|
||||
Value: set,
|
||||
base: base,
|
||||
}, nil
|
||||
default:
|
||||
return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String())
|
||||
}
|
||||
}
|
||||
|
||||
type base struct {
|
||||
// Rego is the original rego string
|
||||
Rego string
|
||||
}
|
||||
|
||||
func (b base) RegoString() string {
|
||||
return b.Rego
|
||||
}
|
||||
|
||||
// Expression comprises a set of terms, operators, and functions. All
|
||||
// expressions return a boolean value.
|
||||
//
|
||||
// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo"
|
||||
type Expression interface {
|
||||
AuthorizeFilter
|
||||
}
|
||||
|
||||
type expAnd struct {
|
||||
base
|
||||
Expressions []Expression
|
||||
}
|
||||
|
||||
func (t expAnd) SQLString(cfg SQLConfig) string {
|
||||
if len(t.Expressions) == 1 {
|
||||
return t.Expressions[0].SQLString(cfg)
|
||||
}
|
||||
|
||||
exprs := make([]string, 0, len(t.Expressions))
|
||||
for _, expr := range t.Expressions {
|
||||
exprs = append(exprs, expr.SQLString(cfg))
|
||||
}
|
||||
return "(" + strings.Join(exprs, " AND ") + ")"
|
||||
}
|
||||
|
||||
func (t expAnd) Eval(object Object) bool {
|
||||
for _, expr := range t.Expressions {
|
||||
if !expr.Eval(object) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type expOr struct {
|
||||
base
|
||||
Expressions []Expression
|
||||
}
|
||||
|
||||
func (t expOr) SQLString(cfg SQLConfig) string {
|
||||
if len(t.Expressions) == 1 {
|
||||
return t.Expressions[0].SQLString(cfg)
|
||||
}
|
||||
|
||||
exprs := make([]string, 0, len(t.Expressions))
|
||||
for _, expr := range t.Expressions {
|
||||
exprs = append(exprs, expr.SQLString(cfg))
|
||||
}
|
||||
return "(" + strings.Join(exprs, " OR ") + ")"
|
||||
}
|
||||
|
||||
func (t expOr) Eval(object Object) bool {
|
||||
for _, expr := range t.Expressions {
|
||||
if expr.Eval(object) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Operator joins terms together to form an expression.
|
||||
// Operators are also expressions.
|
||||
//
|
||||
// Eg: "=", "neq", "internal.member_2", etc.
|
||||
type Operator interface {
|
||||
Expression
|
||||
}
|
||||
|
||||
type opEqual struct {
|
||||
base
|
||||
Terms [2]Term
|
||||
// For NotEqual
|
||||
Not bool
|
||||
}
|
||||
|
||||
func (t opEqual) SQLString(cfg SQLConfig) string {
|
||||
op := "="
|
||||
if t.Not {
|
||||
op = "!="
|
||||
}
|
||||
return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg))
|
||||
}
|
||||
|
||||
func (t opEqual) Eval(object Object) bool {
|
||||
a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object)
|
||||
if t.Not {
|
||||
return a != b
|
||||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
// opInternalMember2 is checking if the first term is a member of the second term.
|
||||
// The second term is a set or list.
|
||||
type opInternalMember2 struct {
|
||||
base
|
||||
Needle Term
|
||||
Haystack Term
|
||||
}
|
||||
|
||||
func (t opInternalMember2) Eval(object Object) bool {
|
||||
a, b := t.Needle.EvalTerm(object), t.Haystack.EvalTerm(object)
|
||||
bset, ok := b.([]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, elem := range bset {
|
||||
if a == elem {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t opInternalMember2) SQLString(cfg SQLConfig) string {
|
||||
if haystack, ok := t.Haystack.(*termObject); ok {
|
||||
// This is a special case where the haystack is a jsonb array.
|
||||
// There is a more general way to solve this, but that requires a lot
|
||||
// more code to cover a lot more cases that we do not care about.
|
||||
// To handle this more generally we should implement "Array" as a type.
|
||||
// Then have the `contains` function on the Array type. This would defer
|
||||
// knowing the element type to the Array and cover more cases without
|
||||
// having to add more "if" branches here.
|
||||
// But until we need more cases, our basic type system is ok, and
|
||||
// this is the only case we need to handle.
|
||||
if haystack.SQLType(cfg) == VarTypeJsonbTextArray {
|
||||
return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg))
|
||||
}
|
||||
|
||||
// Term is a single value in an expression. Terms can be variables or constants.
|
||||
//
|
||||
// Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner",
|
||||
// "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}"
|
||||
type Term interface {
|
||||
RegoString() string
|
||||
SQLString(cfg SQLConfig) string
|
||||
// Eval will evaluate the term
|
||||
// Terms can eval to any type. The operator/expression will type check.
|
||||
EvalTerm(object Object) interface{}
|
||||
}
|
||||
|
||||
type termString struct {
|
||||
base
|
||||
Value string
|
||||
}
|
||||
|
||||
func (t termString) EvalTerm(_ Object) interface{} {
|
||||
return t.Value
|
||||
}
|
||||
|
||||
func (t termString) SQLString(_ SQLConfig) string {
|
||||
return "'" + t.Value + "'"
|
||||
}
|
||||
|
||||
func (termString) SQLType(_ SQLConfig) TermType {
|
||||
return VarTypeText
|
||||
}
|
||||
|
||||
// termObject is a variable that can be dereferenced. We count some rego objects
|
||||
// as single variables, eg: input.object.org_owner. In reality, it is a nested
|
||||
// object.
|
||||
// In rego, we can dereference the object with the "." operator, which we can
|
||||
// handle with regex.
|
||||
// Or we can dereference the object with the "[]", which we can handle with this
|
||||
// term type.
|
||||
type termObject struct {
|
||||
base
|
||||
Variables []termVariable
|
||||
}
|
||||
|
||||
func (t termObject) EvalTerm(obj Object) interface{} {
|
||||
if len(t.Variables) == 0 {
|
||||
return t.Variables[0].EvalTerm(obj)
|
||||
}
|
||||
panic("no nested structures are supported yet")
|
||||
}
|
||||
|
||||
func (t termObject) SQLType(cfg SQLConfig) TermType {
|
||||
// Without a full type system, let's just assume the type of the first var
|
||||
// is the resulting type. This is correct for our use case.
|
||||
// Solving this more generally requires a full type system, which is
|
||||
// excessive for our mostly static policy.
|
||||
return t.Variables[0].SQLType(cfg)
|
||||
}
|
||||
|
||||
func (t termObject) SQLString(cfg SQLConfig) string {
|
||||
if len(t.Variables) == 1 {
|
||||
return t.Variables[0].SQLString(cfg)
|
||||
}
|
||||
// Combine the last 2 variables into 1 variable.
|
||||
end := t.Variables[len(t.Variables)-1]
|
||||
before := t.Variables[len(t.Variables)-2]
|
||||
|
||||
// Recursively solve the SQLString by removing the last nested reference.
|
||||
// This continues until we have a single variable.
|
||||
return termObject{
|
||||
base: t.base,
|
||||
Variables: append(
|
||||
t.Variables[:len(t.Variables)-2],
|
||||
termVariable{
|
||||
base: base{
|
||||
Rego: before.base.Rego + "[" + end.base.Rego + "]",
|
||||
},
|
||||
// Convert the end to SQL string. We evaluate each term
|
||||
// one at a time.
|
||||
Name: before.Name + "." + end.SQLString(cfg),
|
||||
},
|
||||
),
|
||||
}.SQLString(cfg)
|
||||
}
|
||||
|
||||
type termVariable struct {
|
||||
base
|
||||
Name string
|
||||
}
|
||||
|
||||
func (t termVariable) EvalTerm(obj Object) interface{} {
|
||||
switch t.Name {
|
||||
case "input.object.org_owner":
|
||||
return obj.OrgID
|
||||
case "input.object.owner":
|
||||
return obj.Owner
|
||||
case "input.object.type":
|
||||
return obj.Type
|
||||
default:
|
||||
return fmt.Sprintf("'Unknown variable %s'", t.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (t termVariable) SQLType(cfg SQLConfig) TermType {
|
||||
if col := t.ColumnConfig(cfg); col != nil {
|
||||
return col.Type
|
||||
}
|
||||
return VarTypeText
|
||||
}
|
||||
|
||||
func (t termVariable) SQLString(cfg SQLConfig) string {
|
||||
if col := t.ColumnConfig(cfg); col != nil {
|
||||
matches := col.RegoMatch.FindStringSubmatch(t.Name)
|
||||
if len(matches) > 0 {
|
||||
// This config matches this variable.
|
||||
replace := make([]string, 0, len(matches)*2)
|
||||
for i, m := range matches {
|
||||
replace = append(replace, fmt.Sprintf("$%d", i))
|
||||
replace = append(replace, m)
|
||||
}
|
||||
replacer := strings.NewReplacer(replace...)
|
||||
return replacer.Replace(col.ColumnSelect)
|
||||
}
|
||||
}
|
||||
|
||||
return t.Name
|
||||
}
|
||||
|
||||
// ColumnConfig returns the correct SQLColumn settings for the
|
||||
// term. If there is no configured column, it will return nil.
|
||||
func (t termVariable) ColumnConfig(cfg SQLConfig) *SQLColumn {
|
||||
for _, col := range cfg.Variables {
|
||||
matches := col.RegoMatch.MatchString(t.Name)
|
||||
if matches {
|
||||
return &col
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// termSet is a set of unique terms.
|
||||
type termSet struct {
|
||||
base
|
||||
Value []Term
|
||||
}
|
||||
|
||||
func (t termSet) EvalTerm(obj Object) interface{} {
|
||||
set := make([]interface{}, 0, len(t.Value))
|
||||
for _, term := range t.Value {
|
||||
set = append(set, term.EvalTerm(obj))
|
||||
}
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
func (t termSet) SQLString(cfg SQLConfig) string {
|
||||
elems := make([]string, 0, len(t.Value))
|
||||
for _, v := range t.Value {
|
||||
elems = append(elems, v.SQLString(cfg))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ","))
|
||||
}
|
||||
|
||||
type termBoolean struct {
|
||||
base
|
||||
Value bool
|
||||
}
|
||||
|
||||
func (t termBoolean) Eval(_ Object) bool {
|
||||
return t.Value
|
||||
}
|
||||
|
||||
func (t termBoolean) EvalTerm(_ Object) interface{} {
|
||||
return t.Value
|
||||
}
|
||||
|
||||
func (t termBoolean) SQLString(_ SQLConfig) string {
|
||||
return strconv.FormatBool(t.Value)
|
||||
}
|
||||
|
||||
func trimQuotes(s string) string {
|
||||
return strings.Trim(s, "\"")
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
package rbac
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompileQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := ast.ParserOptions{
|
||||
AllFutureKeywords: true,
|
||||
}
|
||||
t.Run("EmptyQuery", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(®o.PartialQueries{
|
||||
Queries: []ast.Body{
|
||||
must(ast.ParseBody("")),
|
||||
},
|
||||
Support: []*ast.Module{},
|
||||
})
|
||||
require.NoError(t, err, "compile empty")
|
||||
|
||||
require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'")
|
||||
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'")
|
||||
})
|
||||
|
||||
t.Run("TrueQuery", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(®o.PartialQueries{
|
||||
Queries: []ast.Body{
|
||||
must(ast.ParseBody("true")),
|
||||
},
|
||||
Support: []*ast.Module{},
|
||||
})
|
||||
require.NoError(t, err, "compile")
|
||||
|
||||
require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'")
|
||||
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'")
|
||||
})
|
||||
|
||||
t.Run("ACLIn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(®o.PartialQueries{
|
||||
Queries: []ast.Body{
|
||||
ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list.allUsers`, opts),
|
||||
},
|
||||
Support: []*ast.Module{},
|
||||
})
|
||||
require.NoError(t, err, "compile")
|
||||
|
||||
require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member")
|
||||
require.Equal(t, `group_acl->allUsers ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in")
|
||||
})
|
||||
|
||||
t.Run("Complex", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(®o.PartialQueries{
|
||||
Queries: []ast.Body{
|
||||
ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts),
|
||||
ast.MustParseBodyWithOpts(`input.object.org_owner in {"a", "b", "c"}`, opts),
|
||||
ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts),
|
||||
ast.MustParseBodyWithOpts(`"read" in input.object.acl_group_list.allUsers`, opts),
|
||||
ast.MustParseBodyWithOpts(`"read" in input.object.acl_user_list.me`, opts),
|
||||
},
|
||||
Support: []*ast.Module{},
|
||||
})
|
||||
require.NoError(t, err, "compile")
|
||||
require.Equal(t, `(organization_id :: text != '' OR `+
|
||||
`organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+
|
||||
`organization_id :: text != '' OR `+
|
||||
`group_acl->allUsers ? 'read' OR `+
|
||||
`user_acl->me ? 'read')`,
|
||||
expression.SQLString(DefaultConfig()), "complex")
|
||||
})
|
||||
|
||||
t.Run("SetDereference", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(®o.PartialQueries{
|
||||
Queries: []ast.Body{
|
||||
ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list[input.object.org_owner]`, opts),
|
||||
},
|
||||
Support: []*ast.Module{},
|
||||
})
|
||||
require.NoError(t, err, "compile")
|
||||
require.Equal(t, `group_acl->organization_id :: text ? '*'`,
|
||||
expression.SQLString(DefaultConfig()), "set dereference")
|
||||
})
|
||||
}
|
|
@ -113,17 +113,16 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
|
|||
filter.OwnerUsername = ""
|
||||
}
|
||||
|
||||
workspaces, err := api.Database.GetWorkspaces(ctx, filter)
|
||||
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspaces.",
|
||||
Message: "Internal error preparing sql filter.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Only return workspaces the user can read
|
||||
workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces)
|
||||
workspaces, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspaces.",
|
||||
|
|
Loading…
Reference in New Issue