mirror of https://github.com/coder/coder.git
chore: Rewrite rbac rego -> SQL clause (#5138)
* chore: Rewrite rbac rego -> SQL clause Previous code was challenging to read with edge cases - bug: OrgAdmin could not make new groups - Also refactor some function names
This commit is contained in:
parent
d5ab4fdeb8
commit
ab9298f382
|
@ -5,7 +5,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
@ -95,19 +94,14 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
|
|||
// from postgres are already authorized, and the caller does not need to
|
||||
// call 'Authorize()' on the returned objects.
|
||||
// Note the authorization is only for the given action and object type.
|
||||
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
|
||||
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) {
|
||||
roles := httpmw.UserAuthorization(r)
|
||||
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("prepare filter: %w", err)
|
||||
}
|
||||
|
||||
filter, err := prepared.Compile()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile filter: %w", err)
|
||||
}
|
||||
|
||||
return filter, nil
|
||||
return prepared, nil
|
||||
}
|
||||
|
||||
// checkAuthorization returns if the current API key can use the given
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -16,12 +18,19 @@ import (
|
|||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
)
|
||||
|
||||
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
// For any route using SQL filters, we need to know if the database is an
|
||||
// in memory fake. This is because the in memory fake does not use SQL, and
|
||||
// still uses rego. So this boolean indicates how to assert the expected
|
||||
// behavior.
|
||||
_, isMemoryDB := a.api.Database.(databasefake.FakeDatabase)
|
||||
|
||||
// Some quick reused objects
|
||||
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||
workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||
|
@ -125,11 +134,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
},
|
||||
"GET:/api/v2/organizations/{organization}/templates": {
|
||||
StatusCode: http.StatusOK,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID),
|
||||
},
|
||||
"POST:/api/v2/organizations/{organization}/templates": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: rbac.ResourceTemplate.InOrg(a.Organization.ID),
|
||||
|
@ -240,7 +244,18 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
"GET:/api/v2/organizations/{organization}/templateversions/{templateversionname}": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
|
||||
// Endpoints that use the SQLQuery filter.
|
||||
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
|
||||
"GET:/api/v2/workspaces/": {
|
||||
StatusCode: http.StatusOK,
|
||||
NoAuthorize: !isMemoryDB,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceWorkspace,
|
||||
},
|
||||
"GET:/api/v2/organizations/{organization}/templates": {
|
||||
StatusCode: http.StatusOK,
|
||||
NoAuthorize: !isMemoryDB,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceTemplate,
|
||||
},
|
||||
}
|
||||
|
||||
// Routes like proxy routes support all HTTP methods. A helper func to expand
|
||||
|
@ -549,10 +564,10 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
|
|||
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object)
|
||||
}
|
||||
|
||||
// Compile returns a compiled version of the authorizer that will work for
|
||||
// CompileToSQL returns a compiled version of the authorizer that will work for
|
||||
// in memory databases. This fake version will not work against a SQL database.
|
||||
func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
|
||||
return f, nil
|
||||
func (fakePreparedAuthorizer) CompileToSQL(_ regosql.ConvertConfig) (string, error) {
|
||||
return "", xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
|
||||
|
@ -565,10 +580,3 @@ func (f fakePreparedAuthorizer) RegoString() string {
|
|||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string {
|
||||
if f.HardCodedSQLString != "" {
|
||||
return f.HardCodedSQLString
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
|
|
|
@ -20,6 +20,13 @@ import (
|
|||
"github.com/coder/coder/coderd/util/slice"
|
||||
)
|
||||
|
||||
// FakeDatabase is helpful for knowing if the underlying db is an in memory fake
|
||||
// database. This is only in the databasefake package, so will only be used
|
||||
// by unit tests.
|
||||
type FakeDatabase interface {
|
||||
IsFakeDB()
|
||||
}
|
||||
|
||||
var errDuplicateKey = &pq.Error{
|
||||
Code: "23505",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
|
@ -117,6 +124,7 @@ type data struct {
|
|||
lastLicenseID int32
|
||||
}
|
||||
|
||||
func (fakeQuerier) IsFakeDB() {}
|
||||
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
@ -488,11 +496,20 @@ func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.Get
|
|||
return count, err
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
||||
func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
users := append([]database.User{}, q.users...)
|
||||
users := make([]database.User, 0, len(q.users))
|
||||
|
||||
for _, user := range q.users {
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
if params.Deleted {
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
|
@ -539,13 +556,6 @@ func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.
|
|||
users = usersFilteredByRole
|
||||
}
|
||||
|
||||
for _, user := range q.workspaces {
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return int64(len(users)), nil
|
||||
}
|
||||
|
||||
|
@ -750,7 +760,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa
|
|||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.GetWorkspacesRow, error) {
|
||||
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
|
@ -923,7 +933,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
|||
}
|
||||
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
|
||||
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
workspaces = append(workspaces, workspace)
|
||||
|
@ -1505,12 +1515,20 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd
|
|||
return database.Template{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
|
||||
func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
|
||||
return q.GetAuthorizedTemplates(ctx, arg, nil)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var templates []database.Template
|
||||
for _, template := range q.templates {
|
||||
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if template.Deleted != arg.Deleted {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -74,6 +74,7 @@ func TestExactMethods(t *testing.T) {
|
|||
extraFakeMethods := map[string]string{
|
||||
// Example
|
||||
// "SortFakeLists": "Helper function used",
|
||||
"IsFakeDB": "Helper function used for unit testing",
|
||||
}
|
||||
|
||||
fake := reflect.TypeOf(databasefake.New())
|
||||
|
|
|
@ -5,12 +5,16 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
)
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
const (
|
||||
authorizedQueryPlaceholder = "-- @authorize_filter"
|
||||
)
|
||||
|
||||
// customQuerier encompasses all non-generated queries.
|
||||
|
@ -23,10 +27,70 @@ type customQuerier interface {
|
|||
}
|
||||
|
||||
type templateQuerier interface {
|
||||
GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error)
|
||||
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
|
||||
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(regosql.ConvertConfig{
|
||||
VariableConverter: regosql.TemplateConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getTemplatesWithFilter, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedTemplates :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.OrganizationID,
|
||||
arg.ExactName,
|
||||
pq.Array(arg.IDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OrganizationID,
|
||||
&i.Deleted,
|
||||
&i.Name,
|
||||
&i.Provisioner,
|
||||
&i.ActiveVersionID,
|
||||
&i.Description,
|
||||
&i.DefaultTTL,
|
||||
&i.CreatedBy,
|
||||
&i.Icon,
|
||||
&i.UserACL,
|
||||
&i.GroupACL,
|
||||
&i.DisplayName,
|
||||
&i.AllowUserCancelWorkspaceJobs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
type TemplateUser struct {
|
||||
User
|
||||
Actions Actions `db:"actions"`
|
||||
|
@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
|
|||
}
|
||||
|
||||
type workspaceQuerier interface {
|
||||
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error)
|
||||
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error)
|
||||
}
|
||||
|
||||
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
|
||||
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
|
||||
// clause.
|
||||
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) {
|
||||
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
|
||||
// authorizedFilter between the end of the where clause and those statements.
|
||||
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
||||
filtered, err := insertAuthorizedFilter(getWorkspaces, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter)
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.Status,
|
||||
|
@ -172,12 +245,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
|||
}
|
||||
|
||||
type userQuerier interface {
|
||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
|
||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
||||
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
|
||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
|
||||
row := q.db.QueryRowContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.Search,
|
||||
|
@ -185,6 +267,14 @@ func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFiltered
|
|||
pq.Array(arg.RbacRole),
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
err = row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||
}
|
||||
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
|
||||
return filtered, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsAuthorizedQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
query := `SELECT true;`
|
||||
_, err := insertAuthorizedFilter(query, "")
|
||||
require.ErrorContains(t, err, "does not contain authorized replace string", "ensure replace string")
|
||||
}
|
|
@ -3197,6 +3197,8 @@ WHERE
|
|||
id = ANY($4)
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
|
||||
-- @authorize_filter
|
||||
ORDER BY (name, id) ASC
|
||||
`
|
||||
|
||||
|
|
|
@ -34,6 +34,8 @@ WHERE
|
|||
id = ANY(@ids)
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
|
||||
-- @authorize_filter
|
||||
ORDER BY (name, id) ASC
|
||||
;
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
|
@ -20,7 +21,7 @@ type Authorizer interface {
|
|||
|
||||
type PreparedAuthorized interface {
|
||||
Authorize(ctx context.Context, object Object) error
|
||||
Compile() (AuthorizeFilter, error)
|
||||
CompileToSQL(cfg regosql.ConvertConfig) (string, error)
|
||||
}
|
||||
|
||||
// Filter takes in a list of objects, and will filter the list removing all
|
||||
|
|
|
@ -847,7 +847,7 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
|
|||
|
||||
// Ensure the partial can compile to a SQL clause.
|
||||
// This does not guarantee that the clause is valid SQL.
|
||||
_, err = Compile(partialAuthz)
|
||||
_, err = Compile(ConfigWithACL(), partialAuthz)
|
||||
require.NoError(t, err, "compile prepared authorizer")
|
||||
|
||||
// Also check the rego policy can form a valid partial query result.
|
||||
|
|
|
@ -3,11 +3,11 @@ package rbac
|
|||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
|
@ -28,12 +28,12 @@ type PartialAuthorizer struct {
|
|||
|
||||
var _ PreparedAuthorized = (*PartialAuthorizer)(nil)
|
||||
|
||||
func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) {
|
||||
filter, err := Compile(pa)
|
||||
func (pa *PartialAuthorizer) CompileToSQL(cfg regosql.ConvertConfig) (string, error) {
|
||||
filter, err := Compile(cfg, pa)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile: %w", err)
|
||||
return "", xerrors.Errorf("compile: %w", err)
|
||||
}
|
||||
return filter, nil
|
||||
return filter.SQLString(), nil
|
||||
}
|
||||
|
||||
func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error {
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
package partial
|
|
@ -2,635 +2,63 @@ package rbac
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type TermType string
|
||||
|
||||
const (
|
||||
VarTypeJsonbTextArray TermType = "jsonb-text-array"
|
||||
VarTypeText TermType = "text"
|
||||
VarTypeBoolean TermType = "boolean"
|
||||
// VarTypeSkip means this variable does not exist to use.
|
||||
VarTypeSkip TermType = "skip"
|
||||
)
|
||||
|
||||
type SQLColumn struct {
|
||||
// RegoMatch matches the original variable string.
|
||||
// If it is a match, then this variable config will apply.
|
||||
RegoMatch *regexp.Regexp
|
||||
// ColumnSelect is the name of the postgres column to select.
|
||||
// Can use capture groups from RegoMatch with $1, $2, etc.
|
||||
ColumnSelect string
|
||||
|
||||
// Type indicates the postgres type of the column. Some expressions will
|
||||
// need to know this in order to determine what SQL to produce.
|
||||
// An example is if the variable is a jsonb array, the "contains" SQL
|
||||
// query is `variable ? 'value'` instead of `'value' = ANY(variable)`.
|
||||
// This type is only needed to be provided
|
||||
Type TermType
|
||||
}
|
||||
|
||||
type SQLConfig struct {
|
||||
// Variables is a map of rego variable names to SQL columns.
|
||||
// Example:
|
||||
// "input\.object\.org_owner": SQLColumn{
|
||||
// ColumnSelect: "organization_id",
|
||||
// Type: VarTypeUUID
|
||||
// }
|
||||
// "input\.object\.owner": SQLColumn{
|
||||
// ColumnSelect: "owner_id",
|
||||
// Type: VarTypeUUID
|
||||
// }
|
||||
// "input\.object\.group_acl\.(.*)": SQLColumn{
|
||||
// ColumnSelect: "group_acl->$1",
|
||||
// Type: VarTypeJsonbTextArray
|
||||
// }
|
||||
Variables []SQLColumn
|
||||
}
|
||||
|
||||
func DefaultConfig() SQLConfig {
|
||||
return SQLConfig{
|
||||
Variables: []SQLColumn{
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
|
||||
ColumnSelect: "group_acl->$1",
|
||||
Type: VarTypeJsonbTextArray,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
|
||||
ColumnSelect: "user_acl->$1",
|
||||
Type: VarTypeJsonbTextArray,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
|
||||
ColumnSelect: "organization_id :: text",
|
||||
Type: VarTypeText,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
|
||||
ColumnSelect: "owner_id :: text",
|
||||
Type: VarTypeText,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NoACLConfig() SQLConfig {
|
||||
return SQLConfig{
|
||||
Variables: []SQLColumn{
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
|
||||
ColumnSelect: "",
|
||||
Type: VarTypeSkip,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
|
||||
ColumnSelect: "",
|
||||
Type: VarTypeSkip,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
|
||||
ColumnSelect: "organization_id :: text",
|
||||
Type: VarTypeText,
|
||||
},
|
||||
{
|
||||
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
|
||||
ColumnSelect: "owner_id :: text",
|
||||
Type: VarTypeText,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type AuthorizeFilter interface {
|
||||
Expression
|
||||
// Eval is required for the fake in memory database to work. The in memory
|
||||
// database can use this function to filter the results.
|
||||
Eval(object Object) bool
|
||||
SQLString() string
|
||||
}
|
||||
|
||||
// expressionTop handles Eval(object Object) for in memory expressions
|
||||
type expressionTop struct {
|
||||
Expression
|
||||
Auth *PartialAuthorizer
|
||||
type authorizedSQLFilter struct {
|
||||
sqlString string
|
||||
auth *PartialAuthorizer
|
||||
}
|
||||
|
||||
func (e expressionTop) Eval(object Object) bool {
|
||||
return e.Auth.Authorize(context.Background(), object) == nil
|
||||
func ConfigWithACL() regosql.ConvertConfig {
|
||||
return regosql.ConvertConfig{
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
}
|
||||
}
|
||||
|
||||
// Compile will convert a rego query AST into our custom types. The output is
|
||||
// an AST that can be used to generate SQL.
|
||||
func Compile(pa *PartialAuthorizer) (AuthorizeFilter, error) {
|
||||
partialQueries := pa.partialQueries
|
||||
if len(partialQueries.Support) > 0 {
|
||||
return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support))
|
||||
func ConfigWithoutACL() regosql.ConvertConfig {
|
||||
return regosql.ConvertConfig{
|
||||
VariableConverter: regosql.NoACLConverter(),
|
||||
}
|
||||
}
|
||||
|
||||
func Compile(cfg regosql.ConvertConfig, pa *PartialAuthorizer) (AuthorizeFilter, error) {
|
||||
root, err := regosql.ConvertRegoAst(cfg, pa.partialQueries)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("convert rego ast: %w", err)
|
||||
}
|
||||
|
||||
// 0 queries means the result is "undefined". This is the same as "false".
|
||||
if len(partialQueries.Queries) == 0 {
|
||||
return &termBoolean{
|
||||
base: base{Rego: "false"},
|
||||
Value: false,
|
||||
}, nil
|
||||
// Generate the SQL
|
||||
gen := sqltypes.NewSQLGenerator()
|
||||
sqlString := root.SQLString(gen)
|
||||
if len(gen.Errors()) > 0 {
|
||||
var errStrings []string
|
||||
for _, err := range gen.Errors() {
|
||||
errStrings = append(errStrings, err.Error())
|
||||
}
|
||||
return nil, xerrors.Errorf("sql generation errors: %v", strings.Join(errStrings, ", "))
|
||||
}
|
||||
|
||||
// Abort early if any of the "OR"'d expressions are the empty string.
|
||||
// This is the same as "true".
|
||||
for _, query := range partialQueries.Queries {
|
||||
if query.String() == "" {
|
||||
return &termBoolean{
|
||||
base: base{Rego: "true"},
|
||||
Value: true,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]Expression, 0, len(partialQueries.Queries))
|
||||
var builder strings.Builder
|
||||
for i := range partialQueries.Queries {
|
||||
query, err := processQuery(partialQueries.Queries[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, query)
|
||||
if i != 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(partialQueries.Queries[i].String())
|
||||
}
|
||||
exp := expOr{
|
||||
base: base{
|
||||
Rego: builder.String(),
|
||||
},
|
||||
Expressions: result,
|
||||
}
|
||||
return expressionTop{
|
||||
Expression: &exp,
|
||||
Auth: pa,
|
||||
return &authorizedSQLFilter{
|
||||
sqlString: sqlString,
|
||||
auth: pa,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processQuery processes an entire set of expressions and joins them with
|
||||
// "AND".
|
||||
func processQuery(query ast.Body) (Expression, error) {
|
||||
expressions := make([]Expression, 0, len(query))
|
||||
for _, astExpr := range query {
|
||||
expr, err := processExpression(astExpr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expressions = append(expressions, expr)
|
||||
}
|
||||
|
||||
return expAnd{
|
||||
base: base{
|
||||
Rego: query.String(),
|
||||
},
|
||||
Expressions: expressions,
|
||||
}, nil
|
||||
func (a *authorizedSQLFilter) Eval(object Object) bool {
|
||||
return a.auth.Authorize(context.Background(), object) == nil
|
||||
}
|
||||
|
||||
func processExpression(expr *ast.Expr) (Expression, error) {
|
||||
if !expr.IsCall() {
|
||||
// This could be a single term that is a valid expression.
|
||||
if term, ok := expr.Terms.(*ast.Term); ok {
|
||||
value, err := processTerm(term)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("single term expression: %w", err)
|
||||
}
|
||||
if boolExp, ok := value.(Expression); ok {
|
||||
return boolExp, nil
|
||||
}
|
||||
// Default to error.
|
||||
}
|
||||
return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported")
|
||||
}
|
||||
|
||||
op := expr.Operator().String()
|
||||
base := base{Rego: op}
|
||||
switch op {
|
||||
case "neq", "eq", "equal":
|
||||
terms, err := processTerms(2, expr.Operands())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
|
||||
}
|
||||
return &opEqual{
|
||||
base: base,
|
||||
Terms: [2]Term{terms[0], terms[1]},
|
||||
Not: op == "neq",
|
||||
}, nil
|
||||
case "internal.member_2":
|
||||
terms, err := processTerms(2, expr.Operands())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
|
||||
}
|
||||
return &opInternalMember2{
|
||||
base: base,
|
||||
Needle: terms[0],
|
||||
Haystack: terms[1],
|
||||
}, nil
|
||||
default:
|
||||
return nil, xerrors.Errorf("invalid expression: operator %s not supported", op)
|
||||
}
|
||||
}
|
||||
|
||||
func processTerms(expected int, terms []*ast.Term) ([]Term, error) {
|
||||
if len(terms) != expected {
|
||||
return nil, xerrors.Errorf("too many arguments, expect %d found %d", expected, len(terms))
|
||||
}
|
||||
|
||||
result := make([]Term, 0, len(terms))
|
||||
for _, term := range terms {
|
||||
processed, err := processTerm(term)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid term: %w", err)
|
||||
}
|
||||
result = append(result, processed)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func processTerm(term *ast.Term) (Term, error) {
|
||||
termBase := base{Rego: term.String()}
|
||||
switch v := term.Value.(type) {
|
||||
case ast.Boolean:
|
||||
return &termBoolean{
|
||||
base: termBase,
|
||||
Value: bool(v),
|
||||
}, nil
|
||||
case ast.Ref:
|
||||
obj := &termObject{
|
||||
base: termBase,
|
||||
Path: []Term{},
|
||||
}
|
||||
var idx int
|
||||
// A ref is a set of terms. If the first term is a var, then the
|
||||
// following terms are the path to the value.
|
||||
isRef := true
|
||||
var builder strings.Builder
|
||||
for _, term := range v {
|
||||
if idx == 0 {
|
||||
if _, ok := v[0].Value.(ast.Var); !ok {
|
||||
return nil, xerrors.Errorf("invalid term (%s): ref must start with a var, started with %T", v[0].String(), v[0])
|
||||
}
|
||||
}
|
||||
|
||||
_, newRef := term.Value.(ast.Ref)
|
||||
if newRef ||
|
||||
// This is an unfortunate hack. To fix this, we need to rewrite
|
||||
// our SQL config as a path ([]string{"input", "object", "acl_group"}).
|
||||
// In the rego AST, there is no difference between selecting
|
||||
// a field by a variable, and selecting a field by a literal (string).
|
||||
// This was a misunderstanding.
|
||||
// Example (these are equivalent by AST):
|
||||
// input.object.acl_group_list['4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75']
|
||||
// input.object.acl_group_list.organization_id
|
||||
//
|
||||
// This is not equivalent
|
||||
// input.object.acl_group_list[input.object.organization_id]
|
||||
//
|
||||
// If this becomes even more hairy, we should fix the sql config.
|
||||
builder.String() == "input.object.acl_group_list" ||
|
||||
builder.String() == "input.object.acl_user_list" {
|
||||
if !newRef {
|
||||
isRef = false
|
||||
}
|
||||
// New obj
|
||||
obj.Path = append(obj.Path, termVariable{
|
||||
base: base{
|
||||
Rego: builder.String(),
|
||||
},
|
||||
Name: builder.String(),
|
||||
})
|
||||
builder.Reset()
|
||||
idx = 0
|
||||
}
|
||||
|
||||
if builder.Len() != 0 {
|
||||
builder.WriteString(".")
|
||||
}
|
||||
builder.WriteString(trimQuotes(term.String()))
|
||||
idx++
|
||||
}
|
||||
|
||||
if isRef {
|
||||
obj.Path = append(obj.Path, termVariable{
|
||||
base: base{
|
||||
Rego: builder.String(),
|
||||
},
|
||||
Name: builder.String(),
|
||||
})
|
||||
} else {
|
||||
obj.Path = append(obj.Path, termString{
|
||||
base: base{
|
||||
Rego: fmt.Sprintf("%q", builder.String()),
|
||||
},
|
||||
Value: builder.String(),
|
||||
})
|
||||
}
|
||||
return obj, nil
|
||||
case ast.Var:
|
||||
return &termVariable{
|
||||
Name: trimQuotes(v.String()),
|
||||
base: termBase,
|
||||
}, nil
|
||||
case ast.String:
|
||||
return &termString{
|
||||
Value: trimQuotes(v.String()),
|
||||
base: termBase,
|
||||
}, nil
|
||||
case ast.Set:
|
||||
slice := v.Slice()
|
||||
set := make([]Term, 0, len(slice))
|
||||
for _, elem := range slice {
|
||||
processed, err := processTerm(elem)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid set term %s: %w", elem.String(), err)
|
||||
}
|
||||
set = append(set, processed)
|
||||
}
|
||||
|
||||
return &termSet{
|
||||
Value: set,
|
||||
base: termBase,
|
||||
}, nil
|
||||
default:
|
||||
return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String())
|
||||
}
|
||||
}
|
||||
|
||||
type base struct {
|
||||
// Rego is the original rego string
|
||||
Rego string
|
||||
}
|
||||
|
||||
func (b base) RegoString() string {
|
||||
return b.Rego
|
||||
}
|
||||
|
||||
// Expression comprises a set of terms, operators, and functions. All
|
||||
// expressions return a boolean value.
|
||||
//
|
||||
// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo"
|
||||
type Expression interface {
|
||||
// RegoString is used in debugging to see the original rego expression.
|
||||
RegoString() string
|
||||
// SQLString returns the SQL expression that can be used in a WHERE clause.
|
||||
SQLString(cfg SQLConfig) string
|
||||
}
|
||||
|
||||
type expAnd struct {
|
||||
base
|
||||
Expressions []Expression
|
||||
}
|
||||
|
||||
func (t expAnd) SQLString(cfg SQLConfig) string {
|
||||
if len(t.Expressions) == 1 {
|
||||
return t.Expressions[0].SQLString(cfg)
|
||||
}
|
||||
|
||||
exprs := make([]string, 0, len(t.Expressions))
|
||||
for _, expr := range t.Expressions {
|
||||
exprs = append(exprs, expr.SQLString(cfg))
|
||||
}
|
||||
return "(" + strings.Join(exprs, " AND ") + ")"
|
||||
}
|
||||
|
||||
type expOr struct {
|
||||
base
|
||||
Expressions []Expression
|
||||
}
|
||||
|
||||
func (t expOr) SQLString(cfg SQLConfig) string {
|
||||
if len(t.Expressions) == 1 {
|
||||
return t.Expressions[0].SQLString(cfg)
|
||||
}
|
||||
|
||||
exprs := make([]string, 0, len(t.Expressions))
|
||||
for _, expr := range t.Expressions {
|
||||
exprs = append(exprs, expr.SQLString(cfg))
|
||||
}
|
||||
return "(" + strings.Join(exprs, " OR ") + ")"
|
||||
}
|
||||
|
||||
// Operator joins terms together to form an expression.
|
||||
// Operators are also expressions.
|
||||
//
|
||||
// Eg: "=", "neq", "internal.member_2", etc.
|
||||
type Operator interface {
|
||||
Expression
|
||||
}
|
||||
|
||||
type opEqual struct {
|
||||
base
|
||||
Terms [2]Term
|
||||
// For NotEqual
|
||||
Not bool
|
||||
}
|
||||
|
||||
func (t opEqual) SQLString(cfg SQLConfig) string {
|
||||
op := "="
|
||||
if t.Not {
|
||||
op = "!="
|
||||
}
|
||||
return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg))
|
||||
}
|
||||
|
||||
// opInternalMember2 is checking if the first term is a member of the second term.
|
||||
// The second term is a set or list.
|
||||
type opInternalMember2 struct {
|
||||
base
|
||||
Needle Term
|
||||
Haystack Term
|
||||
}
|
||||
|
||||
func (t opInternalMember2) SQLString(cfg SQLConfig) string {
|
||||
if haystack, ok := t.Haystack.(*termObject); ok {
|
||||
// This is a special case where the haystack is a jsonb array.
|
||||
// There is a more general way to solve this, but that requires a lot
|
||||
// more code to cover a lot more cases that we do not care about.
|
||||
// To handle this more generally we should implement "Array" as a type.
|
||||
// Then have the `contains` function on the Array type. This would defer
|
||||
// knowing the element type to the Array and cover more cases without
|
||||
// having to add more "if" branches here.
|
||||
// But until we need more cases, our basic type system is ok, and
|
||||
// this is the only case we need to handle.
|
||||
sqlType := haystack.SQLType(cfg)
|
||||
if sqlType == VarTypeJsonbTextArray {
|
||||
return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg))
|
||||
}
|
||||
|
||||
if sqlType == VarTypeSkip {
|
||||
return "false"
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg))
|
||||
}
|
||||
|
||||
// Term is a single value in an expression. Terms can be variables or constants.
|
||||
//
|
||||
// Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner",
|
||||
// "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}"
|
||||
type Term interface {
|
||||
RegoString() string
|
||||
SQLString(cfg SQLConfig) string
|
||||
SQLType(cfg SQLConfig) TermType
|
||||
}
|
||||
|
||||
type termString struct {
|
||||
base
|
||||
Value string
|
||||
}
|
||||
|
||||
func (t termString) SQLString(_ SQLConfig) string {
|
||||
return "'" + t.Value + "'"
|
||||
}
|
||||
|
||||
func (termString) SQLType(_ SQLConfig) TermType {
|
||||
return VarTypeText
|
||||
}
|
||||
|
||||
// termObject is a variable that can be dereferenced. We count some rego objects
|
||||
// as single variables, eg: input.object.org_owner. In reality, it is a nested
|
||||
// object.
|
||||
// In rego, we can dereference the object with the "." operator, which we can
|
||||
// handle with regex.
|
||||
// Or we can dereference the object with the "[]", which we can handle with this
|
||||
// term type.
|
||||
type termObject struct {
|
||||
base
|
||||
Path []Term
|
||||
}
|
||||
|
||||
func (t termObject) SQLType(cfg SQLConfig) TermType {
|
||||
// Without a full type system, let's just assume the type of the first var
|
||||
// is the resulting type. This is correct for our use case.
|
||||
// Solving this more generally requires a full type system, which is
|
||||
// excessive for our mostly static policy.
|
||||
return t.Path[0].SQLType(cfg)
|
||||
}
|
||||
|
||||
func (t termObject) SQLString(cfg SQLConfig) string {
|
||||
if len(t.Path) == 1 {
|
||||
return t.Path[0].SQLString(cfg)
|
||||
}
|
||||
// Combine the last 2 variables into 1 variable.
|
||||
end := t.Path[len(t.Path)-1]
|
||||
before := t.Path[len(t.Path)-2]
|
||||
|
||||
// Recursively solve the SQLString by removing the last nested reference.
|
||||
// This continues until we have a single variable.
|
||||
return termObject{
|
||||
base: t.base,
|
||||
Path: append(
|
||||
t.Path[:len(t.Path)-2],
|
||||
termVariable{
|
||||
base: base{
|
||||
Rego: before.RegoString() + "[" + end.RegoString() + "]",
|
||||
},
|
||||
// Convert the end to SQL string. We evaluate each term
|
||||
// one at a time.
|
||||
Name: before.RegoString() + "." + end.SQLString(cfg),
|
||||
},
|
||||
),
|
||||
}.SQLString(cfg)
|
||||
}
|
||||
|
||||
type termVariable struct {
|
||||
base
|
||||
Name string
|
||||
}
|
||||
|
||||
func (t termVariable) SQLType(cfg SQLConfig) TermType {
|
||||
if col := t.ColumnConfig(cfg); col != nil {
|
||||
return col.Type
|
||||
}
|
||||
return VarTypeText
|
||||
}
|
||||
|
||||
func (t termVariable) SQLString(cfg SQLConfig) string {
|
||||
if col := t.ColumnConfig(cfg); col != nil {
|
||||
matches := col.RegoMatch.FindStringSubmatch(t.Name)
|
||||
if len(matches) > 0 {
|
||||
// This config matches this variable.
|
||||
replace := make([]string, 0, len(matches)*2)
|
||||
for i, m := range matches {
|
||||
replace = append(replace, fmt.Sprintf("$%d", i))
|
||||
replace = append(replace, m)
|
||||
}
|
||||
replacer := strings.NewReplacer(replace...)
|
||||
return replacer.Replace(col.ColumnSelect)
|
||||
}
|
||||
}
|
||||
|
||||
return t.Name
|
||||
}
|
||||
|
||||
// ColumnConfig returns the correct SQLColumn settings for the
|
||||
// term. If there is no configured column, it will return nil.
|
||||
func (t termVariable) ColumnConfig(cfg SQLConfig) *SQLColumn {
|
||||
for _, col := range cfg.Variables {
|
||||
matches := col.RegoMatch.MatchString(t.Name)
|
||||
if matches {
|
||||
return &col
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// termSet is a set of unique terms.
|
||||
type termSet struct {
|
||||
base
|
||||
Value []Term
|
||||
}
|
||||
|
||||
func (t termSet) SQLType(cfg SQLConfig) TermType {
|
||||
if len(t.Value) == 0 {
|
||||
return VarTypeText
|
||||
}
|
||||
// Without a full type system, let's just assume the type of the first var
|
||||
// is the resulting type. This is correct for our use case.
|
||||
// Solving this more generally requires a full type system, which is
|
||||
// excessive for our mostly static policy.
|
||||
return t.Value[0].SQLType(cfg)
|
||||
}
|
||||
|
||||
func (t termSet) SQLString(cfg SQLConfig) string {
|
||||
elems := make([]string, 0, len(t.Value))
|
||||
for _, v := range t.Value {
|
||||
elems = append(elems, v.SQLString(cfg))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ","))
|
||||
}
|
||||
|
||||
type termBoolean struct {
|
||||
base
|
||||
Value bool
|
||||
}
|
||||
|
||||
func (termBoolean) SQLType(SQLConfig) TermType {
|
||||
return VarTypeBoolean
|
||||
}
|
||||
|
||||
func (t termBoolean) Eval(_ Object) bool {
|
||||
return t.Value
|
||||
}
|
||||
|
||||
func (t termBoolean) SQLString(_ SQLConfig) string {
|
||||
return strconv.FormatBool(t.Value)
|
||||
}
|
||||
|
||||
func trimQuotes(s string) string {
|
||||
return strings.Trim(s, "\"")
|
||||
func (a *authorizedSQLFilter) SQLString() string {
|
||||
return a.sqlString
|
||||
}
|
||||
|
|
|
@ -1,150 +0,0 @@
|
|||
package rbac
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompileQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyQuery", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(partialQueries(t, ""))
|
||||
require.NoError(t, err, "compile empty")
|
||||
|
||||
require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'")
|
||||
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'")
|
||||
})
|
||||
|
||||
t.Run("TrueQuery", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(partialQueries(t, "true"))
|
||||
require.NoError(t, err, "compile")
|
||||
|
||||
require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'")
|
||||
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'")
|
||||
})
|
||||
|
||||
t.Run("ACLIn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(partialQueries(t, `"*" in input.object.acl_group_list.allUsers`))
|
||||
require.NoError(t, err, "compile")
|
||||
|
||||
require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member")
|
||||
require.Equal(t, `group_acl->'allUsers' ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in")
|
||||
})
|
||||
|
||||
t.Run("Complex", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(partialQueries(t,
|
||||
`input.object.org_owner != ""`,
|
||||
`input.object.org_owner in {"a", "b", "c"}`,
|
||||
`input.object.org_owner != ""`,
|
||||
`"read" in input.object.acl_group_list.allUsers`,
|
||||
`"read" in input.object.acl_user_list.me`,
|
||||
))
|
||||
require.NoError(t, err, "compile")
|
||||
require.Equal(t, `(organization_id :: text != '' OR `+
|
||||
`organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+
|
||||
`organization_id :: text != '' OR `+
|
||||
`group_acl->'allUsers' ? 'read' OR `+
|
||||
`user_acl->'me' ? 'read')`,
|
||||
expression.SQLString(DefaultConfig()), "complex")
|
||||
})
|
||||
|
||||
t.Run("SetDereference", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(partialQueries(t,
|
||||
`"*" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
))
|
||||
require.NoError(t, err, "compile")
|
||||
require.Equal(t, `group_acl->organization_id :: text ? '*'`,
|
||||
expression.SQLString(DefaultConfig()), "set dereference")
|
||||
})
|
||||
|
||||
t.Run("JsonbLiteralDereference", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(partialQueries(t,
|
||||
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
|
||||
))
|
||||
require.NoError(t, err, "compile")
|
||||
require.Equal(t, `group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'`,
|
||||
expression.SQLString(DefaultConfig()), "literal dereference")
|
||||
})
|
||||
|
||||
t.Run("NoACLColumns", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(partialQueries(t,
|
||||
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
|
||||
))
|
||||
require.NoError(t, err, "compile")
|
||||
require.Equal(t, `false`,
|
||||
expression.SQLString(NoACLConfig()), "literal dereference")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("GroupACL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expression, err := Compile(partialQueries(t,
|
||||
`"read" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
|
||||
))
|
||||
require.NoError(t, err, "compile")
|
||||
|
||||
result := expression.Eval(Object{
|
||||
Owner: "not-me",
|
||||
OrgID: "random",
|
||||
Type: "workspace",
|
||||
ACLUserList: map[string][]Action{},
|
||||
ACLGroupList: map[string][]Action{
|
||||
"4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75": {"read"},
|
||||
},
|
||||
})
|
||||
require.True(t, result, "eval")
|
||||
})
|
||||
}
|
||||
|
||||
func partialQueries(t *testing.T, queries ...string) *PartialAuthorizer {
|
||||
opts := ast.ParserOptions{
|
||||
AllFutureKeywords: true,
|
||||
}
|
||||
|
||||
astQueries := make([]ast.Body, 0, len(queries))
|
||||
for _, q := range queries {
|
||||
astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts))
|
||||
}
|
||||
|
||||
prepareQueries := make([]rego.PreparedEvalQuery, 0, len(queries))
|
||||
for _, q := range astQueries {
|
||||
var prepped rego.PreparedEvalQuery
|
||||
var err error
|
||||
if q.String() == "" {
|
||||
prepped, err = rego.New(
|
||||
rego.Query("true"),
|
||||
).PrepareForEval(context.Background())
|
||||
} else {
|
||||
prepped, err = rego.New(
|
||||
rego.ParsedQuery(q),
|
||||
).PrepareForEval(context.Background())
|
||||
}
|
||||
require.NoError(t, err, "prepare query")
|
||||
prepareQueries = append(prepareQueries, prepped)
|
||||
}
|
||||
return &PartialAuthorizer{
|
||||
partialQueries: ®o.PartialQueries{
|
||||
Queries: astQueries,
|
||||
Support: []*ast.Module{},
|
||||
},
|
||||
preparedQueries: prepareQueries,
|
||||
input: nil,
|
||||
alwaysTrue: false,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
package regosql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||
)
|
||||
|
||||
var _ sqltypes.VariableMatcher = ACLGroupVar{}
|
||||
var _ sqltypes.Node = ACLGroupVar{}
|
||||
|
||||
// ACLGroupVar is a variable matcher that handles group_acl and user_acl.
|
||||
// The sql type is a jsonb object with the following structure:
|
||||
//
|
||||
// "group_acl": {
|
||||
// "<group_name>": ["<actions>"]
|
||||
// }
|
||||
//
|
||||
// This is a custom variable matcher as json objects have arbitrary complexity.
|
||||
type ACLGroupVar struct {
|
||||
StructSQL string
|
||||
// input.object.group_acl -> ["input", "object", "group_acl"]
|
||||
StructPath []string
|
||||
|
||||
// FieldReference handles referencing the subfields, which could be
|
||||
// more variables. We pass one in as the global one might not be correctly
|
||||
// scoped.
|
||||
FieldReference sqltypes.VariableMatcher
|
||||
|
||||
// Instance fields
|
||||
Source sqltypes.RegoSource
|
||||
GroupNode sqltypes.Node
|
||||
}
|
||||
|
||||
func ACLGroupMatcher(fieldReference sqltypes.VariableMatcher, structSQL string, structPath []string) ACLGroupVar {
|
||||
return ACLGroupVar{StructSQL: structSQL, StructPath: structPath, FieldReference: fieldReference}
|
||||
}
|
||||
|
||||
func (ACLGroupVar) UseAs() sqltypes.Node { return ACLGroupVar{} }
|
||||
|
||||
func (g ACLGroupVar) ConvertVariable(rego ast.Ref) (sqltypes.Node, bool) {
|
||||
// "left" will be a map of group names to actions in rego.
|
||||
// {
|
||||
// "all_users": ["read"]
|
||||
// }
|
||||
left, err := sqltypes.RegoVarPath(g.StructPath, rego)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
aclGrp := ACLGroupVar{
|
||||
StructSQL: g.StructSQL,
|
||||
StructPath: g.StructPath,
|
||||
FieldReference: g.FieldReference,
|
||||
|
||||
Source: sqltypes.RegoSource(rego.String()),
|
||||
}
|
||||
|
||||
// We expect 1 more term. Either a ref or a string.
|
||||
if len(left) != 1 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// If the remaining is a variable, then we need to convert it.
|
||||
// Assuming we support variable fields.
|
||||
ref, ok := left[0].Value.(ast.Ref)
|
||||
if ok && g.FieldReference != nil {
|
||||
groupNode, ok := g.FieldReference.ConvertVariable(ref)
|
||||
if ok {
|
||||
aclGrp.GroupNode = groupNode
|
||||
return aclGrp, true
|
||||
}
|
||||
}
|
||||
|
||||
// If it is a string, we assume it is a literal
|
||||
groupName, ok := left[0].Value.(ast.String)
|
||||
if ok {
|
||||
aclGrp.GroupNode = sqltypes.String(string(groupName))
|
||||
return aclGrp, true
|
||||
}
|
||||
|
||||
// If we have not matched it yet, then it is something we do not recognize.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (g ACLGroupVar) SQLString(cfg *sqltypes.SQLGenerator) string {
|
||||
return fmt.Sprintf("%s->%s", g.StructSQL, g.GroupNode.SQLString(cfg))
|
||||
}
|
||||
|
||||
func (g ACLGroupVar) ContainsSQL(cfg *sqltypes.SQLGenerator, other sqltypes.Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
// Only supports containing other strings.
|
||||
case sqltypes.AstString:
|
||||
return fmt.Sprintf("%s ? %s", g.SQLString(cfg), other.SQLString(cfg)), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported acl group contains %T", other)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,230 @@
|
|||
package regosql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||
)
|
||||
|
||||
// ConvertConfig is required to generate SQL from the rego queries.
|
||||
type ConvertConfig struct {
|
||||
// VariableConverter is called each time a var is encountered. This creates
|
||||
// the SQL ast for the variable. Without this, the SQL generator does not
|
||||
// know how to convert rego variables into SQL columns.
|
||||
VariableConverter sqltypes.VariableMatcher
|
||||
}
|
||||
|
||||
// ConvertRegoAst converts partial rego queries into a single SQL where
|
||||
// clause. If the query equates to "true" then the user should have access.
|
||||
func ConvertRegoAst(cfg ConvertConfig, partial *rego.PartialQueries) (sqltypes.BooleanNode, error) {
|
||||
if len(partial.Queries) == 0 {
|
||||
// Always deny if there are no queries. This means there is no possible
|
||||
// way this user can access these resources.
|
||||
return sqltypes.Bool(false), nil
|
||||
}
|
||||
|
||||
for _, q := range partial.Queries {
|
||||
// An empty query in rego means "true". If any query in the set is
|
||||
// empty, then the user should have access.
|
||||
if len(q) == 0 {
|
||||
// Always allow
|
||||
return sqltypes.Bool(true), nil
|
||||
}
|
||||
}
|
||||
|
||||
var queries []sqltypes.BooleanNode
|
||||
var builder strings.Builder
|
||||
for i, q := range partial.Queries {
|
||||
converted, err := convertQuery(cfg, q)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("query %s: %w", q.String(), err)
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(q.String())
|
||||
queries = append(queries, converted)
|
||||
}
|
||||
|
||||
// All queries are OR'd together. This means that if any query is true,
|
||||
// then the user should have access.
|
||||
sqlClause := sqltypes.Or(sqltypes.RegoSource(builder.String()), queries...)
|
||||
// Always wrap in parens to ensure the correct precedence when combining with other
|
||||
// SQL clauses.
|
||||
return sqltypes.BoolParenthesis(sqlClause), nil
|
||||
}
|
||||
|
||||
func convertQuery(cfg ConvertConfig, q ast.Body) (sqltypes.BooleanNode, error) {
|
||||
var expressions []sqltypes.BooleanNode
|
||||
for _, e := range q {
|
||||
exp, err := convertExpression(cfg, e)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("expression %s: %w", e.String(), err)
|
||||
}
|
||||
|
||||
expressions = append(expressions, exp)
|
||||
}
|
||||
|
||||
// All expressions in a single query are AND'd together. This means that
|
||||
// all expressions must be true for the user to have access.
|
||||
return sqltypes.And(sqltypes.RegoSource(q.String()), expressions...), nil
|
||||
}
|
||||
|
||||
func convertExpression(cfg ConvertConfig, e *ast.Expr) (sqltypes.BooleanNode, error) {
|
||||
if e.IsCall() {
|
||||
n, err := convertCall(cfg, e.Terms.([]*ast.Term))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("call: %w", err)
|
||||
}
|
||||
|
||||
boolN, ok := n.(sqltypes.BooleanNode)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("call %q: not a boolean expression", e.String())
|
||||
}
|
||||
return boolN, nil
|
||||
}
|
||||
|
||||
// If it's not a call, it is a single term
|
||||
if term, ok := e.Terms.(*ast.Term); ok {
|
||||
ty, err := convertTerm(cfg, term)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("convert term %s: %w", term.String(), err)
|
||||
}
|
||||
|
||||
tyBool, ok := ty.(sqltypes.BooleanNode)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("convert term %s is not a boolean: %w", term.String(), err)
|
||||
}
|
||||
|
||||
return tyBool, nil
|
||||
}
|
||||
|
||||
return nil, xerrors.Errorf("expression %s not supported", e.String())
|
||||
}
|
||||
|
||||
// convertCall converts a function call to a SQL expression.
|
||||
func convertCall(cfg ConvertConfig, call ast.Call) (sqltypes.Node, error) {
|
||||
if len(call) == 0 {
|
||||
return nil, xerrors.Errorf("empty call")
|
||||
}
|
||||
|
||||
// Operator is the first term
|
||||
op := call[0]
|
||||
var args []*ast.Term
|
||||
if len(call) > 1 {
|
||||
args = call[1:]
|
||||
}
|
||||
|
||||
opString := op.String()
|
||||
// Supported operators.
|
||||
switch op.String() {
|
||||
case "neq", "eq", "equals", "equal":
|
||||
args, err := convertTerms(cfg, args, 2)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("arguments: %w", err)
|
||||
}
|
||||
|
||||
not := false
|
||||
if opString == "neq" || opString == "notequals" || opString == "notequal" {
|
||||
not = true
|
||||
}
|
||||
|
||||
equality := sqltypes.Equality(not, args[0], args[1])
|
||||
return sqltypes.BoolParenthesis(equality), nil
|
||||
case "internal.member_2":
|
||||
args, err := convertTerms(cfg, args, 2)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("arguments: %w", err)
|
||||
}
|
||||
|
||||
member := sqltypes.MemberOf(args[0], args[1])
|
||||
return sqltypes.BoolParenthesis(member), nil
|
||||
default:
|
||||
return nil, xerrors.Errorf("operator %s not supported", op)
|
||||
}
|
||||
}
|
||||
|
||||
func convertTerms(cfg ConvertConfig, terms []*ast.Term, expected int) ([]sqltypes.Node, error) {
|
||||
if len(terms) != expected {
|
||||
return nil, xerrors.Errorf("expected %d terms, got %d", expected, len(terms))
|
||||
}
|
||||
|
||||
result := make([]sqltypes.Node, 0, len(terms))
|
||||
for _, t := range terms {
|
||||
term, err := convertTerm(cfg, t)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("term: %w", err)
|
||||
}
|
||||
result = append(result, term)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func convertTerm(cfg ConvertConfig, term *ast.Term) (sqltypes.Node, error) {
|
||||
source := sqltypes.RegoSource(term.String())
|
||||
switch t := term.Value.(type) {
|
||||
case ast.Var:
|
||||
// All vars should be contained in ast.Ref's.
|
||||
return nil, xerrors.New("var not yet supported")
|
||||
case ast.Ref:
|
||||
if len(t) == 0 {
|
||||
// A reference with no text is a variable with no name?
|
||||
// This makes no sense.
|
||||
return nil, xerrors.New("empty ref not supported")
|
||||
}
|
||||
|
||||
if cfg.VariableConverter == nil {
|
||||
return nil, xerrors.New("no variable converter provided to handle variables")
|
||||
}
|
||||
|
||||
// The structure of references is as follows:
|
||||
// 1. All variables start with a regoAst.Var as the first term.
|
||||
// 2. The next term is either a regoAst.String or a regoAst.Var.
|
||||
// - regoAst.String if a static field name or index.
|
||||
// - regoAst.Var if the field reference is a variable itself. Such as
|
||||
// the wildcard "[_]"
|
||||
// 3. Repeat 1-2 until the end of the reference.
|
||||
node, ok := cfg.VariableConverter.ConvertVariable(t)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("variable %q cannot be converted", t.String())
|
||||
}
|
||||
return node, nil
|
||||
case ast.String:
|
||||
return sqltypes.String(string(t)), nil
|
||||
case ast.Number:
|
||||
return sqltypes.Number(source, json.Number(t)), nil
|
||||
case ast.Boolean:
|
||||
return sqltypes.Bool(bool(t)), nil
|
||||
case *ast.Array:
|
||||
elems := make([]sqltypes.Node, 0, t.Len())
|
||||
for i := 0; i < t.Len(); i++ {
|
||||
value, err := convertTerm(cfg, t.Elem(i))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("array element %d in %q: %w", i, t.String(), err)
|
||||
}
|
||||
elems = append(elems, value)
|
||||
}
|
||||
return sqltypes.Array(source, elems...)
|
||||
case ast.Object:
|
||||
return nil, xerrors.New("object not yet supported")
|
||||
case ast.Set:
|
||||
// Just treat a set like an array for now.
|
||||
arr := t.Sorted()
|
||||
return convertTerm(cfg, &ast.Term{
|
||||
Value: arr,
|
||||
Location: term.Location,
|
||||
})
|
||||
case ast.Call:
|
||||
// This is a function call
|
||||
return convertCall(cfg, t)
|
||||
default:
|
||||
return nil, xerrors.Errorf("%T not yet supported", t)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,307 @@
|
|||
package regosql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||
)
|
||||
|
||||
// TestRegoQueriesNoVariables handles cases without variables. These should be
|
||||
// very simple and straight forward.
|
||||
func TestRegoQueries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := func(v string) string {
|
||||
return "(" + v + ")"
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Queries []string
|
||||
ExpectedSQL string
|
||||
ExpectError bool
|
||||
ExpectedSQLGenError bool
|
||||
|
||||
VariableConverter sqltypes.VariableMatcher
|
||||
}{
|
||||
{
|
||||
Name: "Empty",
|
||||
Queries: []string{``},
|
||||
ExpectedSQL: "true",
|
||||
},
|
||||
{
|
||||
Name: "True",
|
||||
Queries: []string{`true`},
|
||||
ExpectedSQL: "true",
|
||||
},
|
||||
{
|
||||
Name: "False",
|
||||
Queries: []string{`false`},
|
||||
ExpectedSQL: "false",
|
||||
},
|
||||
{
|
||||
Name: "MultipleBool",
|
||||
Queries: []string{"true", "false"},
|
||||
ExpectedSQL: "(true OR false)",
|
||||
},
|
||||
{
|
||||
Name: "Numbers",
|
||||
Queries: []string{
|
||||
"(1 != 2) = true",
|
||||
"5 == 5",
|
||||
},
|
||||
ExpectedSQL: p("((1 != 2) = true) OR (5 = 5)"),
|
||||
},
|
||||
// Variables
|
||||
{
|
||||
// Always return a constant string for all variables.
|
||||
Name: "V_Basic",
|
||||
Queries: []string{
|
||||
`input.x = "hello_world"`,
|
||||
},
|
||||
ExpectedSQL: p("only_var = 'hello_world'"),
|
||||
VariableConverter: sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
sqltypes.StringVarMatcher("only_var", []string{
|
||||
"input", "x",
|
||||
}),
|
||||
),
|
||||
},
|
||||
// Coder Variables
|
||||
{
|
||||
// Always return a constant string for all variables.
|
||||
Name: "GroupACL",
|
||||
Queries: []string{
|
||||
`"read" in input.object.acl_group_list.allUsers`,
|
||||
},
|
||||
ExpectedSQL: "(group_acl->'allUsers' ? 'read')",
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "GroupWildcard",
|
||||
Queries: []string{`"*" in input.object.acl_group_list.allUsers`},
|
||||
ExpectedSQL: "(group_acl->'allUsers' ? '*')",
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
// Always return a constant string for all variables.
|
||||
Name: "GroupACLWithVarField",
|
||||
Queries: []string{
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: "(group_acl->organization_id :: text ? 'read')",
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "VarInArray",
|
||||
Queries: []string{
|
||||
`input.object.org_owner in {"a", "b", "c"}`,
|
||||
},
|
||||
ExpectedSQL: p("organization_id :: text = ANY(ARRAY ['a','b','c'])"),
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "SetDereference",
|
||||
Queries: []string{`"*" in input.object.acl_group_list[input.object.org_owner]`},
|
||||
ExpectedSQL: p("group_acl->organization_id :: text ? '*'"),
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "JsonbLiteralDereference",
|
||||
Queries: []string{`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`},
|
||||
ExpectedSQL: p("group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'"),
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "Complex",
|
||||
Queries: []string{
|
||||
`input.object.org_owner != ""`,
|
||||
`input.object.org_owner in {"a", "b", "c"}`,
|
||||
`input.object.org_owner != ""`,
|
||||
`"read" in input.object.acl_group_list.allUsers`,
|
||||
`"read" in input.object.acl_user_list.me`,
|
||||
},
|
||||
ExpectedSQL: `((organization_id :: text != '') OR ` +
|
||||
`(organization_id :: text = ANY(ARRAY ['a','b','c'])) OR ` +
|
||||
`(organization_id :: text != '') OR ` +
|
||||
`(group_acl->'allUsers' ? 'read') OR ` +
|
||||
`(user_acl->'me' ? 'read'))`,
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "NoACLs",
|
||||
Queries: []string{
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
|
||||
},
|
||||
// Special case where the bool is wrapped
|
||||
ExpectedSQL: p("(false) OR (false)"),
|
||||
VariableConverter: regosql.NoACLConverter(),
|
||||
},
|
||||
{
|
||||
Name: "TwoExpressions",
|
||||
Queries: []string{
|
||||
`true; true`,
|
||||
},
|
||||
ExpectedSQL: p("true AND true"),
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
|
||||
// Actual vectors from production
|
||||
{
|
||||
Name: "FromOwner",
|
||||
Queries: []string{
|
||||
``,
|
||||
`"05f58202-4bfc-43ce-9ba4-5ff6e0174a71" = input.object.org_owner`,
|
||||
`"read" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
|
||||
},
|
||||
ExpectedSQL: "true",
|
||||
VariableConverter: regosql.NoACLConverter(),
|
||||
},
|
||||
{
|
||||
Name: "OrgAdmin",
|
||||
Queries: []string{
|
||||
`input.object.org_owner != "";
|
||||
input.object.org_owner in {"05f58202-4bfc-43ce-9ba4-5ff6e0174a71"};
|
||||
input.object.owner != "";
|
||||
"d5389ccc-57a4-4b13-8c3f-31747bcdc9f1" = input.object.owner`,
|
||||
},
|
||||
ExpectedSQL: "((organization_id :: text != '') AND " +
|
||||
"(organization_id :: text = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'])) AND " +
|
||||
"(owner_id :: text != '') AND " +
|
||||
"('d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' = owner_id :: text))",
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "UserACLAllow",
|
||||
Queries: []string{
|
||||
`"read" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
|
||||
`"*" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
|
||||
},
|
||||
ExpectedSQL: "((user_acl->'d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' ? 'read') OR " +
|
||||
"(user_acl->'d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' ? '*'))",
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "NoACLConfig",
|
||||
Queries: []string{
|
||||
`input.object.org_owner != "";
|
||||
input.object.org_owner in {"05f58202-4bfc-43ce-9ba4-5ff6e0174a71"};
|
||||
"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: "((organization_id :: text != '') AND (organization_id :: text = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'])) AND (false))",
|
||||
VariableConverter: regosql.NoACLConverter(),
|
||||
},
|
||||
{
|
||||
Name: "EmptyACLListNoACLs",
|
||||
Queries: []string{
|
||||
`input.object.org_owner != "";
|
||||
input.object.org_owner in set();
|
||||
"create" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
|
||||
`input.object.org_owner != "";
|
||||
input.object.org_owner in set();
|
||||
"*" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
|
||||
`"create" in input.object.acl_user_list.me`,
|
||||
|
||||
`"*" in input.object.acl_user_list.me`,
|
||||
},
|
||||
ExpectedSQL: p(p("(organization_id :: text != '') AND (false) AND (group_acl->organization_id :: text ? 'create')") + " OR " +
|
||||
p("(organization_id :: text != '') AND (false) AND (group_acl->organization_id :: text ? '*')") + " OR " +
|
||||
p("user_acl->'me' ? 'create'") + " OR " +
|
||||
p("user_acl->'me' ? '*'")),
|
||||
VariableConverter: regosql.DefaultVariableConverter(),
|
||||
},
|
||||
{
|
||||
Name: "TemplateOwner",
|
||||
Queries: []string{
|
||||
`neq(input.object.org_owner, "");
|
||||
internal.member_2(input.object.org_owner, {"3bf82434-e40b-44ae-b3d8-d0115bba9bad", "5630fda3-26ab-462c-9014-a88a62d7a415", "c304877a-bc0d-4e9b-9623-a38eae412929"});
|
||||
neq(input.object.owner, "");
|
||||
"806dd721-775f-4c85-9ce3-63fbbd975954" = input.object.owner`,
|
||||
},
|
||||
ExpectedSQL: p(p("organization_id :: text != ''") + " AND " +
|
||||
p("organization_id :: text = ANY(ARRAY ['3bf82434-e40b-44ae-b3d8-d0115bba9bad','5630fda3-26ab-462c-9014-a88a62d7a415','c304877a-bc0d-4e9b-9623-a38eae412929'])") + " AND " +
|
||||
p("false") + " AND " +
|
||||
p("false")),
|
||||
VariableConverter: regosql.TemplateConverter(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
part := partialQueries(tc.Queries...)
|
||||
|
||||
cfg := regosql.ConvertConfig{
|
||||
VariableConverter: tc.VariableConverter,
|
||||
}
|
||||
|
||||
requireConvert(t, convertTestCase{
|
||||
part: part,
|
||||
cfg: cfg,
|
||||
expectSQL: tc.ExpectedSQL,
|
||||
expectConvertError: tc.ExpectError,
|
||||
expectSQLGenError: tc.ExpectedSQLGenError,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type convertTestCase struct {
|
||||
part *rego.PartialQueries
|
||||
cfg regosql.ConvertConfig
|
||||
|
||||
expectConvertError bool
|
||||
expectSQL string
|
||||
expectSQLGenError bool
|
||||
}
|
||||
|
||||
func requireConvert(t *testing.T, tc convertTestCase) {
|
||||
t.Helper()
|
||||
|
||||
for i, q := range tc.part.Queries {
|
||||
t.Logf("Query %d: %s", i, q.String())
|
||||
}
|
||||
for i, s := range tc.part.Support {
|
||||
t.Logf("Support %d: %s", i, s.String())
|
||||
}
|
||||
|
||||
root, err := regosql.ConvertRegoAst(tc.cfg, tc.part)
|
||||
if tc.expectConvertError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err, "compile")
|
||||
|
||||
gen := sqltypes.NewSQLGenerator()
|
||||
sqlString := root.SQLString(gen)
|
||||
if tc.expectSQLGenError {
|
||||
require.True(t, len(gen.Errors()) > 0, "expected SQL generation error")
|
||||
} else {
|
||||
require.NoError(t, err, "sql gen")
|
||||
require.Equal(t, tc.expectSQL, sqlString, "sql match")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func partialQueries(queries ...string) *rego.PartialQueries {
|
||||
opts := ast.ParserOptions{
|
||||
AllFutureKeywords: true,
|
||||
}
|
||||
|
||||
astQueries := make([]ast.Body, 0, len(queries))
|
||||
for _, q := range queries {
|
||||
astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts))
|
||||
}
|
||||
|
||||
return ®o.PartialQueries{
|
||||
Queries: astQueries,
|
||||
Support: []*ast.Module{},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package regosql
|
||||
|
||||
import "github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||
|
||||
func organizationOwnerMatcher() sqltypes.VariableMatcher {
|
||||
return sqltypes.StringVarMatcher("organization_id :: text", []string{"input", "object", "org_owner"})
|
||||
}
|
||||
|
||||
func userOwnerMatcher() sqltypes.VariableMatcher {
|
||||
return sqltypes.StringVarMatcher("owner_id :: text", []string{"input", "object", "owner"})
|
||||
}
|
||||
|
||||
func groupACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher {
|
||||
return ACLGroupMatcher(m, "group_acl", []string{"input", "object", "acl_group_list"})
|
||||
}
|
||||
|
||||
func userACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher {
|
||||
return ACLGroupMatcher(m, "user_acl", []string{"input", "object", "acl_user_list"})
|
||||
}
|
||||
|
||||
func TemplateConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
organizationOwnerMatcher(),
|
||||
// Templates have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
matcher.RegisterMatcher(
|
||||
groupACLMatcher(matcher),
|
||||
userACLMatcher(matcher),
|
||||
)
|
||||
return matcher
|
||||
}
|
||||
|
||||
// NoACLConverter should be used when the target SQL table does not contain
|
||||
// group or user ACL columns.
|
||||
func NoACLConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
organizationOwnerMatcher(),
|
||||
userOwnerMatcher(),
|
||||
)
|
||||
matcher.RegisterMatcher(
|
||||
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
|
||||
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
|
||||
)
|
||||
|
||||
return matcher
|
||||
}
|
||||
|
||||
func DefaultVariableConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
organizationOwnerMatcher(),
|
||||
userOwnerMatcher(),
|
||||
)
|
||||
matcher.RegisterMatcher(
|
||||
groupACLMatcher(matcher),
|
||||
userACLMatcher(matcher),
|
||||
)
|
||||
|
||||
return matcher
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
// Package regosql converts rego queries into SQL WHERE clauses. This is so
|
||||
// the rego queries can be used to filter the results of a SQL query.
|
||||
package regosql
|
|
@ -0,0 +1,61 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
var _ Node = alwaysFalse{}
|
||||
var _ VariableMatcher = alwaysFalse{}
|
||||
|
||||
type alwaysFalse struct {
|
||||
Matcher VariableMatcher
|
||||
|
||||
InnerNode Node
|
||||
}
|
||||
|
||||
// AlwaysFalse overrides the inner node with a constant "false".
|
||||
func AlwaysFalse(m VariableMatcher) VariableMatcher {
|
||||
return alwaysFalse{
|
||||
Matcher: m,
|
||||
}
|
||||
}
|
||||
|
||||
// AlwaysFalseNode is mainly used for unit testing to make a Node immediately.
|
||||
func AlwaysFalseNode(n Node) Node {
|
||||
return alwaysFalse{
|
||||
InnerNode: n,
|
||||
Matcher: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// UseAs uses a type no one supports to always override with false.
|
||||
func (alwaysFalse) UseAs() Node { return alwaysFalse{} }
|
||||
func (f alwaysFalse) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
if f.Matcher != nil {
|
||||
n, ok := f.Matcher.ConvertVariable(rego)
|
||||
if ok {
|
||||
return alwaysFalse{
|
||||
Matcher: f.Matcher,
|
||||
InnerNode: n,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (alwaysFalse) SQLString(_ *SQLGenerator) string {
|
||||
return "false"
|
||||
}
|
||||
|
||||
func (alwaysFalse) ContainsSQL(_ *SQLGenerator, _ Node) (string, error) {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
func (alwaysFalse) ContainedInSQL(_ *SQLGenerator, _ Node) (string, error) {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
func (alwaysFalse) EqualsSQLString(_ *SQLGenerator, _ bool, _ Node) (string, error) {
|
||||
return "false", nil
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type ASTArray struct {
|
||||
Source RegoSource
|
||||
Value []Node
|
||||
}
|
||||
|
||||
// Array is typed to whatever the first element is. If there is not first
|
||||
// element, the array element type is invalid.
|
||||
func Array(source RegoSource, nodes ...Node) (Node, error) {
|
||||
for i := 1; i < len(nodes); i++ {
|
||||
if reflect.TypeOf(nodes[0]) != reflect.TypeOf(nodes[i]) {
|
||||
// Do not allow mixed types in arrays
|
||||
return nil, xerrors.Errorf("array element %d in %q: type mismatch", i, source)
|
||||
}
|
||||
}
|
||||
return ASTArray{Value: nodes, Source: source}, nil
|
||||
}
|
||||
|
||||
func (ASTArray) UseAs() Node { return ASTArray{} }
|
||||
|
||||
func (a ASTArray) ContainsSQL(cfg *SQLGenerator, needle Node) (string, error) {
|
||||
// If we have no elements in our set, then our needle is never in the set.
|
||||
if len(a.Value) == 0 {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
// This condition supports any contains function if the needle type is
|
||||
// the same as the ASTArray element type.
|
||||
if reflect.TypeOf(a.MyType().UseAs()) != reflect.TypeOf(needle.UseAs()) {
|
||||
return "ArrayContainsError", xerrors.Errorf("array contains %q: type mismatch (%T, %T)",
|
||||
a.Source, a.MyType(), needle)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(%s)", needle.SQLString(cfg), a.SQLString(cfg)), nil
|
||||
}
|
||||
|
||||
func (a ASTArray) SQLString(cfg *SQLGenerator) string {
|
||||
switch a.MyType().UseAs().(type) {
|
||||
case invalidNode:
|
||||
cfg.AddError(xerrors.Errorf("array %q: empty array", a.Source))
|
||||
return "ArrayError"
|
||||
case AstNumber, AstString, AstBoolean:
|
||||
// Primitive types
|
||||
values := make([]string, 0, len(a.Value))
|
||||
for _, v := range a.Value {
|
||||
values = append(values, v.SQLString(cfg))
|
||||
}
|
||||
return fmt.Sprintf("ARRAY [%s]", strings.Join(values, ","))
|
||||
}
|
||||
|
||||
cfg.AddError(xerrors.Errorf("array %q: unsupported type %T", a.Source, a.MyType()))
|
||||
return "ArrayError"
|
||||
}
|
||||
|
||||
func (a ASTArray) MyType() Node {
|
||||
if len(a.Value) == 0 {
|
||||
return invalidNode{}
|
||||
}
|
||||
return a.Value[0]
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type binaryOperator int
|
||||
|
||||
const (
|
||||
_ binaryOperator = iota
|
||||
binaryOpOR
|
||||
binaryOpAND
|
||||
)
|
||||
|
||||
type binaryOp struct {
|
||||
source RegoSource
|
||||
op binaryOperator
|
||||
|
||||
Terms []BooleanNode
|
||||
}
|
||||
|
||||
func (binaryOp) UseAs() Node { return binaryOp{} }
|
||||
func (binaryOp) IsBooleanNode() {}
|
||||
|
||||
func Or(source RegoSource, terms ...BooleanNode) BooleanNode {
|
||||
return newBinaryOp(source, binaryOpOR, terms...)
|
||||
}
|
||||
|
||||
func And(source RegoSource, terms ...BooleanNode) BooleanNode {
|
||||
return newBinaryOp(source, binaryOpAND, terms...)
|
||||
}
|
||||
|
||||
func newBinaryOp(source RegoSource, op binaryOperator, terms ...BooleanNode) BooleanNode {
|
||||
if len(terms) == 0 {
|
||||
// TODO: How to handle 0 terms?
|
||||
return Bool(false)
|
||||
}
|
||||
|
||||
opTerms := make([]BooleanNode, 0, len(terms))
|
||||
for i := range terms {
|
||||
// Always wrap terms in parentheses to be safe.
|
||||
opTerms = append(opTerms, BoolParenthesis(terms[i]))
|
||||
}
|
||||
|
||||
if len(opTerms) == 1 {
|
||||
return opTerms[0]
|
||||
}
|
||||
|
||||
return binaryOp{
|
||||
Terms: opTerms,
|
||||
op: op,
|
||||
source: source,
|
||||
}
|
||||
}
|
||||
|
||||
func (b binaryOp) SQLString(cfg *SQLGenerator) string {
|
||||
sqlOp := ""
|
||||
switch b.op {
|
||||
case binaryOpOR:
|
||||
sqlOp = "OR"
|
||||
case binaryOpAND:
|
||||
sqlOp = "AND"
|
||||
default:
|
||||
cfg.AddError(xerrors.Errorf("unsupported binary operator: %s (%d)", b.source, b.op))
|
||||
return "BinaryOpError"
|
||||
}
|
||||
|
||||
terms := make([]string, 0, len(b.Terms))
|
||||
for _, term := range b.Terms {
|
||||
termSQL := term.SQLString(cfg)
|
||||
terms = append(terms, termSQL)
|
||||
}
|
||||
|
||||
return strings.Join(terms, " "+sqlOp+" ")
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// AstBoolean is a literal true/false value.
|
||||
type AstBoolean struct {
|
||||
Source RegoSource
|
||||
Value bool
|
||||
}
|
||||
|
||||
func Bool(t bool) BooleanNode {
|
||||
return AstBoolean{Value: t, Source: RegoSource(strconv.FormatBool(t))}
|
||||
}
|
||||
|
||||
func (AstBoolean) IsBooleanNode() {}
|
||||
func (AstBoolean) UseAs() Node { return AstBoolean{} }
|
||||
|
||||
func (b AstBoolean) SQLString(_ *SQLGenerator) string {
|
||||
return strconv.FormatBool(b.Value)
|
||||
}
|
||||
|
||||
func (b AstBoolean) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
return boolEqualsSQLString(cfg, b, not, other)
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
// Package sqltypes contains the types used to convert rego queries into SQL.
|
||||
// The rego ast is converted into these types to better control the SQL
|
||||
// generation. It allows writing the SQL generation for types in an easier to
|
||||
// read way.
|
||||
package sqltypes
|
|
@ -0,0 +1,102 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// SupportsEquality is an interface that can be implemented by types that
|
||||
// support equality with other types. We defer to other types to implement this
|
||||
// as it is much easier to implement this in the context of the type.
|
||||
type SupportsEquality interface {
|
||||
// EqualsSQLString intentionally returns an error. This is so if
|
||||
// left = right is not supported, we can try right = left.
|
||||
EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error)
|
||||
}
|
||||
|
||||
var _ BooleanNode = equality{}
|
||||
var _ Node = equality{}
|
||||
var _ SupportsEquality = equality{}
|
||||
|
||||
type equality struct {
|
||||
Left Node
|
||||
Right Node
|
||||
|
||||
// Not just inverses the result of the comparison. We could implement this
|
||||
// as a Not node wrapping the equality, but this is more efficient.
|
||||
Not bool
|
||||
}
|
||||
|
||||
func Equality(notEquals bool, a, b Node) BooleanNode {
|
||||
return equality{
|
||||
Left: a,
|
||||
Right: b,
|
||||
Not: notEquals,
|
||||
}
|
||||
}
|
||||
|
||||
func (equality) IsBooleanNode() {}
|
||||
|
||||
// UseAs returns an ASTBoolean as equalities resolve to boolean values
|
||||
func (equality) UseAs() Node { return AstBoolean{} }
|
||||
|
||||
func (e equality) SQLString(cfg *SQLGenerator) string {
|
||||
// Equalities can be flipped without changing the result, so we can
|
||||
// try both left = right and right = left.
|
||||
if eq, ok := e.Left.(SupportsEquality); ok {
|
||||
v, err := eq.EqualsSQLString(cfg, e.Not, e.Right)
|
||||
if err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
if eq, ok := e.Right.(SupportsEquality); ok {
|
||||
v, err := eq.EqualsSQLString(cfg, e.Not, e.Left)
|
||||
if err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
cfg.AddError(xerrors.Errorf("unsupported equality: %T %s %T", e.Left, equalsOp(e.Not), e.Right))
|
||||
return "EqualityError"
|
||||
}
|
||||
|
||||
func (e equality) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
return boolEqualsSQLString(cfg, e, not, other)
|
||||
}
|
||||
|
||||
func boolEqualsSQLString(cfg *SQLGenerator, a BooleanNode, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case BooleanNode:
|
||||
bn, ok := other.(BooleanNode)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("not a boolean node: %T", other)
|
||||
}
|
||||
|
||||
// Always wrap both sides in parens to ensure the correct precedence.
|
||||
return fmt.Sprintf("%s %s %s",
|
||||
BoolParenthesis(a).SQLString(cfg),
|
||||
equalsOp(not),
|
||||
BoolParenthesis(bn).SQLString(cfg),
|
||||
), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T", a, equalsOp(not), other)
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:revive
|
||||
func equalsOp(not bool) string {
|
||||
if not {
|
||||
return "!="
|
||||
}
|
||||
return "="
|
||||
}
|
||||
|
||||
func basicSQLEquality(cfg *SQLGenerator, not bool, a, b Node) string {
|
||||
return fmt.Sprintf("%s %s %s",
|
||||
a.SQLString(cfg),
|
||||
equalsOp(not),
|
||||
b.SQLString(cfg),
|
||||
)
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
package sqltypes_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||
)
|
||||
|
||||
func TestEquality(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Equality sqltypes.Node
|
||||
ExpectedSQL string
|
||||
ExpectedErrors int
|
||||
}{
|
||||
{
|
||||
Name: "String=String",
|
||||
Equality: sqltypes.Equality(false,
|
||||
sqltypes.String("foo"),
|
||||
sqltypes.String("bar"),
|
||||
),
|
||||
ExpectedSQL: "'foo' = 'bar'",
|
||||
},
|
||||
{
|
||||
Name: "Number=Number",
|
||||
Equality: sqltypes.Equality(false,
|
||||
sqltypes.Number("", json.Number("5")),
|
||||
sqltypes.Number("", json.Number("22")),
|
||||
),
|
||||
ExpectedSQL: "5 = 22",
|
||||
},
|
||||
{
|
||||
Name: "Bool=Bool",
|
||||
Equality: sqltypes.Equality(false,
|
||||
sqltypes.Bool(true),
|
||||
sqltypes.Bool(false),
|
||||
),
|
||||
ExpectedSQL: "true = false",
|
||||
},
|
||||
{
|
||||
Name: "Bool=Equality",
|
||||
Equality: sqltypes.Equality(false,
|
||||
sqltypes.Bool(true),
|
||||
sqltypes.Equality(true,
|
||||
sqltypes.Equality(true,
|
||||
sqltypes.String("foo"),
|
||||
sqltypes.String("bar"),
|
||||
),
|
||||
sqltypes.Bool(false),
|
||||
),
|
||||
),
|
||||
ExpectedSQL: "true = (('foo' != 'bar') != false)",
|
||||
},
|
||||
{
|
||||
Name: "Equality=Equality",
|
||||
Equality: sqltypes.Equality(false,
|
||||
sqltypes.Equality(true,
|
||||
sqltypes.Bool(true),
|
||||
sqltypes.Bool(false),
|
||||
),
|
||||
sqltypes.Equality(false,
|
||||
sqltypes.String("foo"),
|
||||
sqltypes.String("foo"),
|
||||
),
|
||||
),
|
||||
ExpectedSQL: "(true != false) = ('foo' = 'foo')",
|
||||
},
|
||||
{
|
||||
Name: "Membership=Membership",
|
||||
Equality: sqltypes.Equality(false,
|
||||
sqltypes.Equality(true,
|
||||
sqltypes.MemberOf(
|
||||
sqltypes.String("foo"),
|
||||
must(sqltypes.Array("",
|
||||
sqltypes.String("foo"),
|
||||
sqltypes.String("bar"),
|
||||
)),
|
||||
),
|
||||
sqltypes.Bool(false),
|
||||
),
|
||||
sqltypes.Equality(false,
|
||||
sqltypes.Bool(true),
|
||||
sqltypes.MemberOf(
|
||||
sqltypes.Number("", "2"),
|
||||
must(sqltypes.Array("",
|
||||
sqltypes.Number("", "5"),
|
||||
sqltypes.Number("", "2"),
|
||||
)),
|
||||
),
|
||||
),
|
||||
),
|
||||
ExpectedSQL: "(('foo' = ANY(ARRAY ['foo','bar'])) != false) = (true = (2 = ANY(ARRAY [5,2])))",
|
||||
},
|
||||
{
|
||||
Name: "AlwaysFalse=String",
|
||||
Equality: sqltypes.Equality(false,
|
||||
sqltypes.AlwaysFalseNode(sqltypes.String("foo")),
|
||||
sqltypes.String("foo"),
|
||||
),
|
||||
ExpectedSQL: "false",
|
||||
},
|
||||
{
|
||||
Name: "String=AlwaysFalse",
|
||||
Equality: sqltypes.Equality(false,
|
||||
sqltypes.String("foo"),
|
||||
sqltypes.AlwaysFalseNode(sqltypes.String("foo")),
|
||||
),
|
||||
ExpectedSQL: "false",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gen := sqltypes.NewSQLGenerator()
|
||||
found := tc.Equality.SQLString(gen)
|
||||
if tc.ExpectedErrors > 0 {
|
||||
require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected AstNumber of errors")
|
||||
} else {
|
||||
require.Equal(t, tc.ExpectedSQL, found, "expected sql")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package sqltypes
|
||||
|
||||
type SQLGenerator struct {
|
||||
errors []error
|
||||
}
|
||||
|
||||
func NewSQLGenerator() *SQLGenerator {
|
||||
return &SQLGenerator{}
|
||||
}
|
||||
|
||||
func (g *SQLGenerator) AddError(err error) {
|
||||
if err != nil {
|
||||
g.errors = append(g.errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *SQLGenerator) Errors() []error {
|
||||
return g.errors
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// SupportsContains is an interface that can be implemented by types that
|
||||
// support "me.Contains(other)". This is `internal_member2` in the rego.
|
||||
type SupportsContains interface {
|
||||
ContainsSQL(cfg *SQLGenerator, other Node) (string, error)
|
||||
}
|
||||
|
||||
// SupportsContainedIn is the inverse of SupportsContains. It is implemented
|
||||
// from the "needle" rather than the haystack.
|
||||
type SupportsContainedIn interface {
|
||||
ContainedInSQL(cfg *SQLGenerator, other Node) (string, error)
|
||||
}
|
||||
|
||||
var _ BooleanNode = memberOf{}
|
||||
var _ Node = memberOf{}
|
||||
var _ SupportsEquality = memberOf{}
|
||||
|
||||
type memberOf struct {
|
||||
Needle Node
|
||||
Haystack Node
|
||||
}
|
||||
|
||||
func MemberOf(needle, haystack Node) BooleanNode {
|
||||
return memberOf{
|
||||
Needle: needle,
|
||||
Haystack: haystack,
|
||||
}
|
||||
}
|
||||
|
||||
func (memberOf) IsBooleanNode() {}
|
||||
func (memberOf) UseAs() Node { return AstBoolean{} }
|
||||
|
||||
func (e memberOf) SQLString(cfg *SQLGenerator) string {
|
||||
// Equalities can be flipped without changing the result, so we can
|
||||
// try both left = right and right = left.
|
||||
if sc, ok := e.Haystack.(SupportsContains); ok {
|
||||
v, err := sc.ContainsSQL(cfg, e.Needle)
|
||||
if err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
if sc, ok := e.Needle.(SupportsContainedIn); ok {
|
||||
v, err := sc.ContainedInSQL(cfg, e.Haystack)
|
||||
if err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
cfg.AddError(xerrors.Errorf("unsupported contains: %T contains %T", e.Haystack, e.Needle))
|
||||
return "MemberOfError"
|
||||
}
|
||||
|
||||
func (e memberOf) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
return boolEqualsSQLString(cfg, e, not, other)
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
package sqltypes_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||
)
|
||||
|
||||
func TestMembership(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Membership sqltypes.Node
|
||||
ExpectedSQL string
|
||||
ExpectedErrors int
|
||||
}{
|
||||
{
|
||||
Name: "StringArray",
|
||||
Membership: sqltypes.MemberOf(
|
||||
sqltypes.String("foo"),
|
||||
must(sqltypes.Array("",
|
||||
sqltypes.String("bar"),
|
||||
sqltypes.String("buzz"),
|
||||
)),
|
||||
),
|
||||
ExpectedSQL: "'foo' = ANY(ARRAY ['bar','buzz'])",
|
||||
},
|
||||
{
|
||||
Name: "NumberArray",
|
||||
Membership: sqltypes.MemberOf(
|
||||
sqltypes.Number("", "5"),
|
||||
must(sqltypes.Array("",
|
||||
sqltypes.Number("", "2"),
|
||||
sqltypes.Number("", "5"),
|
||||
)),
|
||||
),
|
||||
ExpectedSQL: "5 = ANY(ARRAY [2,5])",
|
||||
},
|
||||
{
|
||||
Name: "BoolArray",
|
||||
Membership: sqltypes.MemberOf(
|
||||
sqltypes.Bool(true),
|
||||
must(sqltypes.Array("",
|
||||
sqltypes.Bool(false),
|
||||
sqltypes.Bool(true),
|
||||
)),
|
||||
),
|
||||
ExpectedSQL: "true = ANY(ARRAY [false,true])",
|
||||
},
|
||||
{
|
||||
Name: "EmptyArray",
|
||||
Membership: sqltypes.MemberOf(
|
||||
sqltypes.Bool(true),
|
||||
must(sqltypes.Array("")),
|
||||
),
|
||||
ExpectedSQL: "false",
|
||||
},
|
||||
{
|
||||
Name: "AlwaysFalseMember",
|
||||
Membership: sqltypes.MemberOf(
|
||||
sqltypes.AlwaysFalseNode(sqltypes.Bool(true)),
|
||||
must(sqltypes.Array("",
|
||||
sqltypes.Bool(false),
|
||||
sqltypes.Bool(true),
|
||||
)),
|
||||
),
|
||||
ExpectedSQL: "false",
|
||||
},
|
||||
{
|
||||
Name: "AlwaysFalseArray",
|
||||
Membership: sqltypes.MemberOf(
|
||||
sqltypes.Bool(true),
|
||||
sqltypes.AlwaysFalseNode(must(sqltypes.Array("",
|
||||
sqltypes.Bool(false),
|
||||
sqltypes.Bool(true),
|
||||
))),
|
||||
),
|
||||
ExpectedSQL: "false",
|
||||
},
|
||||
|
||||
// Errors
|
||||
{
|
||||
Name: "Unsupported",
|
||||
Membership: sqltypes.MemberOf(
|
||||
sqltypes.Bool(true),
|
||||
sqltypes.Bool(true),
|
||||
),
|
||||
ExpectedErrors: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gen := sqltypes.NewSQLGenerator()
|
||||
found := tc.Membership.SQLString(gen)
|
||||
if tc.ExpectedErrors > 0 {
|
||||
require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected some errors")
|
||||
} else {
|
||||
require.Equal(t, tc.ExpectedSQL, found, "expected sql")
|
||||
require.Equal(t, tc.ExpectedErrors, 0, "expected no errors")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func must[V any](v V, err error) V {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
SQLString(cfg *SQLGenerator) string
|
||||
// UseAs is a helper function to allow a node to be used as a different
|
||||
// Node in operators. For example, a variable is really just a "string", so
|
||||
// having the Equality operator check for "String" or "StringVar" is just
|
||||
// excessive. Instead, we can just have the variable implement this function.
|
||||
UseAs() Node
|
||||
}
|
||||
|
||||
// BooleanNode is a node that returns a true/false when evaluated.
|
||||
type BooleanNode interface {
|
||||
Node
|
||||
IsBooleanNode()
|
||||
}
|
||||
|
||||
type RegoSource string
|
||||
|
||||
type invalidNode struct{}
|
||||
|
||||
func (invalidNode) UseAs() Node { return invalidNode{} }
|
||||
|
||||
func (invalidNode) SQLString(cfg *SQLGenerator) string {
|
||||
cfg.AddError(xerrors.Errorf("invalid node called"))
|
||||
return "invalid_type"
|
||||
}
|
||||
|
||||
func IsPrimitive(n Node) bool {
|
||||
switch n.(type) {
|
||||
case AstBoolean, AstString, AstNumber:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type AstNumber struct {
|
||||
Source RegoSource
|
||||
// Value is intentionally vague as to if it's an integer or a float.
|
||||
// This defers that decision to the user. Rego keeps all numbers in this
|
||||
// type. If we were to source the type from something other than Rego,
|
||||
// we might want to make a Float and Int type which keep the original
|
||||
// precision.
|
||||
Value json.Number
|
||||
}
|
||||
|
||||
func Number(source RegoSource, v json.Number) Node {
|
||||
return AstNumber{Value: v, Source: source}
|
||||
}
|
||||
|
||||
func (AstNumber) UseAs() Node { return AstNumber{} }
|
||||
|
||||
func (n AstNumber) SQLString(_ *SQLGenerator) string {
|
||||
return n.Value.String()
|
||||
}
|
||||
|
||||
func (n AstNumber) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstNumber:
|
||||
return basicSQLEquality(cfg, not, n, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T", n, equalsOp(not), other)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type astParenthesis struct {
|
||||
Value BooleanNode
|
||||
}
|
||||
|
||||
// BoolParenthesis wraps the given boolean node in parens.
|
||||
// This is useful for grouping and avoiding ambiguity. This does not work for
|
||||
// mathematical parenthesis to change order of operations.
|
||||
func BoolParenthesis(value BooleanNode) BooleanNode {
|
||||
// Wrapping primitives is useless.
|
||||
if IsPrimitive(value) {
|
||||
return value
|
||||
}
|
||||
|
||||
// Unwrap any existing parens. Do not add excess parens.
|
||||
if p, ok := value.(astParenthesis); ok {
|
||||
return BoolParenthesis(p.Value)
|
||||
}
|
||||
return astParenthesis{Value: value}
|
||||
}
|
||||
|
||||
func (astParenthesis) IsBooleanNode() {}
|
||||
func (p astParenthesis) UseAs() Node { return p.Value.UseAs() }
|
||||
func (p astParenthesis) SQLString(cfg *SQLGenerator) string {
|
||||
return "(" + p.Value.SQLString(cfg) + ")"
|
||||
}
|
||||
|
||||
func (p astParenthesis) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
if supp, ok := p.Value.(SupportsEquality); ok {
|
||||
return supp.EqualsSQLString(cfg, not, other)
|
||||
}
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T", p.Value, equalsOp(not), other)
|
||||
}
|
||||
|
||||
func (p astParenthesis) ContainsSQL(cfg *SQLGenerator, other Node) (string, error) {
|
||||
if supp, ok := p.Value.(SupportsContains); ok {
|
||||
return supp.ContainsSQL(cfg, other)
|
||||
}
|
||||
return "", xerrors.Errorf("unsupported contains: %T %T", p.Value, other)
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type AstString struct {
|
||||
Source RegoSource
|
||||
Value string
|
||||
}
|
||||
|
||||
func String(v string) Node {
|
||||
return AstString{Value: v, Source: RegoSource(v)}
|
||||
}
|
||||
|
||||
func (AstString) UseAs() Node { return AstString{} }
|
||||
|
||||
func (s AstString) SQLString(_ *SQLGenerator) string {
|
||||
return "'" + s.Value + "'"
|
||||
}
|
||||
|
||||
func (s AstString) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstString:
|
||||
return basicSQLEquality(cfg, not, s, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T", s, equalsOp(not), other)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
type VariableMatcher interface {
|
||||
ConvertVariable(rego ast.Ref) (Node, bool)
|
||||
}
|
||||
|
||||
type VariableConverter struct {
|
||||
converters []VariableMatcher
|
||||
}
|
||||
|
||||
func NewVariableConverter() *VariableConverter {
|
||||
return &VariableConverter{}
|
||||
}
|
||||
|
||||
func (vc *VariableConverter) RegisterMatcher(m ...VariableMatcher) *VariableConverter {
|
||||
vc.converters = append(vc.converters, m...)
|
||||
// Returns the VariableConverter for easier instantiation
|
||||
return vc
|
||||
}
|
||||
|
||||
func (vc *VariableConverter) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
for _, c := range vc.converters {
|
||||
if n, ok := c.ConvertVariable(rego); ok {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// RegoVarPath will consume the following terms from the given rego Ref and
|
||||
// return the remaining terms. If the path does not fully match, an error is
|
||||
// returned. The first term must always be a Var.
|
||||
func RegoVarPath(path []string, terms []*ast.Term) ([]*ast.Term, error) {
|
||||
if len(terms) < len(path) {
|
||||
return nil, xerrors.Errorf("path %s longer than rego path %s", path, terms)
|
||||
}
|
||||
|
||||
if len(terms) == 0 || len(path) == 0 {
|
||||
return nil, xerrors.Errorf("path %s and rego path %s must not be empty", path, terms)
|
||||
}
|
||||
|
||||
varTerm, ok := terms[0].Value.(ast.Var)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("expected var, got %T", terms[0])
|
||||
}
|
||||
|
||||
if string(varTerm) != path[0] {
|
||||
return nil, xerrors.Errorf("expected var %s, got %s", path[0], varTerm)
|
||||
}
|
||||
|
||||
for i := 1; i < len(path); i++ {
|
||||
nextTerm, ok := terms[i].Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("expected ast.string, got %T", terms[i])
|
||||
}
|
||||
|
||||
if string(nextTerm) != path[i] {
|
||||
return nil, xerrors.Errorf("expected string %s, got %s", path[i], nextTerm)
|
||||
}
|
||||
}
|
||||
|
||||
return terms[len(path):], nil
|
||||
}
|
||||
|
||||
var _ VariableMatcher = astStringVar{}
|
||||
var _ Node = astStringVar{}
|
||||
|
||||
// astStringVar is any variable that represents a string.
|
||||
type astStringVar struct {
|
||||
Source RegoSource
|
||||
FieldPath []string
|
||||
ColumnString string
|
||||
}
|
||||
|
||||
func StringVarMatcher(sqlString string, regoPath []string) VariableMatcher {
|
||||
return astStringVar{FieldPath: regoPath, ColumnString: sqlString}
|
||||
}
|
||||
|
||||
func (astStringVar) UseAs() Node { return AstString{} }
|
||||
|
||||
// ConvertVariable will return a new astStringVar Node if the given rego Ref
|
||||
// matches this astStringVar.
|
||||
func (s astStringVar) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
left, err := RegoVarPath(s.FieldPath, rego)
|
||||
if err == nil && len(left) == 0 {
|
||||
return astStringVar{
|
||||
Source: RegoSource(rego.String()),
|
||||
FieldPath: s.FieldPath,
|
||||
ColumnString: s.ColumnString,
|
||||
}, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s astStringVar) SQLString(_ *SQLGenerator) string {
|
||||
return s.ColumnString
|
||||
}
|
||||
|
||||
func (s astStringVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstString:
|
||||
return basicSQLEquality(cfg, not, s, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T", s, equalsOp(not), other)
|
||||
}
|
||||
}
|
|
@ -316,25 +316,27 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
|
|||
func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
organization := httpmw.OrganizationParam(r)
|
||||
templates, err := api.Database.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
|
||||
OrganizationID: organization.ID,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
|
||||
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceTemplate.Type)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching templates in organization.",
|
||||
Message: "Internal error preparing sql filter.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Filter templates based on rbac permissions
|
||||
templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates)
|
||||
templates, err := api.Database.GetAuthorizedTemplates(ctx, database.GetTemplatesWithFilterParams{
|
||||
OrganizationID: organization.ID,
|
||||
}, prepared)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching templates.",
|
||||
Message: "Internal error fetching templates in organization.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
|
|
|
@ -118,7 +118,8 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
|
|||
filter.OwnerUsername = ""
|
||||
}
|
||||
|
||||
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
|
||||
// Workspaces do not have ACL columns.
|
||||
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error preparing sql filter.",
|
||||
|
@ -127,7 +128,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter)
|
||||
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, prepared)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspaces.",
|
||||
|
|
|
@ -31,7 +31,7 @@ func (api *API) postGroupByOrganization(rw http.ResponseWriter, r *http.Request)
|
|||
)
|
||||
defer commitAudit()
|
||||
|
||||
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceGroup) {
|
||||
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceGroup.InOrg(org.ID)) {
|
||||
http.NotFound(rw, r)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -11,7 +11,9 @@ import (
|
|||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/testutil"
|
||||
|
@ -747,3 +749,210 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
require.Equal(t, http.StatusNotFound, cerr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
// TestTemplateAccess tests the rego -> sql conversion. We need to implement
|
||||
// this test on at least 1 table type to ensure that the conversion is correct.
|
||||
// The rbac tests only assert against static SQL queries.
|
||||
// This is a full rbac test of many of the common role combinations.
|
||||
//
|
||||
//nolint:tparallel
|
||||
func TestTemplateAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
ownerClient := coderdenttest.New(t, nil)
|
||||
owner := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
_ = coderdenttest.AddLicense(t, ownerClient, coderdenttest.LicenseOptions{
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
type coderUser struct {
|
||||
*codersdk.Client
|
||||
User codersdk.User
|
||||
}
|
||||
|
||||
type orgSetup struct {
|
||||
Admin coderUser
|
||||
MemberInGroup coderUser
|
||||
MemberNoGroup coderUser
|
||||
|
||||
DefaultTemplate codersdk.Template
|
||||
AllRead codersdk.Template
|
||||
UserACL codersdk.Template
|
||||
GroupACL codersdk.Template
|
||||
|
||||
Group codersdk.Group
|
||||
Org codersdk.Organization
|
||||
}
|
||||
|
||||
// Create the following users
|
||||
// - owner: Site wide owner
|
||||
// - template-admin
|
||||
// - org-admin (org 1)
|
||||
// - org-admin (org 2)
|
||||
// - org-member (org 1)
|
||||
// - org-member (org 2)
|
||||
|
||||
// Create the following templates in each org
|
||||
// - template 1, default acls
|
||||
// - template 2, all_user read
|
||||
// - template 3, user_acl read for member
|
||||
// - template 4, group_acl read for groupMember
|
||||
|
||||
templateAdmin := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin())
|
||||
|
||||
makeTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, acl codersdk.UpdateTemplateACL) codersdk.Template {
|
||||
version := coderdtest.CreateTemplateVersion(t, client, orgID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
|
||||
|
||||
err := client.UpdateTemplateACL(ctx, template.ID, acl)
|
||||
require.NoError(t, err, "failed to update template acl")
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
makeOrg := func(t *testing.T) orgSetup {
|
||||
// Make org
|
||||
orgName, err := cryptorand.String(5)
|
||||
require.NoError(t, err, "org name")
|
||||
|
||||
// Make users
|
||||
newOrg, err := ownerClient.CreateOrganization(ctx, codersdk.CreateOrganizationRequest{Name: orgName})
|
||||
require.NoError(t, err, "failed to create org")
|
||||
|
||||
adminCli, adminUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgAdmin(newOrg.ID))
|
||||
groupMemCli, groupMemUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgMember(newOrg.ID))
|
||||
memberCli, memberUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgMember(newOrg.ID))
|
||||
|
||||
// Make group
|
||||
group, err := adminCli.CreateGroup(ctx, newOrg.ID, codersdk.CreateGroupRequest{
|
||||
Name: "SingleUser",
|
||||
})
|
||||
require.NoError(t, err, "failed to create group")
|
||||
|
||||
group, err = adminCli.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{
|
||||
AddUsers: []string{groupMemUsr.ID.String()},
|
||||
})
|
||||
require.NoError(t, err, "failed to add user to group")
|
||||
|
||||
// Make templates
|
||||
|
||||
return orgSetup{
|
||||
Admin: coderUser{Client: adminCli, User: adminUsr},
|
||||
MemberInGroup: coderUser{Client: groupMemCli, User: groupMemUsr},
|
||||
MemberNoGroup: coderUser{Client: memberCli, User: memberUsr},
|
||||
Org: newOrg,
|
||||
Group: group,
|
||||
|
||||
DefaultTemplate: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
|
||||
GroupPerms: map[string]codersdk.TemplateRole{
|
||||
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
|
||||
},
|
||||
}),
|
||||
AllRead: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
|
||||
GroupPerms: map[string]codersdk.TemplateRole{
|
||||
newOrg.ID.String(): codersdk.TemplateRoleUse,
|
||||
},
|
||||
}),
|
||||
UserACL: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
|
||||
GroupPerms: map[string]codersdk.TemplateRole{
|
||||
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
|
||||
},
|
||||
UserPerms: map[string]codersdk.TemplateRole{
|
||||
memberUsr.ID.String(): codersdk.TemplateRoleUse,
|
||||
},
|
||||
}),
|
||||
GroupACL: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
|
||||
GroupPerms: map[string]codersdk.TemplateRole{
|
||||
group.ID.String(): codersdk.TemplateRoleUse,
|
||||
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Make 2 organizations
|
||||
orgs := []orgSetup{
|
||||
makeOrg(t),
|
||||
makeOrg(t),
|
||||
}
|
||||
|
||||
testTemplateRead := func(t *testing.T, org orgSetup, usr *codersdk.Client, read []codersdk.Template) {
|
||||
found, err := usr.TemplatesByOrganization(ctx, org.Org.ID)
|
||||
require.NoError(t, err, "failed to get templates")
|
||||
|
||||
exp := make(map[uuid.UUID]codersdk.Template)
|
||||
for _, tmpl := range read {
|
||||
exp[tmpl.ID] = tmpl
|
||||
}
|
||||
|
||||
for _, f := range found {
|
||||
if _, ok := exp[f.ID]; !ok {
|
||||
t.Errorf("found unexpected template %q", f.Name)
|
||||
}
|
||||
delete(exp, f.ID)
|
||||
}
|
||||
require.Len(t, exp, 0, "expected templates not found")
|
||||
}
|
||||
|
||||
// nolint:paralleltest
|
||||
t.Run("OwnerReadAll", func(t *testing.T) {
|
||||
for _, o := range orgs {
|
||||
// Owners can read all templates in all orgs
|
||||
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
|
||||
testTemplateRead(t, o, ownerClient, exp)
|
||||
}
|
||||
})
|
||||
|
||||
// nolint:paralleltest
|
||||
t.Run("TemplateAdminReadAll", func(t *testing.T) {
|
||||
for _, o := range orgs {
|
||||
// Template Admins can read all templates in all orgs
|
||||
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
|
||||
testTemplateRead(t, o, templateAdmin, exp)
|
||||
}
|
||||
})
|
||||
|
||||
// nolint:paralleltest
|
||||
t.Run("OrgAdminReadAllTheirs", func(t *testing.T) {
|
||||
for i, o := range orgs {
|
||||
cli := o.Admin.Client
|
||||
// Only read their own org
|
||||
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
|
||||
testTemplateRead(t, o, cli, exp)
|
||||
|
||||
other := orgs[(i+1)%len(orgs)]
|
||||
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
|
||||
testTemplateRead(t, other, cli, []codersdk.Template{})
|
||||
}
|
||||
})
|
||||
|
||||
// nolint:paralleltest
|
||||
t.Run("TestMemberNoGroup", func(t *testing.T) {
|
||||
for i, o := range orgs {
|
||||
cli := o.MemberNoGroup.Client
|
||||
// Only read their own org
|
||||
exp := []codersdk.Template{o.AllRead, o.UserACL}
|
||||
testTemplateRead(t, o, cli, exp)
|
||||
|
||||
other := orgs[(i+1)%len(orgs)]
|
||||
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
|
||||
testTemplateRead(t, other, cli, []codersdk.Template{})
|
||||
}
|
||||
})
|
||||
|
||||
// nolint:paralleltest
|
||||
t.Run("TestMemberInGroup", func(t *testing.T) {
|
||||
for i, o := range orgs {
|
||||
cli := o.MemberInGroup.Client
|
||||
// Only read their own org
|
||||
exp := []codersdk.Template{o.AllRead, o.GroupACL}
|
||||
testTemplateRead(t, o, cli, exp)
|
||||
|
||||
other := orgs[(i+1)%len(orgs)]
|
||||
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
|
||||
testTemplateRead(t, other, cli, []codersdk.Template{})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue