From af59e2bcfa524dd0f93a186b69a8e43d2a31702e Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 9 Feb 2023 13:47:17 -0600 Subject: [PATCH] chore: Optimize rego policy input allocations (#6135) * chore: Optimize rego policy evaluation allocations Manually convert to ast.Value instead of using generic json.Marshal conversion. * Add a unit test that prevents regressions of rego input The optimized input is always compared to the normal json marshal parser. --- coderd/coderdtest/authorize.go | 2 +- coderd/rbac/astvalue.go | 210 +++++++++++++++++++++++++++ coderd/rbac/authz.go | 26 +--- coderd/rbac/builtin.go | 18 +++ coderd/rbac/builtin_internal_test.go | 184 ++++++++++++++++++++++- coderd/rbac/error.go | 25 +++- coderd/rbac/input.json | 16 +- coderd/rbac/partial.go | 43 +++--- 8 files changed, 466 insertions(+), 58 deletions(-) create mode 100644 coderd/rbac/astvalue.go diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 3fcbb63cb6..ac7c6f5ce4 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -445,7 +445,7 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) { // Always fail auth from this point forward a.authorizer.Wrapped = &FakeAuthorizer{ - AlwaysReturn: rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil), + AlwaysReturn: rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), rbac.Subject{}, "", rbac.Object{}, nil), } routeMissing := make(map[string]bool) diff --git a/coderd/rbac/astvalue.go b/coderd/rbac/astvalue.go new file mode 100644 index 0000000000..2feb20d636 --- /dev/null +++ b/coderd/rbac/astvalue.go @@ -0,0 +1,210 @@ +package rbac + +import ( + "github.com/open-policy-agent/opa/ast" + "golang.org/x/xerrors" +) + +// regoInputValue returns a rego input value for the given subject, action, and +// object. This rego input is already parsed and can be used directly in a +// rego query. +func regoInputValue(subject Subject, action Action, object Object) (ast.Value, error) { + regoSubj, err := subject.regoValue() + if err != nil { + return nil, xerrors.Errorf("subject: %w", err) + } + + s := [2]*ast.Term{ + ast.StringTerm("subject"), + ast.NewTerm(regoSubj), + } + a := [2]*ast.Term{ + ast.StringTerm("action"), + ast.StringTerm(string(action)), + } + o := [2]*ast.Term{ + ast.StringTerm("object"), + ast.NewTerm(object.regoValue()), + } + + input := ast.NewObject(s, a, o) + + return input, nil +} + +// regoPartialInputValue is the same as regoInputValue but only includes the +// object type. This is for partial evaluations. +func regoPartialInputValue(subject Subject, action Action, objectType string) (ast.Value, error) { + regoSubj, err := subject.regoValue() + if err != nil { + return nil, xerrors.Errorf("subject: %w", err) + } + + s := [2]*ast.Term{ + ast.StringTerm("subject"), + ast.NewTerm(regoSubj), + } + a := [2]*ast.Term{ + ast.StringTerm("action"), + ast.StringTerm(string(action)), + } + o := [2]*ast.Term{ + ast.StringTerm("object"), + ast.NewTerm(ast.NewObject( + [2]*ast.Term{ + ast.StringTerm("type"), + ast.StringTerm(objectType), + }), + ), + } + + input := ast.NewObject(s, a, o) + + return input, nil +} + +// regoValue returns the ast.Object representation of the subject. +func (s Subject) regoValue() (ast.Value, error) { + subjRoles, err := s.Roles.Expand() + if err != nil { + return nil, xerrors.Errorf("expand roles: %w", err) + } + + subjScope, err := s.Scope.Expand() + if err != nil { + return nil, xerrors.Errorf("expand scope: %w", err) + } + subj := ast.NewObject( + [2]*ast.Term{ + ast.StringTerm("id"), + ast.StringTerm(s.ID), + }, + [2]*ast.Term{ + ast.StringTerm("roles"), + ast.NewTerm(regoSlice(subjRoles)), + }, + [2]*ast.Term{ + ast.StringTerm("scope"), + ast.NewTerm(subjScope.regoValue()), + }, + [2]*ast.Term{ + ast.StringTerm("groups"), + ast.NewTerm(regoSliceString(s.Groups...)), + }, + ) + + return subj, nil +} + +func (z Object) regoValue() ast.Value { + userACL := ast.NewObject() + for k, v := range z.ACLUserList { + userACL.Insert(ast.StringTerm(k), ast.NewTerm(regoSlice(v))) + } + grpACL := ast.NewObject() + for k, v := range z.ACLGroupList { + grpACL.Insert(ast.StringTerm(k), ast.NewTerm(regoSlice(v))) + } + return ast.NewObject( + [2]*ast.Term{ + ast.StringTerm("id"), + ast.StringTerm(z.ID), + }, + [2]*ast.Term{ + ast.StringTerm("owner"), + ast.StringTerm(z.Owner), + }, + [2]*ast.Term{ + ast.StringTerm("org_owner"), + ast.StringTerm(z.OrgID), + }, + [2]*ast.Term{ + ast.StringTerm("type"), + ast.StringTerm(z.Type), + }, + [2]*ast.Term{ + ast.StringTerm("acl_user_list"), + ast.NewTerm(userACL), + }, + [2]*ast.Term{ + ast.StringTerm("acl_group_list"), + ast.NewTerm(grpACL), + }, + ) +} + +func (role Role) regoValue() ast.Value { + orgMap := ast.NewObject() + for k, p := range role.Org { + orgMap.Insert(ast.StringTerm(k), ast.NewTerm(regoSlice(p))) + } + return ast.NewObject( + [2]*ast.Term{ + ast.StringTerm("site"), + ast.NewTerm(regoSlice(role.Site)), + }, + [2]*ast.Term{ + ast.StringTerm("org"), + ast.NewTerm(orgMap), + }, + [2]*ast.Term{ + ast.StringTerm("user"), + ast.NewTerm(regoSlice(role.User)), + }, + ) +} + +func (s Scope) regoValue() ast.Value { + r, ok := s.Role.regoValue().(ast.Object) + if !ok { + panic("developer error: role is not an object") + } + r.Insert( + ast.StringTerm("allow_list"), + ast.NewTerm(regoSliceString(s.AllowIDList...)), + ) + return r +} + +func (perm Permission) regoValue() ast.Value { + return ast.NewObject( + [2]*ast.Term{ + ast.StringTerm("negate"), + ast.BooleanTerm(perm.Negate), + }, + [2]*ast.Term{ + ast.StringTerm("resource_type"), + ast.StringTerm(perm.ResourceType), + }, + [2]*ast.Term{ + ast.StringTerm("action"), + ast.StringTerm(string(perm.Action)), + }, + ) +} + +func (act Action) regoValue() ast.Value { + return ast.StringTerm(string(act)).Value +} + +type regoValue interface { + regoValue() ast.Value +} + +// regoSlice returns the ast.Array representation of the slice. +// The slice must contain only types that implement the regoValue interface. +func regoSlice[T regoValue](slice []T) *ast.Array { + terms := make([]*ast.Term, len(slice)) + for i, v := range slice { + terms[i] = ast.NewTerm(v.regoValue()) + } + return ast.NewArray(terms...) +} + +func regoSliceString(slice ...string) *ast.Array { + terms := make([]*ast.Term, len(slice)) + for i, v := range slice { + terms[i] = ast.StringTerm(v) + } + return ast.NewArray(terms...) +} diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index a15270b59b..e718248f1a 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -263,34 +263,18 @@ func (a RegoAuthorizer) authorize(ctx context.Context, subject Subject, action A return xerrors.Errorf("subject must have a scope") } - subjRoles, err := subject.Roles.Expand() + astV, err := regoInputValue(subject, action, object) if err != nil { - return xerrors.Errorf("expand roles: %w", err) + return xerrors.Errorf("convert input to value: %w", err) } - subjScope, err := subject.Scope.Expand() + results, err := a.query.Eval(ctx, rego.EvalParsedInput(astV)) if err != nil { - return xerrors.Errorf("expand scope: %w", err) - } - - input := map[string]interface{}{ - "subject": authSubject{ - ID: subject.ID, - Roles: subjRoles, - Groups: subject.Groups, - Scope: subjScope, - }, - "object": object, - "action": action, - } - - results, err := a.query.Eval(ctx, rego.EvalInput(input)) - if err != nil { - return ForbiddenWithInternal(xerrors.Errorf("eval rego: %w", err), input, results) + return ForbiddenWithInternal(xerrors.Errorf("eval rego: %w", err), subject, action, object, results) } if !results.Allowed() { - return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), input, results) + return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), subject, action, object, results) } return nil } diff --git a/coderd/rbac/builtin.go b/coderd/rbac/builtin.go index 4839812459..c5b0396629 100644 --- a/coderd/rbac/builtin.go +++ b/coderd/rbac/builtin.go @@ -1,6 +1,7 @@ package rbac import ( + "sort" "strings" "github.com/google/uuid" @@ -79,6 +80,8 @@ var ( Site: permissions(map[string][]Action{ ResourceWildcard.Type: {WildcardSymbol}, }), + Org: map[string][]Permission{}, + User: []Permission{}, } }, @@ -94,6 +97,7 @@ var ( // All users can see the provisioner daemons. ResourceProvisionerDaemon.Type: {ActionRead}, }), + Org: map[string][]Permission{}, User: permissions(map[string][]Action{ ResourceWildcard.Type: {WildcardSymbol}, }), @@ -113,6 +117,8 @@ var ( ResourceTemplate.Type: {ActionRead}, ResourceAuditLog.Type: {ActionRead}, }), + Org: map[string][]Permission{}, + User: []Permission{}, } }, @@ -128,6 +134,8 @@ var ( // CRUD to provisioner daemons for now. ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete}, }), + Org: map[string][]Permission{}, + User: []Permission{}, } }, @@ -142,6 +150,8 @@ var ( ResourceOrganizationMember.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete}, ResourceGroup.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete}, }), + Org: map[string][]Permission{}, + User: []Permission{}, } }, @@ -151,6 +161,7 @@ var ( return Role{ Name: roleName(orgAdmin, organizationID), DisplayName: "Organization Admin", + Site: []Permission{}, Org: map[string][]Permission{ organizationID: { { @@ -160,6 +171,7 @@ var ( }, }, }, + User: []Permission{}, } }, @@ -169,6 +181,7 @@ var ( return Role{ Name: roleName(orgMember, organizationID), DisplayName: "", + Site: []Permission{}, Org: map[string][]Permission{ organizationID: { { @@ -192,6 +205,7 @@ var ( }, }, }, + User: []Permission{}, } }, } @@ -422,5 +436,9 @@ func permissions(perms map[string][]Action) []Permission { }) } } + // Deterministic ordering of permissions + sort.Slice(list, func(i, j int) bool { + return list[i].ResourceType < list[j].ResourceType + }) return list } diff --git a/coderd/rbac/builtin_internal_test.go b/coderd/rbac/builtin_internal_test.go index 0921cb361a..4c86a71356 100644 --- a/coderd/rbac/builtin_internal_test.go +++ b/coderd/rbac/builtin_internal_test.go @@ -3,11 +3,191 @@ package rbac import ( "testing" - "github.com/stretchr/testify/require" - "github.com/google/uuid" + "github.com/open-policy-agent/opa/ast" + "github.com/stretchr/testify/require" ) +// BenchmarkRBACValueAllocation benchmarks the cost of allocating a rego input +// value. By default, `ast.InterfaceToValue` is used to convert the input, +// which uses json marshalling under the hood. +// +// Currently ast.Object.insert() is the slowest part of the process and allocates +// the most amount of bytes. This general approach copies all of our struct +// data and uses a lot of extra memory for handling things like sort order. +// A possible large improvement would be to implement the ast.Value interface directly. +func BenchmarkRBACValueAllocation(b *testing.B) { + actor := Subject{ + Roles: RoleNames{RoleOrgMember(uuid.New()), RoleOrgAdmin(uuid.New()), RoleMember()}, + ID: uuid.NewString(), + Scope: ScopeAll, + Groups: []string{uuid.NewString(), uuid.NewString(), uuid.NewString()}, + } + obj := ResourceTemplate. + WithID(uuid.New()). + InOrg(uuid.New()). + WithOwner(uuid.NewString()). + WithGroupACL(map[string][]Action{ + uuid.NewString(): {ActionRead, ActionCreate}, + uuid.NewString(): {ActionRead, ActionCreate}, + uuid.NewString(): {ActionRead, ActionCreate}, + }).WithACLUserList(map[string][]Action{ + uuid.NewString(): {ActionRead, ActionCreate}, + uuid.NewString(): {ActionRead, ActionCreate}, + }) + + jsonSubject := authSubject{ + ID: actor.ID, + Roles: must(actor.Roles.Expand()), + Groups: actor.Groups, + Scope: must(actor.Scope.Expand()), + } + + b.Run("ManualRegoValue", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := regoInputValue(actor, ActionRead, obj) + require.NoError(b, err) + } + }) + b.Run("JSONRegoValue", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ast.InterfaceToValue(jsonSubject) + require.NoError(b, err) + } + }) +} + +// TestRegoInputValue ensures the custom rego input parser returns the +// same value as the default json parser. The json parser is always correct, +// and the custom parser is used to reduce allocations. This optimization +// should yield the same results. Anything different is a bug. +func TestRegoInputValue(t *testing.T) { + t.Parallel() + + actor := Subject{ + Roles: RoleNames{RoleOrgMember(uuid.New()), RoleOrgAdmin(uuid.New()), RoleMember()}, + ID: uuid.NewString(), + Scope: ScopeAll, + Groups: []string{uuid.NewString(), uuid.NewString(), uuid.NewString()}, + } + + obj := ResourceTemplate. + WithID(uuid.New()). + InOrg(uuid.New()). + WithOwner(uuid.NewString()). + WithGroupACL(map[string][]Action{ + uuid.NewString(): {ActionRead, ActionCreate}, + uuid.NewString(): {ActionRead, ActionCreate}, + uuid.NewString(): {ActionRead, ActionCreate}, + }).WithACLUserList(map[string][]Action{ + uuid.NewString(): {ActionRead, ActionCreate}, + uuid.NewString(): {ActionRead, ActionCreate}, + }) + + action := ActionRead + + t.Run("InputValue", func(t *testing.T) { + t.Parallel() + + // This is the input that would be passed to the rego policy. + jsonInput := map[string]interface{}{ + "subject": authSubject{ + ID: actor.ID, + Roles: must(actor.Roles.Expand()), + Groups: actor.Groups, + Scope: must(actor.Scope.Expand()), + }, + "action": action, + "object": obj, + } + + manual, err := regoInputValue(actor, action, obj) + require.NoError(t, err) + + general, err := ast.InterfaceToValue(jsonInput) + require.NoError(t, err) + + // The custom parser does not set these fields because they are not needed. + // To ensure the outputs are identical, intentionally overwrite all names + // to the same values. + ignoreNames(t, manual) + ignoreNames(t, general) + + cmp := manual.Compare(general) + require.Equal(t, 0, cmp, "manual and general input values should be equal") + }) + + t.Run("PartialInputValue", func(t *testing.T) { + t.Parallel() + + // This is the input that would be passed to the rego policy. + jsonInput := map[string]interface{}{ + "subject": authSubject{ + ID: actor.ID, + Roles: must(actor.Roles.Expand()), + Groups: actor.Groups, + Scope: must(actor.Scope.Expand()), + }, + "action": action, + "object": map[string]interface{}{ + "type": obj.Type, + }, + } + + manual, err := regoPartialInputValue(actor, action, obj.Type) + require.NoError(t, err) + + general, err := ast.InterfaceToValue(jsonInput) + require.NoError(t, err) + + // The custom parser does not set these fields because they are not needed. + // To ensure the outputs are identical, intentionally overwrite all names + // to the same values. + ignoreNames(t, manual) + ignoreNames(t, general) + + cmp := manual.Compare(general) + require.Equal(t, 0, cmp, "manual and general input values should be equal") + }) +} + +// ignoreNames sets all names to "ignore" to ensure the values are identical. +func ignoreNames(t *testing.T, value ast.Value) { + t.Helper() + + // Override the names of the roles + ref := ast.Ref{ + ast.StringTerm("subject"), + ast.StringTerm("roles"), + } + roles, err := value.Find(ref) + require.NoError(t, err) + + rolesArray, ok := roles.(*ast.Array) + require.True(t, ok, "roles is expected to be an array") + + rolesArray.Foreach(func(term *ast.Term) { + obj, _ := term.Value.(ast.Object) + // Ignore all names + obj.Insert(ast.StringTerm("name"), ast.StringTerm("ignore")) + obj.Insert(ast.StringTerm("display_name"), ast.StringTerm("ignore")) + }) + + // Override the names of the scope role + ref = ast.Ref{ + ast.StringTerm("subject"), + ast.StringTerm("scope"), + } + scope, err := value.Find(ref) + require.NoError(t, err) + + scopeObj, ok := scope.(ast.Object) + require.True(t, ok, "scope is expected to be an object") + + scopeObj.Insert(ast.StringTerm("name"), ast.StringTerm("ignore")) + scopeObj.Insert(ast.StringTerm("display_name"), ast.StringTerm("ignore")) +} + func TestRoleByName(t *testing.T) { t.Parallel() diff --git a/coderd/rbac/error.go b/coderd/rbac/error.go index 6b63bb8860..ec0bf02f8f 100644 --- a/coderd/rbac/error.go +++ b/coderd/rbac/error.go @@ -14,20 +14,25 @@ type UnauthorizedError struct { // internal is the internal error that should never be shown to the client. // It is only for debugging purposes. internal error - input map[string]interface{} - output rego.ResultSet + + // These fields are for debugging purposes. + subject Subject + action Action + // Note only the object type is set for partial execution. + object Object + + output rego.ResultSet } // ForbiddenWithInternal creates a new error that will return a simple // "forbidden" to the client, logging internally the more detailed message // provided. -func ForbiddenWithInternal(internal error, input map[string]interface{}, output rego.ResultSet) *UnauthorizedError { - if input == nil { - input = map[string]interface{}{} - } +func ForbiddenWithInternal(internal error, subject Subject, action Action, object Object, output rego.ResultSet) *UnauthorizedError { return &UnauthorizedError{ internal: internal, - input: input, + subject: subject, + action: action, + object: object, output: output, } } @@ -43,7 +48,11 @@ func (e *UnauthorizedError) Internal() error { } func (e *UnauthorizedError) Input() map[string]interface{} { - return e.input + return map[string]interface{}{ + "subject": e.subject, + "action": e.action, + "object": e.object, + } } // Output contains the results of the Rego query for debugging. diff --git a/coderd/rbac/input.json b/coderd/rbac/input.json index f762de96ba..5e464168ac 100644 --- a/coderd/rbac/input.json +++ b/coderd/rbac/input.json @@ -12,7 +12,21 @@ }, "subject": { "id": "10d03e62-7703-4df5-a358-4f76577d4e2f", - "roles": [], + "roles": [ + { + "name": "owner", + "display_name": "Owner", + "site": [ + { + "negate": false, + "resource_type": "*", + "action": "*" + } + ], + "org": {}, + "user": [] + } + ], "groups": ["b617a647-b5d0-4cbe-9e40-26f89710bf18"], "scope": { "name": "Scope_all", diff --git a/coderd/rbac/partial.go b/coderd/rbac/partial.go index 19ee0a6c80..6347651409 100644 --- a/coderd/rbac/partial.go +++ b/coderd/rbac/partial.go @@ -17,8 +17,12 @@ type PartialAuthorizer struct { // partialQueries is mainly used for unit testing to assert our rego policy // can always be compressed into a set of queries. partialQueries *rego.PartialQueries + // input is used purely for debugging and logging. - input map[string]interface{} + subjectInput Subject + subjectAction Action + subjectResourceType Object + // preparedQueries are the compiled set of queries after partial evaluation. // Cache these prepared queries to avoid re-compiling the queries. // If alwaysTrue is true, then ignore these. @@ -54,7 +58,8 @@ func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error // If we have no queries, then no queries can return 'true'. // So the result is always 'false'. if len(pa.preparedQueries) == 0 { - return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), pa.input, nil) + return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), + pa.subjectInput, pa.subjectAction, pa.subjectResourceType, nil) } parsed, err := ast.InterfaceToValue(map[string]interface{}{ @@ -118,7 +123,8 @@ EachQueryLoop: return nil } - return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), pa.input, nil) + return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), + pa.subjectInput, pa.subjectAction, pa.subjectResourceType, nil) } func newPartialAuthorizer(ctx context.Context, subject Subject, action Action, objectType string) (*PartialAuthorizer, error) { @@ -129,27 +135,9 @@ func newPartialAuthorizer(ctx context.Context, subject Subject, action Action, o return nil, xerrors.Errorf("subject must have a scope") } - roles, err := subject.Roles.Expand() + input, err := regoPartialInputValue(subject, action, objectType) if err != nil { - return nil, xerrors.Errorf("expand roles: %w", err) - } - - scope, err := subject.Scope.Expand() - if err != nil { - return nil, xerrors.Errorf("expand scope: %w", err) - } - - input := map[string]interface{}{ - "subject": authSubject{ - ID: subject.ID, - Roles: roles, - Scope: scope, - Groups: subject.Groups, - }, - "object": map[string]string{ - "type": objectType, - }, - "action": action, + return nil, xerrors.Errorf("prepare input: %w", err) } // Run the rego policy with a few unknown fields. This should simplify our @@ -164,7 +152,7 @@ func newPartialAuthorizer(ctx context.Context, subject Subject, action Action, o "input.object.acl_user_list", "input.object.acl_group_list", }), - rego.Input(input), + rego.ParsedInput(input), ).Partial(ctx) if err != nil { return nil, xerrors.Errorf("prepare: %w", err) @@ -173,7 +161,12 @@ func newPartialAuthorizer(ctx context.Context, subject Subject, action Action, o pAuth := &PartialAuthorizer{ partialQueries: partialQueries, preparedQueries: []rego.PreparedEvalQuery{}, - input: input, + subjectInput: subject, + subjectResourceType: Object{ + Type: objectType, + ID: "prepared-object", + }, + subjectAction: action, } // Prepare each query to optimize the runtime when we iterate over the objects.