chore: Optimize rego policy input allocations (#6135)

* chore: Optimize rego policy evaluation allocations

Manually convert to ast.Value instead of using generic
json.Marshal conversion.

* Add a unit test that prevents regressions of rego input

The optimized input is always compared to the normal json
marshal parser.
This commit is contained in:
Steven Masley 2023-02-09 13:47:17 -06:00 committed by GitHub
parent 22f6400ea5
commit af59e2bcfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 466 additions and 58 deletions

View File

@ -445,7 +445,7 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a
func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) {
// Always fail auth from this point forward
a.authorizer.Wrapped = &FakeAuthorizer{
AlwaysReturn: rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil),
AlwaysReturn: rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), rbac.Subject{}, "", rbac.Object{}, nil),
}
routeMissing := make(map[string]bool)

210
coderd/rbac/astvalue.go Normal file
View File

@ -0,0 +1,210 @@
package rbac
import (
"github.com/open-policy-agent/opa/ast"
"golang.org/x/xerrors"
)
// regoInputValue returns a rego input value for the given subject, action, and
// object. This rego input is already parsed and can be used directly in a
// rego query.
func regoInputValue(subject Subject, action Action, object Object) (ast.Value, error) {
regoSubj, err := subject.regoValue()
if err != nil {
return nil, xerrors.Errorf("subject: %w", err)
}
s := [2]*ast.Term{
ast.StringTerm("subject"),
ast.NewTerm(regoSubj),
}
a := [2]*ast.Term{
ast.StringTerm("action"),
ast.StringTerm(string(action)),
}
o := [2]*ast.Term{
ast.StringTerm("object"),
ast.NewTerm(object.regoValue()),
}
input := ast.NewObject(s, a, o)
return input, nil
}
// regoPartialInputValue is the same as regoInputValue but only includes the
// object type. This is for partial evaluations.
func regoPartialInputValue(subject Subject, action Action, objectType string) (ast.Value, error) {
regoSubj, err := subject.regoValue()
if err != nil {
return nil, xerrors.Errorf("subject: %w", err)
}
s := [2]*ast.Term{
ast.StringTerm("subject"),
ast.NewTerm(regoSubj),
}
a := [2]*ast.Term{
ast.StringTerm("action"),
ast.StringTerm(string(action)),
}
o := [2]*ast.Term{
ast.StringTerm("object"),
ast.NewTerm(ast.NewObject(
[2]*ast.Term{
ast.StringTerm("type"),
ast.StringTerm(objectType),
}),
),
}
input := ast.NewObject(s, a, o)
return input, nil
}
// regoValue returns the ast.Object representation of the subject.
func (s Subject) regoValue() (ast.Value, error) {
subjRoles, err := s.Roles.Expand()
if err != nil {
return nil, xerrors.Errorf("expand roles: %w", err)
}
subjScope, err := s.Scope.Expand()
if err != nil {
return nil, xerrors.Errorf("expand scope: %w", err)
}
subj := ast.NewObject(
[2]*ast.Term{
ast.StringTerm("id"),
ast.StringTerm(s.ID),
},
[2]*ast.Term{
ast.StringTerm("roles"),
ast.NewTerm(regoSlice(subjRoles)),
},
[2]*ast.Term{
ast.StringTerm("scope"),
ast.NewTerm(subjScope.regoValue()),
},
[2]*ast.Term{
ast.StringTerm("groups"),
ast.NewTerm(regoSliceString(s.Groups...)),
},
)
return subj, nil
}
func (z Object) regoValue() ast.Value {
userACL := ast.NewObject()
for k, v := range z.ACLUserList {
userACL.Insert(ast.StringTerm(k), ast.NewTerm(regoSlice(v)))
}
grpACL := ast.NewObject()
for k, v := range z.ACLGroupList {
grpACL.Insert(ast.StringTerm(k), ast.NewTerm(regoSlice(v)))
}
return ast.NewObject(
[2]*ast.Term{
ast.StringTerm("id"),
ast.StringTerm(z.ID),
},
[2]*ast.Term{
ast.StringTerm("owner"),
ast.StringTerm(z.Owner),
},
[2]*ast.Term{
ast.StringTerm("org_owner"),
ast.StringTerm(z.OrgID),
},
[2]*ast.Term{
ast.StringTerm("type"),
ast.StringTerm(z.Type),
},
[2]*ast.Term{
ast.StringTerm("acl_user_list"),
ast.NewTerm(userACL),
},
[2]*ast.Term{
ast.StringTerm("acl_group_list"),
ast.NewTerm(grpACL),
},
)
}
func (role Role) regoValue() ast.Value {
orgMap := ast.NewObject()
for k, p := range role.Org {
orgMap.Insert(ast.StringTerm(k), ast.NewTerm(regoSlice(p)))
}
return ast.NewObject(
[2]*ast.Term{
ast.StringTerm("site"),
ast.NewTerm(regoSlice(role.Site)),
},
[2]*ast.Term{
ast.StringTerm("org"),
ast.NewTerm(orgMap),
},
[2]*ast.Term{
ast.StringTerm("user"),
ast.NewTerm(regoSlice(role.User)),
},
)
}
func (s Scope) regoValue() ast.Value {
r, ok := s.Role.regoValue().(ast.Object)
if !ok {
panic("developer error: role is not an object")
}
r.Insert(
ast.StringTerm("allow_list"),
ast.NewTerm(regoSliceString(s.AllowIDList...)),
)
return r
}
func (perm Permission) regoValue() ast.Value {
return ast.NewObject(
[2]*ast.Term{
ast.StringTerm("negate"),
ast.BooleanTerm(perm.Negate),
},
[2]*ast.Term{
ast.StringTerm("resource_type"),
ast.StringTerm(perm.ResourceType),
},
[2]*ast.Term{
ast.StringTerm("action"),
ast.StringTerm(string(perm.Action)),
},
)
}
func (act Action) regoValue() ast.Value {
return ast.StringTerm(string(act)).Value
}
type regoValue interface {
regoValue() ast.Value
}
// regoSlice returns the ast.Array representation of the slice.
// The slice must contain only types that implement the regoValue interface.
func regoSlice[T regoValue](slice []T) *ast.Array {
terms := make([]*ast.Term, len(slice))
for i, v := range slice {
terms[i] = ast.NewTerm(v.regoValue())
}
return ast.NewArray(terms...)
}
func regoSliceString(slice ...string) *ast.Array {
terms := make([]*ast.Term, len(slice))
for i, v := range slice {
terms[i] = ast.StringTerm(v)
}
return ast.NewArray(terms...)
}

View File

@ -263,34 +263,18 @@ func (a RegoAuthorizer) authorize(ctx context.Context, subject Subject, action A
return xerrors.Errorf("subject must have a scope")
}
subjRoles, err := subject.Roles.Expand()
astV, err := regoInputValue(subject, action, object)
if err != nil {
return xerrors.Errorf("expand roles: %w", err)
return xerrors.Errorf("convert input to value: %w", err)
}
subjScope, err := subject.Scope.Expand()
results, err := a.query.Eval(ctx, rego.EvalParsedInput(astV))
if err != nil {
return xerrors.Errorf("expand scope: %w", err)
}
input := map[string]interface{}{
"subject": authSubject{
ID: subject.ID,
Roles: subjRoles,
Groups: subject.Groups,
Scope: subjScope,
},
"object": object,
"action": action,
}
results, err := a.query.Eval(ctx, rego.EvalInput(input))
if err != nil {
return ForbiddenWithInternal(xerrors.Errorf("eval rego: %w", err), input, results)
return ForbiddenWithInternal(xerrors.Errorf("eval rego: %w", err), subject, action, object, results)
}
if !results.Allowed() {
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), input, results)
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), subject, action, object, results)
}
return nil
}

View File

@ -1,6 +1,7 @@
package rbac
import (
"sort"
"strings"
"github.com/google/uuid"
@ -79,6 +80,8 @@ var (
Site: permissions(map[string][]Action{
ResourceWildcard.Type: {WildcardSymbol},
}),
Org: map[string][]Permission{},
User: []Permission{},
}
},
@ -94,6 +97,7 @@ var (
// All users can see the provisioner daemons.
ResourceProvisionerDaemon.Type: {ActionRead},
}),
Org: map[string][]Permission{},
User: permissions(map[string][]Action{
ResourceWildcard.Type: {WildcardSymbol},
}),
@ -113,6 +117,8 @@ var (
ResourceTemplate.Type: {ActionRead},
ResourceAuditLog.Type: {ActionRead},
}),
Org: map[string][]Permission{},
User: []Permission{},
}
},
@ -128,6 +134,8 @@ var (
// CRUD to provisioner daemons for now.
ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
}),
Org: map[string][]Permission{},
User: []Permission{},
}
},
@ -142,6 +150,8 @@ var (
ResourceOrganizationMember.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
ResourceGroup.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
}),
Org: map[string][]Permission{},
User: []Permission{},
}
},
@ -151,6 +161,7 @@ var (
return Role{
Name: roleName(orgAdmin, organizationID),
DisplayName: "Organization Admin",
Site: []Permission{},
Org: map[string][]Permission{
organizationID: {
{
@ -160,6 +171,7 @@ var (
},
},
},
User: []Permission{},
}
},
@ -169,6 +181,7 @@ var (
return Role{
Name: roleName(orgMember, organizationID),
DisplayName: "",
Site: []Permission{},
Org: map[string][]Permission{
organizationID: {
{
@ -192,6 +205,7 @@ var (
},
},
},
User: []Permission{},
}
},
}
@ -422,5 +436,9 @@ func permissions(perms map[string][]Action) []Permission {
})
}
}
// Deterministic ordering of permissions
sort.Slice(list, func(i, j int) bool {
return list[i].ResourceType < list[j].ResourceType
})
return list
}

View File

@ -3,11 +3,191 @@ package rbac
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/google/uuid"
"github.com/open-policy-agent/opa/ast"
"github.com/stretchr/testify/require"
)
// BenchmarkRBACValueAllocation benchmarks the cost of allocating a rego input
// value. By default, `ast.InterfaceToValue` is used to convert the input,
// which uses json marshalling under the hood.
//
// Currently ast.Object.insert() is the slowest part of the process and allocates
// the most amount of bytes. This general approach copies all of our struct
// data and uses a lot of extra memory for handling things like sort order.
// A possible large improvement would be to implement the ast.Value interface directly.
func BenchmarkRBACValueAllocation(b *testing.B) {
actor := Subject{
Roles: RoleNames{RoleOrgMember(uuid.New()), RoleOrgAdmin(uuid.New()), RoleMember()},
ID: uuid.NewString(),
Scope: ScopeAll,
Groups: []string{uuid.NewString(), uuid.NewString(), uuid.NewString()},
}
obj := ResourceTemplate.
WithID(uuid.New()).
InOrg(uuid.New()).
WithOwner(uuid.NewString()).
WithGroupACL(map[string][]Action{
uuid.NewString(): {ActionRead, ActionCreate},
uuid.NewString(): {ActionRead, ActionCreate},
uuid.NewString(): {ActionRead, ActionCreate},
}).WithACLUserList(map[string][]Action{
uuid.NewString(): {ActionRead, ActionCreate},
uuid.NewString(): {ActionRead, ActionCreate},
})
jsonSubject := authSubject{
ID: actor.ID,
Roles: must(actor.Roles.Expand()),
Groups: actor.Groups,
Scope: must(actor.Scope.Expand()),
}
b.Run("ManualRegoValue", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := regoInputValue(actor, ActionRead, obj)
require.NoError(b, err)
}
})
b.Run("JSONRegoValue", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := ast.InterfaceToValue(jsonSubject)
require.NoError(b, err)
}
})
}
// TestRegoInputValue ensures the custom rego input parser returns the
// same value as the default json parser. The json parser is always correct,
// and the custom parser is used to reduce allocations. This optimization
// should yield the same results. Anything different is a bug.
func TestRegoInputValue(t *testing.T) {
t.Parallel()
actor := Subject{
Roles: RoleNames{RoleOrgMember(uuid.New()), RoleOrgAdmin(uuid.New()), RoleMember()},
ID: uuid.NewString(),
Scope: ScopeAll,
Groups: []string{uuid.NewString(), uuid.NewString(), uuid.NewString()},
}
obj := ResourceTemplate.
WithID(uuid.New()).
InOrg(uuid.New()).
WithOwner(uuid.NewString()).
WithGroupACL(map[string][]Action{
uuid.NewString(): {ActionRead, ActionCreate},
uuid.NewString(): {ActionRead, ActionCreate},
uuid.NewString(): {ActionRead, ActionCreate},
}).WithACLUserList(map[string][]Action{
uuid.NewString(): {ActionRead, ActionCreate},
uuid.NewString(): {ActionRead, ActionCreate},
})
action := ActionRead
t.Run("InputValue", func(t *testing.T) {
t.Parallel()
// This is the input that would be passed to the rego policy.
jsonInput := map[string]interface{}{
"subject": authSubject{
ID: actor.ID,
Roles: must(actor.Roles.Expand()),
Groups: actor.Groups,
Scope: must(actor.Scope.Expand()),
},
"action": action,
"object": obj,
}
manual, err := regoInputValue(actor, action, obj)
require.NoError(t, err)
general, err := ast.InterfaceToValue(jsonInput)
require.NoError(t, err)
// The custom parser does not set these fields because they are not needed.
// To ensure the outputs are identical, intentionally overwrite all names
// to the same values.
ignoreNames(t, manual)
ignoreNames(t, general)
cmp := manual.Compare(general)
require.Equal(t, 0, cmp, "manual and general input values should be equal")
})
t.Run("PartialInputValue", func(t *testing.T) {
t.Parallel()
// This is the input that would be passed to the rego policy.
jsonInput := map[string]interface{}{
"subject": authSubject{
ID: actor.ID,
Roles: must(actor.Roles.Expand()),
Groups: actor.Groups,
Scope: must(actor.Scope.Expand()),
},
"action": action,
"object": map[string]interface{}{
"type": obj.Type,
},
}
manual, err := regoPartialInputValue(actor, action, obj.Type)
require.NoError(t, err)
general, err := ast.InterfaceToValue(jsonInput)
require.NoError(t, err)
// The custom parser does not set these fields because they are not needed.
// To ensure the outputs are identical, intentionally overwrite all names
// to the same values.
ignoreNames(t, manual)
ignoreNames(t, general)
cmp := manual.Compare(general)
require.Equal(t, 0, cmp, "manual and general input values should be equal")
})
}
// ignoreNames sets all names to "ignore" to ensure the values are identical.
func ignoreNames(t *testing.T, value ast.Value) {
t.Helper()
// Override the names of the roles
ref := ast.Ref{
ast.StringTerm("subject"),
ast.StringTerm("roles"),
}
roles, err := value.Find(ref)
require.NoError(t, err)
rolesArray, ok := roles.(*ast.Array)
require.True(t, ok, "roles is expected to be an array")
rolesArray.Foreach(func(term *ast.Term) {
obj, _ := term.Value.(ast.Object)
// Ignore all names
obj.Insert(ast.StringTerm("name"), ast.StringTerm("ignore"))
obj.Insert(ast.StringTerm("display_name"), ast.StringTerm("ignore"))
})
// Override the names of the scope role
ref = ast.Ref{
ast.StringTerm("subject"),
ast.StringTerm("scope"),
}
scope, err := value.Find(ref)
require.NoError(t, err)
scopeObj, ok := scope.(ast.Object)
require.True(t, ok, "scope is expected to be an object")
scopeObj.Insert(ast.StringTerm("name"), ast.StringTerm("ignore"))
scopeObj.Insert(ast.StringTerm("display_name"), ast.StringTerm("ignore"))
}
func TestRoleByName(t *testing.T) {
t.Parallel()

View File

@ -14,20 +14,25 @@ type UnauthorizedError struct {
// internal is the internal error that should never be shown to the client.
// It is only for debugging purposes.
internal error
input map[string]interface{}
output rego.ResultSet
// These fields are for debugging purposes.
subject Subject
action Action
// Note only the object type is set for partial execution.
object Object
output rego.ResultSet
}
// ForbiddenWithInternal creates a new error that will return a simple
// "forbidden" to the client, logging internally the more detailed message
// provided.
func ForbiddenWithInternal(internal error, input map[string]interface{}, output rego.ResultSet) *UnauthorizedError {
if input == nil {
input = map[string]interface{}{}
}
func ForbiddenWithInternal(internal error, subject Subject, action Action, object Object, output rego.ResultSet) *UnauthorizedError {
return &UnauthorizedError{
internal: internal,
input: input,
subject: subject,
action: action,
object: object,
output: output,
}
}
@ -43,7 +48,11 @@ func (e *UnauthorizedError) Internal() error {
}
func (e *UnauthorizedError) Input() map[string]interface{} {
return e.input
return map[string]interface{}{
"subject": e.subject,
"action": e.action,
"object": e.object,
}
}
// Output contains the results of the Rego query for debugging.

View File

@ -12,7 +12,21 @@
},
"subject": {
"id": "10d03e62-7703-4df5-a358-4f76577d4e2f",
"roles": [],
"roles": [
{
"name": "owner",
"display_name": "Owner",
"site": [
{
"negate": false,
"resource_type": "*",
"action": "*"
}
],
"org": {},
"user": []
}
],
"groups": ["b617a647-b5d0-4cbe-9e40-26f89710bf18"],
"scope": {
"name": "Scope_all",

View File

@ -17,8 +17,12 @@ type PartialAuthorizer struct {
// partialQueries is mainly used for unit testing to assert our rego policy
// can always be compressed into a set of queries.
partialQueries *rego.PartialQueries
// input is used purely for debugging and logging.
input map[string]interface{}
subjectInput Subject
subjectAction Action
subjectResourceType Object
// preparedQueries are the compiled set of queries after partial evaluation.
// Cache these prepared queries to avoid re-compiling the queries.
// If alwaysTrue is true, then ignore these.
@ -54,7 +58,8 @@ func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error
// If we have no queries, then no queries can return 'true'.
// So the result is always 'false'.
if len(pa.preparedQueries) == 0 {
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), pa.input, nil)
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"),
pa.subjectInput, pa.subjectAction, pa.subjectResourceType, nil)
}
parsed, err := ast.InterfaceToValue(map[string]interface{}{
@ -118,7 +123,8 @@ EachQueryLoop:
return nil
}
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), pa.input, nil)
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"),
pa.subjectInput, pa.subjectAction, pa.subjectResourceType, nil)
}
func newPartialAuthorizer(ctx context.Context, subject Subject, action Action, objectType string) (*PartialAuthorizer, error) {
@ -129,27 +135,9 @@ func newPartialAuthorizer(ctx context.Context, subject Subject, action Action, o
return nil, xerrors.Errorf("subject must have a scope")
}
roles, err := subject.Roles.Expand()
input, err := regoPartialInputValue(subject, action, objectType)
if err != nil {
return nil, xerrors.Errorf("expand roles: %w", err)
}
scope, err := subject.Scope.Expand()
if err != nil {
return nil, xerrors.Errorf("expand scope: %w", err)
}
input := map[string]interface{}{
"subject": authSubject{
ID: subject.ID,
Roles: roles,
Scope: scope,
Groups: subject.Groups,
},
"object": map[string]string{
"type": objectType,
},
"action": action,
return nil, xerrors.Errorf("prepare input: %w", err)
}
// Run the rego policy with a few unknown fields. This should simplify our
@ -164,7 +152,7 @@ func newPartialAuthorizer(ctx context.Context, subject Subject, action Action, o
"input.object.acl_user_list",
"input.object.acl_group_list",
}),
rego.Input(input),
rego.ParsedInput(input),
).Partial(ctx)
if err != nil {
return nil, xerrors.Errorf("prepare: %w", err)
@ -173,7 +161,12 @@ func newPartialAuthorizer(ctx context.Context, subject Subject, action Action, o
pAuth := &PartialAuthorizer{
partialQueries: partialQueries,
preparedQueries: []rego.PreparedEvalQuery{},
input: input,
subjectInput: subject,
subjectResourceType: Object{
Type: objectType,
ID: "prepared-object",
},
subjectAction: action,
}
// Prepare each query to optimize the runtime when we iterate over the objects.