mirror of https://github.com/coder/coder.git
chore: Rewrite rbac rego -> SQL clause (#5138)
* chore: Rewrite rbac rego -> SQL clause Previous code was challenging to read with edge cases - bug: OrgAdmin could not make new groups - Also refactor some function names
This commit is contained in:
parent
d5ab4fdeb8
commit
ab9298f382
|
@ -5,7 +5,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
@ -95,19 +94,14 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
|
||||||
// from postgres are already authorized, and the caller does not need to
|
// from postgres are already authorized, and the caller does not need to
|
||||||
// call 'Authorize()' on the returned objects.
|
// call 'Authorize()' on the returned objects.
|
||||||
// Note the authorization is only for the given action and object type.
|
// Note the authorization is only for the given action and object type.
|
||||||
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
|
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) {
|
||||||
roles := httpmw.UserAuthorization(r)
|
roles := httpmw.UserAuthorization(r)
|
||||||
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType)
|
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("prepare filter: %w", err)
|
return nil, xerrors.Errorf("prepare filter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
filter, err := prepared.Compile()
|
return prepared, nil
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("compile filter: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return filter, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkAuthorization returns if the current API key can use the given
|
// checkAuthorization returns if the current API key can use the given
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/database/databasefake"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -16,12 +18,19 @@ import (
|
||||||
|
|
||||||
"github.com/coder/coder/coderd"
|
"github.com/coder/coder/coderd"
|
||||||
"github.com/coder/coder/coderd/rbac"
|
"github.com/coder/coder/coderd/rbac"
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql"
|
||||||
"github.com/coder/coder/codersdk"
|
"github.com/coder/coder/codersdk"
|
||||||
"github.com/coder/coder/provisioner/echo"
|
"github.com/coder/coder/provisioner/echo"
|
||||||
"github.com/coder/coder/provisionersdk/proto"
|
"github.com/coder/coder/provisionersdk/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||||
|
// For any route using SQL filters, we need to know if the database is an
|
||||||
|
// in memory fake. This is because the in memory fake does not use SQL, and
|
||||||
|
// still uses rego. So this boolean indicates how to assert the expected
|
||||||
|
// behavior.
|
||||||
|
_, isMemoryDB := a.api.Database.(databasefake.FakeDatabase)
|
||||||
|
|
||||||
// Some quick reused objects
|
// Some quick reused objects
|
||||||
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||||
workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||||
|
@ -125,11 +134,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||||
AssertAction: rbac.ActionCreate,
|
AssertAction: rbac.ActionCreate,
|
||||||
AssertObject: workspaceExecObj,
|
AssertObject: workspaceExecObj,
|
||||||
},
|
},
|
||||||
"GET:/api/v2/organizations/{organization}/templates": {
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
AssertAction: rbac.ActionRead,
|
|
||||||
AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID),
|
|
||||||
},
|
|
||||||
"POST:/api/v2/organizations/{organization}/templates": {
|
"POST:/api/v2/organizations/{organization}/templates": {
|
||||||
AssertAction: rbac.ActionCreate,
|
AssertAction: rbac.ActionCreate,
|
||||||
AssertObject: rbac.ResourceTemplate.InOrg(a.Organization.ID),
|
AssertObject: rbac.ResourceTemplate.InOrg(a.Organization.ID),
|
||||||
|
@ -240,7 +244,18 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||||
"GET:/api/v2/organizations/{organization}/templateversions/{templateversionname}": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
"GET:/api/v2/organizations/{organization}/templateversions/{templateversionname}": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||||
|
|
||||||
// Endpoints that use the SQLQuery filter.
|
// Endpoints that use the SQLQuery filter.
|
||||||
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
|
"GET:/api/v2/workspaces/": {
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
NoAuthorize: !isMemoryDB,
|
||||||
|
AssertAction: rbac.ActionRead,
|
||||||
|
AssertObject: rbac.ResourceWorkspace,
|
||||||
|
},
|
||||||
|
"GET:/api/v2/organizations/{organization}/templates": {
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
NoAuthorize: !isMemoryDB,
|
||||||
|
AssertAction: rbac.ActionRead,
|
||||||
|
AssertObject: rbac.ResourceTemplate,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Routes like proxy routes support all HTTP methods. A helper func to expand
|
// Routes like proxy routes support all HTTP methods. A helper func to expand
|
||||||
|
@ -549,10 +564,10 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
|
||||||
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object)
|
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile returns a compiled version of the authorizer that will work for
|
// CompileToSQL returns a compiled version of the authorizer that will work for
|
||||||
// in memory databases. This fake version will not work against a SQL database.
|
// in memory databases. This fake version will not work against a SQL database.
|
||||||
func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
|
func (fakePreparedAuthorizer) CompileToSQL(_ regosql.ConvertConfig) (string, error) {
|
||||||
return f, nil
|
return "", xerrors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
|
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
|
||||||
|
@ -565,10 +580,3 @@ func (f fakePreparedAuthorizer) RegoString() string {
|
||||||
}
|
}
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string {
|
|
||||||
if f.HardCodedSQLString != "" {
|
|
||||||
return f.HardCodedSQLString
|
|
||||||
}
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
|
@ -20,6 +20,13 @@ import (
|
||||||
"github.com/coder/coder/coderd/util/slice"
|
"github.com/coder/coder/coderd/util/slice"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// FakeDatabase is helpful for knowing if the underlying db is an in memory fake
|
||||||
|
// database. This is only in the databasefake package, so will only be used
|
||||||
|
// by unit tests.
|
||||||
|
type FakeDatabase interface {
|
||||||
|
IsFakeDB()
|
||||||
|
}
|
||||||
|
|
||||||
var errDuplicateKey = &pq.Error{
|
var errDuplicateKey = &pq.Error{
|
||||||
Code: "23505",
|
Code: "23505",
|
||||||
Message: "duplicate key value violates unique constraint",
|
Message: "duplicate key value violates unique constraint",
|
||||||
|
@ -117,6 +124,7 @@ type data struct {
|
||||||
lastLicenseID int32
|
lastLicenseID int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fakeQuerier) IsFakeDB() {}
|
||||||
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
|
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
@ -488,11 +496,20 @@ func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.Get
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||||
q.mutex.RLock()
|
q.mutex.RLock()
|
||||||
defer q.mutex.RUnlock()
|
defer q.mutex.RUnlock()
|
||||||
|
|
||||||
users := append([]database.User{}, q.users...)
|
users := make([]database.User, 0, len(q.users))
|
||||||
|
|
||||||
|
for _, user := range q.users {
|
||||||
|
// If the filter exists, ensure the object is authorized.
|
||||||
|
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
|
||||||
if params.Deleted {
|
if params.Deleted {
|
||||||
tmp := make([]database.User, 0, len(users))
|
tmp := make([]database.User, 0, len(users))
|
||||||
|
@ -539,13 +556,6 @@ func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.
|
||||||
users = usersFilteredByRole
|
users = usersFilteredByRole
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range q.workspaces {
|
|
||||||
// If the filter exists, ensure the object is authorized.
|
|
||||||
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return int64(len(users)), nil
|
return int64(len(users)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -750,7 +760,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocyclo
|
//nolint:gocyclo
|
||||||
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.GetWorkspacesRow, error) {
|
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||||
q.mutex.RLock()
|
q.mutex.RLock()
|
||||||
defer q.mutex.RUnlock()
|
defer q.mutex.RUnlock()
|
||||||
|
|
||||||
|
@ -923,7 +933,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the filter exists, ensure the object is authorized.
|
// If the filter exists, ensure the object is authorized.
|
||||||
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
|
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
workspaces = append(workspaces, workspace)
|
workspaces = append(workspaces, workspace)
|
||||||
|
@ -1505,12 +1515,20 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd
|
||||||
return database.Template{}, sql.ErrNoRows
|
return database.Template{}, sql.ErrNoRows
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
|
func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
|
||||||
|
return q.GetAuthorizedTemplates(ctx, arg, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||||
q.mutex.RLock()
|
q.mutex.RLock()
|
||||||
defer q.mutex.RUnlock()
|
defer q.mutex.RUnlock()
|
||||||
|
|
||||||
var templates []database.Template
|
var templates []database.Template
|
||||||
for _, template := range q.templates {
|
for _, template := range q.templates {
|
||||||
|
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if template.Deleted != arg.Deleted {
|
if template.Deleted != arg.Deleted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,6 +74,7 @@ func TestExactMethods(t *testing.T) {
|
||||||
extraFakeMethods := map[string]string{
|
extraFakeMethods := map[string]string{
|
||||||
// Example
|
// Example
|
||||||
// "SortFakeLists": "Helper function used",
|
// "SortFakeLists": "Helper function used",
|
||||||
|
"IsFakeDB": "Helper function used for unit testing",
|
||||||
}
|
}
|
||||||
|
|
||||||
fake := reflect.TypeOf(databasefake.New())
|
fake := reflect.TypeOf(databasefake.New())
|
||||||
|
|
|
@ -5,12 +5,16 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/coder/coder/coderd/rbac"
|
"github.com/coder/coder/coderd/rbac"
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql"
|
||||||
|
)
|
||||||
|
|
||||||
"github.com/google/uuid"
|
const (
|
||||||
"golang.org/x/xerrors"
|
authorizedQueryPlaceholder = "-- @authorize_filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// customQuerier encompasses all non-generated queries.
|
// customQuerier encompasses all non-generated queries.
|
||||||
|
@ -23,10 +27,70 @@ type customQuerier interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type templateQuerier interface {
|
type templateQuerier interface {
|
||||||
|
GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error)
|
||||||
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
|
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
|
||||||
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
|
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) {
|
||||||
|
authorizedFilter, err := prepared.CompileToSQL(regosql.ConvertConfig{
|
||||||
|
VariableConverter: regosql.TemplateConverter(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered, err := insertAuthorizedFilter(getTemplatesWithFilter, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The name comment is for metric tracking
|
||||||
|
query := fmt.Sprintf("-- name: GetAuthorizedTemplates :many\n%s", filtered)
|
||||||
|
rows, err := q.db.QueryContext(ctx, query,
|
||||||
|
arg.Deleted,
|
||||||
|
arg.OrganizationID,
|
||||||
|
arg.ExactName,
|
||||||
|
pq.Array(arg.IDs),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []Template
|
||||||
|
for rows.Next() {
|
||||||
|
var i Template
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
&i.OrganizationID,
|
||||||
|
&i.Deleted,
|
||||||
|
&i.Name,
|
||||||
|
&i.Provisioner,
|
||||||
|
&i.ActiveVersionID,
|
||||||
|
&i.Description,
|
||||||
|
&i.DefaultTTL,
|
||||||
|
&i.CreatedBy,
|
||||||
|
&i.Icon,
|
||||||
|
&i.UserACL,
|
||||||
|
&i.GroupACL,
|
||||||
|
&i.DisplayName,
|
||||||
|
&i.AllowUserCancelWorkspaceJobs,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
type TemplateUser struct {
|
type TemplateUser struct {
|
||||||
User
|
User
|
||||||
Actions Actions `db:"actions"`
|
Actions Actions `db:"actions"`
|
||||||
|
@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
|
||||||
}
|
}
|
||||||
|
|
||||||
type workspaceQuerier interface {
|
type workspaceQuerier interface {
|
||||||
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error)
|
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
|
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
|
||||||
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
|
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
|
||||||
// clause.
|
// clause.
|
||||||
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) {
|
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) {
|
||||||
|
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
|
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
|
||||||
// authorizedFilter between the end of the where clause and those statements.
|
// authorizedFilter between the end of the where clause and those statements.
|
||||||
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
filtered, err := insertAuthorizedFilter(getWorkspaces, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// The name comment is for metric tracking
|
// The name comment is for metric tracking
|
||||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter)
|
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered)
|
||||||
rows, err := q.db.QueryContext(ctx, query,
|
rows, err := q.db.QueryContext(ctx, query,
|
||||||
arg.Deleted,
|
arg.Deleted,
|
||||||
arg.Status,
|
arg.Status,
|
||||||
|
@ -172,12 +245,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuerier interface {
|
type userQuerier interface {
|
||||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
|
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||||
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
|
||||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
|
if err != nil {
|
||||||
|
return -1, xerrors.Errorf("compile authorized filter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||||
|
if err != nil {
|
||||||
|
return -1, xerrors.Errorf("insert authorized filter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
|
||||||
row := q.db.QueryRowContext(ctx, query,
|
row := q.db.QueryRowContext(ctx, query,
|
||||||
arg.Deleted,
|
arg.Deleted,
|
||||||
arg.Search,
|
arg.Search,
|
||||||
|
@ -185,6 +267,14 @@ func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFiltered
|
||||||
pq.Array(arg.RbacRole),
|
pq.Array(arg.RbacRole),
|
||||||
)
|
)
|
||||||
var count int64
|
var count int64
|
||||||
err := row.Scan(&count)
|
err = row.Scan(&count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||||
|
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||||
|
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||||
|
}
|
||||||
|
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsAuthorizedQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
query := `SELECT true;`
|
||||||
|
_, err := insertAuthorizedFilter(query, "")
|
||||||
|
require.ErrorContains(t, err, "does not contain authorized replace string", "ensure replace string")
|
||||||
|
}
|
|
@ -3197,6 +3197,8 @@ WHERE
|
||||||
id = ANY($4)
|
id = ANY($4)
|
||||||
ELSE true
|
ELSE true
|
||||||
END
|
END
|
||||||
|
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
|
||||||
|
-- @authorize_filter
|
||||||
ORDER BY (name, id) ASC
|
ORDER BY (name, id) ASC
|
||||||
`
|
`
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,8 @@ WHERE
|
||||||
id = ANY(@ids)
|
id = ANY(@ids)
|
||||||
ELSE true
|
ELSE true
|
||||||
END
|
END
|
||||||
|
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
|
||||||
|
-- @authorize_filter
|
||||||
ORDER BY (name, id) ASC
|
ORDER BY (name, id) ASC
|
||||||
;
|
;
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql"
|
||||||
"github.com/coder/coder/coderd/tracing"
|
"github.com/coder/coder/coderd/tracing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ type Authorizer interface {
|
||||||
|
|
||||||
type PreparedAuthorized interface {
|
type PreparedAuthorized interface {
|
||||||
Authorize(ctx context.Context, object Object) error
|
Authorize(ctx context.Context, object Object) error
|
||||||
Compile() (AuthorizeFilter, error)
|
CompileToSQL(cfg regosql.ConvertConfig) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter takes in a list of objects, and will filter the list removing all
|
// Filter takes in a list of objects, and will filter the list removing all
|
||||||
|
|
|
@ -847,7 +847,7 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
|
||||||
|
|
||||||
// Ensure the partial can compile to a SQL clause.
|
// Ensure the partial can compile to a SQL clause.
|
||||||
// This does not guarantee that the clause is valid SQL.
|
// This does not guarantee that the clause is valid SQL.
|
||||||
_, err = Compile(partialAuthz)
|
_, err = Compile(ConfigWithACL(), partialAuthz)
|
||||||
require.NoError(t, err, "compile prepared authorizer")
|
require.NoError(t, err, "compile prepared authorizer")
|
||||||
|
|
||||||
// Also check the rego policy can form a valid partial query result.
|
// Also check the rego policy can form a valid partial query result.
|
||||||
|
|
|
@ -3,11 +3,11 @@ package rbac
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"golang.org/x/xerrors"
|
|
||||||
|
|
||||||
"github.com/open-policy-agent/opa/ast"
|
"github.com/open-policy-agent/opa/ast"
|
||||||
"github.com/open-policy-agent/opa/rego"
|
"github.com/open-policy-agent/opa/rego"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql"
|
||||||
"github.com/coder/coder/coderd/tracing"
|
"github.com/coder/coder/coderd/tracing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,12 +28,12 @@ type PartialAuthorizer struct {
|
||||||
|
|
||||||
var _ PreparedAuthorized = (*PartialAuthorizer)(nil)
|
var _ PreparedAuthorized = (*PartialAuthorizer)(nil)
|
||||||
|
|
||||||
func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) {
|
func (pa *PartialAuthorizer) CompileToSQL(cfg regosql.ConvertConfig) (string, error) {
|
||||||
filter, err := Compile(pa)
|
filter, err := Compile(cfg, pa)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("compile: %w", err)
|
return "", xerrors.Errorf("compile: %w", err)
|
||||||
}
|
}
|
||||||
return filter, nil
|
return filter.SQLString(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error {
|
func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error {
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
package partial
|
|
@ -2,635 +2,63 @@ package rbac
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/open-policy-agent/opa/ast"
|
"github.com/coder/coder/coderd/rbac/regosql"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||||
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TermType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
VarTypeJsonbTextArray TermType = "jsonb-text-array"
|
|
||||||
VarTypeText TermType = "text"
|
|
||||||
VarTypeBoolean TermType = "boolean"
|
|
||||||
// VarTypeSkip means this variable does not exist to use.
|
|
||||||
VarTypeSkip TermType = "skip"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SQLColumn struct {
|
|
||||||
// RegoMatch matches the original variable string.
|
|
||||||
// If it is a match, then this variable config will apply.
|
|
||||||
RegoMatch *regexp.Regexp
|
|
||||||
// ColumnSelect is the name of the postgres column to select.
|
|
||||||
// Can use capture groups from RegoMatch with $1, $2, etc.
|
|
||||||
ColumnSelect string
|
|
||||||
|
|
||||||
// Type indicates the postgres type of the column. Some expressions will
|
|
||||||
// need to know this in order to determine what SQL to produce.
|
|
||||||
// An example is if the variable is a jsonb array, the "contains" SQL
|
|
||||||
// query is `variable ? 'value'` instead of `'value' = ANY(variable)`.
|
|
||||||
// This type is only needed to be provided
|
|
||||||
Type TermType
|
|
||||||
}
|
|
||||||
|
|
||||||
type SQLConfig struct {
|
|
||||||
// Variables is a map of rego variable names to SQL columns.
|
|
||||||
// Example:
|
|
||||||
// "input\.object\.org_owner": SQLColumn{
|
|
||||||
// ColumnSelect: "organization_id",
|
|
||||||
// Type: VarTypeUUID
|
|
||||||
// }
|
|
||||||
// "input\.object\.owner": SQLColumn{
|
|
||||||
// ColumnSelect: "owner_id",
|
|
||||||
// Type: VarTypeUUID
|
|
||||||
// }
|
|
||||||
// "input\.object\.group_acl\.(.*)": SQLColumn{
|
|
||||||
// ColumnSelect: "group_acl->$1",
|
|
||||||
// Type: VarTypeJsonbTextArray
|
|
||||||
// }
|
|
||||||
Variables []SQLColumn
|
|
||||||
}
|
|
||||||
|
|
||||||
func DefaultConfig() SQLConfig {
|
|
||||||
return SQLConfig{
|
|
||||||
Variables: []SQLColumn{
|
|
||||||
{
|
|
||||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
|
|
||||||
ColumnSelect: "group_acl->$1",
|
|
||||||
Type: VarTypeJsonbTextArray,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
|
|
||||||
ColumnSelect: "user_acl->$1",
|
|
||||||
Type: VarTypeJsonbTextArray,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
|
|
||||||
ColumnSelect: "organization_id :: text",
|
|
||||||
Type: VarTypeText,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
|
|
||||||
ColumnSelect: "owner_id :: text",
|
|
||||||
Type: VarTypeText,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NoACLConfig() SQLConfig {
|
|
||||||
return SQLConfig{
|
|
||||||
Variables: []SQLColumn{
|
|
||||||
{
|
|
||||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
|
|
||||||
ColumnSelect: "",
|
|
||||||
Type: VarTypeSkip,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
|
|
||||||
ColumnSelect: "",
|
|
||||||
Type: VarTypeSkip,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
|
|
||||||
ColumnSelect: "organization_id :: text",
|
|
||||||
Type: VarTypeText,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
|
|
||||||
ColumnSelect: "owner_id :: text",
|
|
||||||
Type: VarTypeText,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthorizeFilter interface {
|
type AuthorizeFilter interface {
|
||||||
Expression
|
SQLString() string
|
||||||
// Eval is required for the fake in memory database to work. The in memory
|
|
||||||
// database can use this function to filter the results.
|
|
||||||
Eval(object Object) bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// expressionTop handles Eval(object Object) for in memory expressions
|
type authorizedSQLFilter struct {
|
||||||
type expressionTop struct {
|
sqlString string
|
||||||
Expression
|
auth *PartialAuthorizer
|
||||||
Auth *PartialAuthorizer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e expressionTop) Eval(object Object) bool {
|
func ConfigWithACL() regosql.ConvertConfig {
|
||||||
return e.Auth.Authorize(context.Background(), object) == nil
|
return regosql.ConvertConfig{
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile will convert a rego query AST into our custom types. The output is
|
func ConfigWithoutACL() regosql.ConvertConfig {
|
||||||
// an AST that can be used to generate SQL.
|
return regosql.ConvertConfig{
|
||||||
func Compile(pa *PartialAuthorizer) (AuthorizeFilter, error) {
|
VariableConverter: regosql.NoACLConverter(),
|
||||||
partialQueries := pa.partialQueries
|
}
|
||||||
if len(partialQueries.Support) > 0 {
|
}
|
||||||
return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support))
|
|
||||||
|
func Compile(cfg regosql.ConvertConfig, pa *PartialAuthorizer) (AuthorizeFilter, error) {
|
||||||
|
root, err := regosql.ConvertRegoAst(cfg, pa.partialQueries)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("convert rego ast: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 0 queries means the result is "undefined". This is the same as "false".
|
// Generate the SQL
|
||||||
if len(partialQueries.Queries) == 0 {
|
gen := sqltypes.NewSQLGenerator()
|
||||||
return &termBoolean{
|
sqlString := root.SQLString(gen)
|
||||||
base: base{Rego: "false"},
|
if len(gen.Errors()) > 0 {
|
||||||
Value: false,
|
var errStrings []string
|
||||||
}, nil
|
for _, err := range gen.Errors() {
|
||||||
|
errStrings = append(errStrings, err.Error())
|
||||||
|
}
|
||||||
|
return nil, xerrors.Errorf("sql generation errors: %v", strings.Join(errStrings, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Abort early if any of the "OR"'d expressions are the empty string.
|
return &authorizedSQLFilter{
|
||||||
// This is the same as "true".
|
sqlString: sqlString,
|
||||||
for _, query := range partialQueries.Queries {
|
auth: pa,
|
||||||
if query.String() == "" {
|
|
||||||
return &termBoolean{
|
|
||||||
base: base{Rego: "true"},
|
|
||||||
Value: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]Expression, 0, len(partialQueries.Queries))
|
|
||||||
var builder strings.Builder
|
|
||||||
for i := range partialQueries.Queries {
|
|
||||||
query, err := processQuery(partialQueries.Queries[i])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
result = append(result, query)
|
|
||||||
if i != 0 {
|
|
||||||
builder.WriteString("\n")
|
|
||||||
}
|
|
||||||
builder.WriteString(partialQueries.Queries[i].String())
|
|
||||||
}
|
|
||||||
exp := expOr{
|
|
||||||
base: base{
|
|
||||||
Rego: builder.String(),
|
|
||||||
},
|
|
||||||
Expressions: result,
|
|
||||||
}
|
|
||||||
return expressionTop{
|
|
||||||
Expression: &exp,
|
|
||||||
Auth: pa,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// processQuery processes an entire set of expressions and joins them with
|
func (a *authorizedSQLFilter) Eval(object Object) bool {
|
||||||
// "AND".
|
return a.auth.Authorize(context.Background(), object) == nil
|
||||||
func processQuery(query ast.Body) (Expression, error) {
|
|
||||||
expressions := make([]Expression, 0, len(query))
|
|
||||||
for _, astExpr := range query {
|
|
||||||
expr, err := processExpression(astExpr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
expressions = append(expressions, expr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return expAnd{
|
|
||||||
base: base{
|
|
||||||
Rego: query.String(),
|
|
||||||
},
|
|
||||||
Expressions: expressions,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func processExpression(expr *ast.Expr) (Expression, error) {
|
func (a *authorizedSQLFilter) SQLString() string {
|
||||||
if !expr.IsCall() {
|
return a.sqlString
|
||||||
// This could be a single term that is a valid expression.
|
|
||||||
if term, ok := expr.Terms.(*ast.Term); ok {
|
|
||||||
value, err := processTerm(term)
|
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("single term expression: %w", err)
|
|
||||||
}
|
|
||||||
if boolExp, ok := value.(Expression); ok {
|
|
||||||
return boolExp, nil
|
|
||||||
}
|
|
||||||
// Default to error.
|
|
||||||
}
|
|
||||||
return nil, xerrors.Errorf("invalid expression: single non-boolean terms not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
op := expr.Operator().String()
|
|
||||||
base := base{Rego: op}
|
|
||||||
switch op {
|
|
||||||
case "neq", "eq", "equal":
|
|
||||||
terms, err := processTerms(2, expr.Operands())
|
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
|
|
||||||
}
|
|
||||||
return &opEqual{
|
|
||||||
base: base,
|
|
||||||
Terms: [2]Term{terms[0], terms[1]},
|
|
||||||
Not: op == "neq",
|
|
||||||
}, nil
|
|
||||||
case "internal.member_2":
|
|
||||||
terms, err := processTerms(2, expr.Operands())
|
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("invalid '%s' expression: %w", op, err)
|
|
||||||
}
|
|
||||||
return &opInternalMember2{
|
|
||||||
base: base,
|
|
||||||
Needle: terms[0],
|
|
||||||
Haystack: terms[1],
|
|
||||||
}, nil
|
|
||||||
default:
|
|
||||||
return nil, xerrors.Errorf("invalid expression: operator %s not supported", op)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func processTerms(expected int, terms []*ast.Term) ([]Term, error) {
|
|
||||||
if len(terms) != expected {
|
|
||||||
return nil, xerrors.Errorf("too many arguments, expect %d found %d", expected, len(terms))
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]Term, 0, len(terms))
|
|
||||||
for _, term := range terms {
|
|
||||||
processed, err := processTerm(term)
|
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("invalid term: %w", err)
|
|
||||||
}
|
|
||||||
result = append(result, processed)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func processTerm(term *ast.Term) (Term, error) {
|
|
||||||
termBase := base{Rego: term.String()}
|
|
||||||
switch v := term.Value.(type) {
|
|
||||||
case ast.Boolean:
|
|
||||||
return &termBoolean{
|
|
||||||
base: termBase,
|
|
||||||
Value: bool(v),
|
|
||||||
}, nil
|
|
||||||
case ast.Ref:
|
|
||||||
obj := &termObject{
|
|
||||||
base: termBase,
|
|
||||||
Path: []Term{},
|
|
||||||
}
|
|
||||||
var idx int
|
|
||||||
// A ref is a set of terms. If the first term is a var, then the
|
|
||||||
// following terms are the path to the value.
|
|
||||||
isRef := true
|
|
||||||
var builder strings.Builder
|
|
||||||
for _, term := range v {
|
|
||||||
if idx == 0 {
|
|
||||||
if _, ok := v[0].Value.(ast.Var); !ok {
|
|
||||||
return nil, xerrors.Errorf("invalid term (%s): ref must start with a var, started with %T", v[0].String(), v[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, newRef := term.Value.(ast.Ref)
|
|
||||||
if newRef ||
|
|
||||||
// This is an unfortunate hack. To fix this, we need to rewrite
|
|
||||||
// our SQL config as a path ([]string{"input", "object", "acl_group"}).
|
|
||||||
// In the rego AST, there is no difference between selecting
|
|
||||||
// a field by a variable, and selecting a field by a literal (string).
|
|
||||||
// This was a misunderstanding.
|
|
||||||
// Example (these are equivalent by AST):
|
|
||||||
// input.object.acl_group_list['4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75']
|
|
||||||
// input.object.acl_group_list.organization_id
|
|
||||||
//
|
|
||||||
// This is not equivalent
|
|
||||||
// input.object.acl_group_list[input.object.organization_id]
|
|
||||||
//
|
|
||||||
// If this becomes even more hairy, we should fix the sql config.
|
|
||||||
builder.String() == "input.object.acl_group_list" ||
|
|
||||||
builder.String() == "input.object.acl_user_list" {
|
|
||||||
if !newRef {
|
|
||||||
isRef = false
|
|
||||||
}
|
|
||||||
// New obj
|
|
||||||
obj.Path = append(obj.Path, termVariable{
|
|
||||||
base: base{
|
|
||||||
Rego: builder.String(),
|
|
||||||
},
|
|
||||||
Name: builder.String(),
|
|
||||||
})
|
|
||||||
builder.Reset()
|
|
||||||
idx = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if builder.Len() != 0 {
|
|
||||||
builder.WriteString(".")
|
|
||||||
}
|
|
||||||
builder.WriteString(trimQuotes(term.String()))
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
|
|
||||||
if isRef {
|
|
||||||
obj.Path = append(obj.Path, termVariable{
|
|
||||||
base: base{
|
|
||||||
Rego: builder.String(),
|
|
||||||
},
|
|
||||||
Name: builder.String(),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
obj.Path = append(obj.Path, termString{
|
|
||||||
base: base{
|
|
||||||
Rego: fmt.Sprintf("%q", builder.String()),
|
|
||||||
},
|
|
||||||
Value: builder.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return obj, nil
|
|
||||||
case ast.Var:
|
|
||||||
return &termVariable{
|
|
||||||
Name: trimQuotes(v.String()),
|
|
||||||
base: termBase,
|
|
||||||
}, nil
|
|
||||||
case ast.String:
|
|
||||||
return &termString{
|
|
||||||
Value: trimQuotes(v.String()),
|
|
||||||
base: termBase,
|
|
||||||
}, nil
|
|
||||||
case ast.Set:
|
|
||||||
slice := v.Slice()
|
|
||||||
set := make([]Term, 0, len(slice))
|
|
||||||
for _, elem := range slice {
|
|
||||||
processed, err := processTerm(elem)
|
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("invalid set term %s: %w", elem.String(), err)
|
|
||||||
}
|
|
||||||
set = append(set, processed)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &termSet{
|
|
||||||
Value: set,
|
|
||||||
base: termBase,
|
|
||||||
}, nil
|
|
||||||
default:
|
|
||||||
return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type base struct {
|
|
||||||
// Rego is the original rego string
|
|
||||||
Rego string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b base) RegoString() string {
|
|
||||||
return b.Rego
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expression comprises a set of terms, operators, and functions. All
|
|
||||||
// expressions return a boolean value.
|
|
||||||
//
|
|
||||||
// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo"
|
|
||||||
type Expression interface {
|
|
||||||
// RegoString is used in debugging to see the original rego expression.
|
|
||||||
RegoString() string
|
|
||||||
// SQLString returns the SQL expression that can be used in a WHERE clause.
|
|
||||||
SQLString(cfg SQLConfig) string
|
|
||||||
}
|
|
||||||
|
|
||||||
type expAnd struct {
|
|
||||||
base
|
|
||||||
Expressions []Expression
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t expAnd) SQLString(cfg SQLConfig) string {
|
|
||||||
if len(t.Expressions) == 1 {
|
|
||||||
return t.Expressions[0].SQLString(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs := make([]string, 0, len(t.Expressions))
|
|
||||||
for _, expr := range t.Expressions {
|
|
||||||
exprs = append(exprs, expr.SQLString(cfg))
|
|
||||||
}
|
|
||||||
return "(" + strings.Join(exprs, " AND ") + ")"
|
|
||||||
}
|
|
||||||
|
|
||||||
type expOr struct {
|
|
||||||
base
|
|
||||||
Expressions []Expression
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t expOr) SQLString(cfg SQLConfig) string {
|
|
||||||
if len(t.Expressions) == 1 {
|
|
||||||
return t.Expressions[0].SQLString(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs := make([]string, 0, len(t.Expressions))
|
|
||||||
for _, expr := range t.Expressions {
|
|
||||||
exprs = append(exprs, expr.SQLString(cfg))
|
|
||||||
}
|
|
||||||
return "(" + strings.Join(exprs, " OR ") + ")"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Operator joins terms together to form an expression.
|
|
||||||
// Operators are also expressions.
|
|
||||||
//
|
|
||||||
// Eg: "=", "neq", "internal.member_2", etc.
|
|
||||||
type Operator interface {
|
|
||||||
Expression
|
|
||||||
}
|
|
||||||
|
|
||||||
type opEqual struct {
|
|
||||||
base
|
|
||||||
Terms [2]Term
|
|
||||||
// For NotEqual
|
|
||||||
Not bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t opEqual) SQLString(cfg SQLConfig) string {
|
|
||||||
op := "="
|
|
||||||
if t.Not {
|
|
||||||
op = "!="
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg))
|
|
||||||
}
|
|
||||||
|
|
||||||
// opInternalMember2 is checking if the first term is a member of the second term.
|
|
||||||
// The second term is a set or list.
|
|
||||||
type opInternalMember2 struct {
|
|
||||||
base
|
|
||||||
Needle Term
|
|
||||||
Haystack Term
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t opInternalMember2) SQLString(cfg SQLConfig) string {
|
|
||||||
if haystack, ok := t.Haystack.(*termObject); ok {
|
|
||||||
// This is a special case where the haystack is a jsonb array.
|
|
||||||
// There is a more general way to solve this, but that requires a lot
|
|
||||||
// more code to cover a lot more cases that we do not care about.
|
|
||||||
// To handle this more generally we should implement "Array" as a type.
|
|
||||||
// Then have the `contains` function on the Array type. This would defer
|
|
||||||
// knowing the element type to the Array and cover more cases without
|
|
||||||
// having to add more "if" branches here.
|
|
||||||
// But until we need more cases, our basic type system is ok, and
|
|
||||||
// this is the only case we need to handle.
|
|
||||||
sqlType := haystack.SQLType(cfg)
|
|
||||||
if sqlType == VarTypeJsonbTextArray {
|
|
||||||
return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg))
|
|
||||||
}
|
|
||||||
|
|
||||||
if sqlType == VarTypeSkip {
|
|
||||||
return "false"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Term is a single value in an expression. Terms can be variables or constants.
|
|
||||||
//
|
|
||||||
// Eg: "f9d6fb75-b59b-4363-ab6b-ae9d26b679d7", "input.object.org_owner",
|
|
||||||
// "{"f9d6fb75-b59b-4363-ab6b-ae9d26b679d7"}"
|
|
||||||
type Term interface {
|
|
||||||
RegoString() string
|
|
||||||
SQLString(cfg SQLConfig) string
|
|
||||||
SQLType(cfg SQLConfig) TermType
|
|
||||||
}
|
|
||||||
|
|
||||||
type termString struct {
|
|
||||||
base
|
|
||||||
Value string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termString) SQLString(_ SQLConfig) string {
|
|
||||||
return "'" + t.Value + "'"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (termString) SQLType(_ SQLConfig) TermType {
|
|
||||||
return VarTypeText
|
|
||||||
}
|
|
||||||
|
|
||||||
// termObject is a variable that can be dereferenced. We count some rego objects
|
|
||||||
// as single variables, eg: input.object.org_owner. In reality, it is a nested
|
|
||||||
// object.
|
|
||||||
// In rego, we can dereference the object with the "." operator, which we can
|
|
||||||
// handle with regex.
|
|
||||||
// Or we can dereference the object with the "[]", which we can handle with this
|
|
||||||
// term type.
|
|
||||||
type termObject struct {
|
|
||||||
base
|
|
||||||
Path []Term
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termObject) SQLType(cfg SQLConfig) TermType {
|
|
||||||
// Without a full type system, let's just assume the type of the first var
|
|
||||||
// is the resulting type. This is correct for our use case.
|
|
||||||
// Solving this more generally requires a full type system, which is
|
|
||||||
// excessive for our mostly static policy.
|
|
||||||
return t.Path[0].SQLType(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termObject) SQLString(cfg SQLConfig) string {
|
|
||||||
if len(t.Path) == 1 {
|
|
||||||
return t.Path[0].SQLString(cfg)
|
|
||||||
}
|
|
||||||
// Combine the last 2 variables into 1 variable.
|
|
||||||
end := t.Path[len(t.Path)-1]
|
|
||||||
before := t.Path[len(t.Path)-2]
|
|
||||||
|
|
||||||
// Recursively solve the SQLString by removing the last nested reference.
|
|
||||||
// This continues until we have a single variable.
|
|
||||||
return termObject{
|
|
||||||
base: t.base,
|
|
||||||
Path: append(
|
|
||||||
t.Path[:len(t.Path)-2],
|
|
||||||
termVariable{
|
|
||||||
base: base{
|
|
||||||
Rego: before.RegoString() + "[" + end.RegoString() + "]",
|
|
||||||
},
|
|
||||||
// Convert the end to SQL string. We evaluate each term
|
|
||||||
// one at a time.
|
|
||||||
Name: before.RegoString() + "." + end.SQLString(cfg),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}.SQLString(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
type termVariable struct {
|
|
||||||
base
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termVariable) SQLType(cfg SQLConfig) TermType {
|
|
||||||
if col := t.ColumnConfig(cfg); col != nil {
|
|
||||||
return col.Type
|
|
||||||
}
|
|
||||||
return VarTypeText
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termVariable) SQLString(cfg SQLConfig) string {
|
|
||||||
if col := t.ColumnConfig(cfg); col != nil {
|
|
||||||
matches := col.RegoMatch.FindStringSubmatch(t.Name)
|
|
||||||
if len(matches) > 0 {
|
|
||||||
// This config matches this variable.
|
|
||||||
replace := make([]string, 0, len(matches)*2)
|
|
||||||
for i, m := range matches {
|
|
||||||
replace = append(replace, fmt.Sprintf("$%d", i))
|
|
||||||
replace = append(replace, m)
|
|
||||||
}
|
|
||||||
replacer := strings.NewReplacer(replace...)
|
|
||||||
return replacer.Replace(col.ColumnSelect)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColumnConfig returns the correct SQLColumn settings for the
|
|
||||||
// term. If there is no configured column, it will return nil.
|
|
||||||
func (t termVariable) ColumnConfig(cfg SQLConfig) *SQLColumn {
|
|
||||||
for _, col := range cfg.Variables {
|
|
||||||
matches := col.RegoMatch.MatchString(t.Name)
|
|
||||||
if matches {
|
|
||||||
return &col
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// termSet is a set of unique terms.
|
|
||||||
type termSet struct {
|
|
||||||
base
|
|
||||||
Value []Term
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termSet) SQLType(cfg SQLConfig) TermType {
|
|
||||||
if len(t.Value) == 0 {
|
|
||||||
return VarTypeText
|
|
||||||
}
|
|
||||||
// Without a full type system, let's just assume the type of the first var
|
|
||||||
// is the resulting type. This is correct for our use case.
|
|
||||||
// Solving this more generally requires a full type system, which is
|
|
||||||
// excessive for our mostly static policy.
|
|
||||||
return t.Value[0].SQLType(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termSet) SQLString(cfg SQLConfig) string {
|
|
||||||
elems := make([]string, 0, len(t.Value))
|
|
||||||
for _, v := range t.Value {
|
|
||||||
elems = append(elems, v.SQLString(cfg))
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("ARRAY [%s]", strings.Join(elems, ","))
|
|
||||||
}
|
|
||||||
|
|
||||||
type termBoolean struct {
|
|
||||||
base
|
|
||||||
Value bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (termBoolean) SQLType(SQLConfig) TermType {
|
|
||||||
return VarTypeBoolean
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termBoolean) Eval(_ Object) bool {
|
|
||||||
return t.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t termBoolean) SQLString(_ SQLConfig) string {
|
|
||||||
return strconv.FormatBool(t.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func trimQuotes(s string) string {
|
|
||||||
return strings.Trim(s, "\"")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,150 +0,0 @@
|
||||||
package rbac
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/open-policy-agent/opa/ast"
|
|
||||||
"github.com/open-policy-agent/opa/rego"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCompileQuery(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("EmptyQuery", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
expression, err := Compile(partialQueries(t, ""))
|
|
||||||
require.NoError(t, err, "compile empty")
|
|
||||||
|
|
||||||
require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'")
|
|
||||||
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "empty query is sql 'true'")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TrueQuery", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
expression, err := Compile(partialQueries(t, "true"))
|
|
||||||
require.NoError(t, err, "compile")
|
|
||||||
|
|
||||||
require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'")
|
|
||||||
require.Equal(t, "true", expression.SQLString(SQLConfig{}), "true query is sql 'true'")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ACLIn", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
expression, err := Compile(partialQueries(t, `"*" in input.object.acl_group_list.allUsers`))
|
|
||||||
require.NoError(t, err, "compile")
|
|
||||||
|
|
||||||
require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member")
|
|
||||||
require.Equal(t, `group_acl->'allUsers' ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Complex", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
expression, err := Compile(partialQueries(t,
|
|
||||||
`input.object.org_owner != ""`,
|
|
||||||
`input.object.org_owner in {"a", "b", "c"}`,
|
|
||||||
`input.object.org_owner != ""`,
|
|
||||||
`"read" in input.object.acl_group_list.allUsers`,
|
|
||||||
`"read" in input.object.acl_user_list.me`,
|
|
||||||
))
|
|
||||||
require.NoError(t, err, "compile")
|
|
||||||
require.Equal(t, `(organization_id :: text != '' OR `+
|
|
||||||
`organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+
|
|
||||||
`organization_id :: text != '' OR `+
|
|
||||||
`group_acl->'allUsers' ? 'read' OR `+
|
|
||||||
`user_acl->'me' ? 'read')`,
|
|
||||||
expression.SQLString(DefaultConfig()), "complex")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("SetDereference", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
expression, err := Compile(partialQueries(t,
|
|
||||||
`"*" in input.object.acl_group_list[input.object.org_owner]`,
|
|
||||||
))
|
|
||||||
require.NoError(t, err, "compile")
|
|
||||||
require.Equal(t, `group_acl->organization_id :: text ? '*'`,
|
|
||||||
expression.SQLString(DefaultConfig()), "set dereference")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("JsonbLiteralDereference", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
expression, err := Compile(partialQueries(t,
|
|
||||||
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
|
|
||||||
))
|
|
||||||
require.NoError(t, err, "compile")
|
|
||||||
require.Equal(t, `group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'`,
|
|
||||||
expression.SQLString(DefaultConfig()), "literal dereference")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("NoACLColumns", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
expression, err := Compile(partialQueries(t,
|
|
||||||
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
|
|
||||||
))
|
|
||||||
require.NoError(t, err, "compile")
|
|
||||||
require.Equal(t, `false`,
|
|
||||||
expression.SQLString(NoACLConfig()), "literal dereference")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEvalQuery(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("GroupACL", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
expression, err := Compile(partialQueries(t,
|
|
||||||
`"read" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
|
|
||||||
))
|
|
||||||
require.NoError(t, err, "compile")
|
|
||||||
|
|
||||||
result := expression.Eval(Object{
|
|
||||||
Owner: "not-me",
|
|
||||||
OrgID: "random",
|
|
||||||
Type: "workspace",
|
|
||||||
ACLUserList: map[string][]Action{},
|
|
||||||
ACLGroupList: map[string][]Action{
|
|
||||||
"4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75": {"read"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.True(t, result, "eval")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func partialQueries(t *testing.T, queries ...string) *PartialAuthorizer {
|
|
||||||
opts := ast.ParserOptions{
|
|
||||||
AllFutureKeywords: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
astQueries := make([]ast.Body, 0, len(queries))
|
|
||||||
for _, q := range queries {
|
|
||||||
astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts))
|
|
||||||
}
|
|
||||||
|
|
||||||
prepareQueries := make([]rego.PreparedEvalQuery, 0, len(queries))
|
|
||||||
for _, q := range astQueries {
|
|
||||||
var prepped rego.PreparedEvalQuery
|
|
||||||
var err error
|
|
||||||
if q.String() == "" {
|
|
||||||
prepped, err = rego.New(
|
|
||||||
rego.Query("true"),
|
|
||||||
).PrepareForEval(context.Background())
|
|
||||||
} else {
|
|
||||||
prepped, err = rego.New(
|
|
||||||
rego.ParsedQuery(q),
|
|
||||||
).PrepareForEval(context.Background())
|
|
||||||
}
|
|
||||||
require.NoError(t, err, "prepare query")
|
|
||||||
prepareQueries = append(prepareQueries, prepped)
|
|
||||||
}
|
|
||||||
return &PartialAuthorizer{
|
|
||||||
partialQueries: ®o.PartialQueries{
|
|
||||||
Queries: astQueries,
|
|
||||||
Support: []*ast.Module{},
|
|
||||||
},
|
|
||||||
preparedQueries: prepareQueries,
|
|
||||||
input: nil,
|
|
||||||
alwaysTrue: false,
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
package regosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/open-policy-agent/opa/ast"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ sqltypes.VariableMatcher = ACLGroupVar{}
|
||||||
|
var _ sqltypes.Node = ACLGroupVar{}
|
||||||
|
|
||||||
|
// ACLGroupVar is a variable matcher that handles group_acl and user_acl.
|
||||||
|
// The sql type is a jsonb object with the following structure:
|
||||||
|
//
|
||||||
|
// "group_acl": {
|
||||||
|
// "<group_name>": ["<actions>"]
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// This is a custom variable matcher as json objects have arbitrary complexity.
|
||||||
|
type ACLGroupVar struct {
|
||||||
|
StructSQL string
|
||||||
|
// input.object.group_acl -> ["input", "object", "group_acl"]
|
||||||
|
StructPath []string
|
||||||
|
|
||||||
|
// FieldReference handles referencing the subfields, which could be
|
||||||
|
// more variables. We pass one in as the global one might not be correctly
|
||||||
|
// scoped.
|
||||||
|
FieldReference sqltypes.VariableMatcher
|
||||||
|
|
||||||
|
// Instance fields
|
||||||
|
Source sqltypes.RegoSource
|
||||||
|
GroupNode sqltypes.Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func ACLGroupMatcher(fieldReference sqltypes.VariableMatcher, structSQL string, structPath []string) ACLGroupVar {
|
||||||
|
return ACLGroupVar{StructSQL: structSQL, StructPath: structPath, FieldReference: fieldReference}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ACLGroupVar) UseAs() sqltypes.Node { return ACLGroupVar{} }
|
||||||
|
|
||||||
|
func (g ACLGroupVar) ConvertVariable(rego ast.Ref) (sqltypes.Node, bool) {
|
||||||
|
// "left" will be a map of group names to actions in rego.
|
||||||
|
// {
|
||||||
|
// "all_users": ["read"]
|
||||||
|
// }
|
||||||
|
left, err := sqltypes.RegoVarPath(g.StructPath, rego)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
aclGrp := ACLGroupVar{
|
||||||
|
StructSQL: g.StructSQL,
|
||||||
|
StructPath: g.StructPath,
|
||||||
|
FieldReference: g.FieldReference,
|
||||||
|
|
||||||
|
Source: sqltypes.RegoSource(rego.String()),
|
||||||
|
}
|
||||||
|
|
||||||
|
// We expect 1 more term. Either a ref or a string.
|
||||||
|
if len(left) != 1 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the remaining is a variable, then we need to convert it.
|
||||||
|
// Assuming we support variable fields.
|
||||||
|
ref, ok := left[0].Value.(ast.Ref)
|
||||||
|
if ok && g.FieldReference != nil {
|
||||||
|
groupNode, ok := g.FieldReference.ConvertVariable(ref)
|
||||||
|
if ok {
|
||||||
|
aclGrp.GroupNode = groupNode
|
||||||
|
return aclGrp, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it is a string, we assume it is a literal
|
||||||
|
groupName, ok := left[0].Value.(ast.String)
|
||||||
|
if ok {
|
||||||
|
aclGrp.GroupNode = sqltypes.String(string(groupName))
|
||||||
|
return aclGrp, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have not matched it yet, then it is something we do not recognize.
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g ACLGroupVar) SQLString(cfg *sqltypes.SQLGenerator) string {
|
||||||
|
return fmt.Sprintf("%s->%s", g.StructSQL, g.GroupNode.SQLString(cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g ACLGroupVar) ContainsSQL(cfg *sqltypes.SQLGenerator, other sqltypes.Node) (string, error) {
|
||||||
|
switch other.UseAs().(type) {
|
||||||
|
// Only supports containing other strings.
|
||||||
|
case sqltypes.AstString:
|
||||||
|
return fmt.Sprintf("%s ? %s", g.SQLString(cfg), other.SQLString(cfg)), nil
|
||||||
|
default:
|
||||||
|
return "", xerrors.Errorf("unsupported acl group contains %T", other)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,230 @@
|
||||||
|
package regosql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/open-policy-agent/opa/ast"
|
||||||
|
"github.com/open-policy-agent/opa/rego"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertConfig is required to generate SQL from the rego queries.
|
||||||
|
type ConvertConfig struct {
|
||||||
|
// VariableConverter is called each time a var is encountered. This creates
|
||||||
|
// the SQL ast for the variable. Without this, the SQL generator does not
|
||||||
|
// know how to convert rego variables into SQL columns.
|
||||||
|
VariableConverter sqltypes.VariableMatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertRegoAst converts partial rego queries into a single SQL where
|
||||||
|
// clause. If the query equates to "true" then the user should have access.
|
||||||
|
func ConvertRegoAst(cfg ConvertConfig, partial *rego.PartialQueries) (sqltypes.BooleanNode, error) {
|
||||||
|
if len(partial.Queries) == 0 {
|
||||||
|
// Always deny if there are no queries. This means there is no possible
|
||||||
|
// way this user can access these resources.
|
||||||
|
return sqltypes.Bool(false), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, q := range partial.Queries {
|
||||||
|
// An empty query in rego means "true". If any query in the set is
|
||||||
|
// empty, then the user should have access.
|
||||||
|
if len(q) == 0 {
|
||||||
|
// Always allow
|
||||||
|
return sqltypes.Bool(true), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var queries []sqltypes.BooleanNode
|
||||||
|
var builder strings.Builder
|
||||||
|
for i, q := range partial.Queries {
|
||||||
|
converted, err := convertQuery(cfg, q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("query %s: %w", q.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if i != 0 {
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
builder.WriteString(q.String())
|
||||||
|
queries = append(queries, converted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All queries are OR'd together. This means that if any query is true,
|
||||||
|
// then the user should have access.
|
||||||
|
sqlClause := sqltypes.Or(sqltypes.RegoSource(builder.String()), queries...)
|
||||||
|
// Always wrap in parens to ensure the correct precedence when combining with other
|
||||||
|
// SQL clauses.
|
||||||
|
return sqltypes.BoolParenthesis(sqlClause), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertQuery(cfg ConvertConfig, q ast.Body) (sqltypes.BooleanNode, error) {
|
||||||
|
var expressions []sqltypes.BooleanNode
|
||||||
|
for _, e := range q {
|
||||||
|
exp, err := convertExpression(cfg, e)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("expression %s: %w", e.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expressions = append(expressions, exp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All expressions in a single query are AND'd together. This means that
|
||||||
|
// all expressions must be true for the user to have access.
|
||||||
|
return sqltypes.And(sqltypes.RegoSource(q.String()), expressions...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertExpression(cfg ConvertConfig, e *ast.Expr) (sqltypes.BooleanNode, error) {
|
||||||
|
if e.IsCall() {
|
||||||
|
n, err := convertCall(cfg, e.Terms.([]*ast.Term))
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
boolN, ok := n.(sqltypes.BooleanNode)
|
||||||
|
if !ok {
|
||||||
|
return nil, xerrors.Errorf("call %q: not a boolean expression", e.String())
|
||||||
|
}
|
||||||
|
return boolN, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's not a call, it is a single term
|
||||||
|
if term, ok := e.Terms.(*ast.Term); ok {
|
||||||
|
ty, err := convertTerm(cfg, term)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("convert term %s: %w", term.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tyBool, ok := ty.(sqltypes.BooleanNode)
|
||||||
|
if !ok {
|
||||||
|
return nil, xerrors.Errorf("convert term %s is not a boolean: %w", term.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tyBool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, xerrors.Errorf("expression %s not supported", e.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertCall converts a function call to a SQL expression.
|
||||||
|
func convertCall(cfg ConvertConfig, call ast.Call) (sqltypes.Node, error) {
|
||||||
|
if len(call) == 0 {
|
||||||
|
return nil, xerrors.Errorf("empty call")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operator is the first term
|
||||||
|
op := call[0]
|
||||||
|
var args []*ast.Term
|
||||||
|
if len(call) > 1 {
|
||||||
|
args = call[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
opString := op.String()
|
||||||
|
// Supported operators.
|
||||||
|
switch op.String() {
|
||||||
|
case "neq", "eq", "equals", "equal":
|
||||||
|
args, err := convertTerms(cfg, args, 2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("arguments: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
not := false
|
||||||
|
if opString == "neq" || opString == "notequals" || opString == "notequal" {
|
||||||
|
not = true
|
||||||
|
}
|
||||||
|
|
||||||
|
equality := sqltypes.Equality(not, args[0], args[1])
|
||||||
|
return sqltypes.BoolParenthesis(equality), nil
|
||||||
|
case "internal.member_2":
|
||||||
|
args, err := convertTerms(cfg, args, 2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("arguments: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
member := sqltypes.MemberOf(args[0], args[1])
|
||||||
|
return sqltypes.BoolParenthesis(member), nil
|
||||||
|
default:
|
||||||
|
return nil, xerrors.Errorf("operator %s not supported", op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertTerms(cfg ConvertConfig, terms []*ast.Term, expected int) ([]sqltypes.Node, error) {
|
||||||
|
if len(terms) != expected {
|
||||||
|
return nil, xerrors.Errorf("expected %d terms, got %d", expected, len(terms))
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]sqltypes.Node, 0, len(terms))
|
||||||
|
for _, t := range terms {
|
||||||
|
term, err := convertTerm(cfg, t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("term: %w", err)
|
||||||
|
}
|
||||||
|
result = append(result, term)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertTerm(cfg ConvertConfig, term *ast.Term) (sqltypes.Node, error) {
|
||||||
|
source := sqltypes.RegoSource(term.String())
|
||||||
|
switch t := term.Value.(type) {
|
||||||
|
case ast.Var:
|
||||||
|
// All vars should be contained in ast.Ref's.
|
||||||
|
return nil, xerrors.New("var not yet supported")
|
||||||
|
case ast.Ref:
|
||||||
|
if len(t) == 0 {
|
||||||
|
// A reference with no text is a variable with no name?
|
||||||
|
// This makes no sense.
|
||||||
|
return nil, xerrors.New("empty ref not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.VariableConverter == nil {
|
||||||
|
return nil, xerrors.New("no variable converter provided to handle variables")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The structure of references is as follows:
|
||||||
|
// 1. All variables start with a regoAst.Var as the first term.
|
||||||
|
// 2. The next term is either a regoAst.String or a regoAst.Var.
|
||||||
|
// - regoAst.String if a static field name or index.
|
||||||
|
// - regoAst.Var if the field reference is a variable itself. Such as
|
||||||
|
// the wildcard "[_]"
|
||||||
|
// 3. Repeat 1-2 until the end of the reference.
|
||||||
|
node, ok := cfg.VariableConverter.ConvertVariable(t)
|
||||||
|
if !ok {
|
||||||
|
return nil, xerrors.Errorf("variable %q cannot be converted", t.String())
|
||||||
|
}
|
||||||
|
return node, nil
|
||||||
|
case ast.String:
|
||||||
|
return sqltypes.String(string(t)), nil
|
||||||
|
case ast.Number:
|
||||||
|
return sqltypes.Number(source, json.Number(t)), nil
|
||||||
|
case ast.Boolean:
|
||||||
|
return sqltypes.Bool(bool(t)), nil
|
||||||
|
case *ast.Array:
|
||||||
|
elems := make([]sqltypes.Node, 0, t.Len())
|
||||||
|
for i := 0; i < t.Len(); i++ {
|
||||||
|
value, err := convertTerm(cfg, t.Elem(i))
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("array element %d in %q: %w", i, t.String(), err)
|
||||||
|
}
|
||||||
|
elems = append(elems, value)
|
||||||
|
}
|
||||||
|
return sqltypes.Array(source, elems...)
|
||||||
|
case ast.Object:
|
||||||
|
return nil, xerrors.New("object not yet supported")
|
||||||
|
case ast.Set:
|
||||||
|
// Just treat a set like an array for now.
|
||||||
|
arr := t.Sorted()
|
||||||
|
return convertTerm(cfg, &ast.Term{
|
||||||
|
Value: arr,
|
||||||
|
Location: term.Location,
|
||||||
|
})
|
||||||
|
case ast.Call:
|
||||||
|
// This is a function call
|
||||||
|
return convertCall(cfg, t)
|
||||||
|
default:
|
||||||
|
return nil, xerrors.Errorf("%T not yet supported", t)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,307 @@
|
||||||
|
package regosql_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/open-policy-agent/opa/ast"
|
||||||
|
"github.com/open-policy-agent/opa/rego"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql"
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestRegoQueriesNoVariables handles cases without variables. These should be
|
||||||
|
// very simple and straight forward.
|
||||||
|
func TestRegoQueries(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
p := func(v string) string {
|
||||||
|
return "(" + v + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
Name string
|
||||||
|
Queries []string
|
||||||
|
ExpectedSQL string
|
||||||
|
ExpectError bool
|
||||||
|
ExpectedSQLGenError bool
|
||||||
|
|
||||||
|
VariableConverter sqltypes.VariableMatcher
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "Empty",
|
||||||
|
Queries: []string{``},
|
||||||
|
ExpectedSQL: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "True",
|
||||||
|
Queries: []string{`true`},
|
||||||
|
ExpectedSQL: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "False",
|
||||||
|
Queries: []string{`false`},
|
||||||
|
ExpectedSQL: "false",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "MultipleBool",
|
||||||
|
Queries: []string{"true", "false"},
|
||||||
|
ExpectedSQL: "(true OR false)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Numbers",
|
||||||
|
Queries: []string{
|
||||||
|
"(1 != 2) = true",
|
||||||
|
"5 == 5",
|
||||||
|
},
|
||||||
|
ExpectedSQL: p("((1 != 2) = true) OR (5 = 5)"),
|
||||||
|
},
|
||||||
|
// Variables
|
||||||
|
{
|
||||||
|
// Always return a constant string for all variables.
|
||||||
|
Name: "V_Basic",
|
||||||
|
Queries: []string{
|
||||||
|
`input.x = "hello_world"`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: p("only_var = 'hello_world'"),
|
||||||
|
VariableConverter: sqltypes.NewVariableConverter().RegisterMatcher(
|
||||||
|
sqltypes.StringVarMatcher("only_var", []string{
|
||||||
|
"input", "x",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
// Coder Variables
|
||||||
|
{
|
||||||
|
// Always return a constant string for all variables.
|
||||||
|
Name: "GroupACL",
|
||||||
|
Queries: []string{
|
||||||
|
`"read" in input.object.acl_group_list.allUsers`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: "(group_acl->'allUsers' ? 'read')",
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "GroupWildcard",
|
||||||
|
Queries: []string{`"*" in input.object.acl_group_list.allUsers`},
|
||||||
|
ExpectedSQL: "(group_acl->'allUsers' ? '*')",
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Always return a constant string for all variables.
|
||||||
|
Name: "GroupACLWithVarField",
|
||||||
|
Queries: []string{
|
||||||
|
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: "(group_acl->organization_id :: text ? 'read')",
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "VarInArray",
|
||||||
|
Queries: []string{
|
||||||
|
`input.object.org_owner in {"a", "b", "c"}`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: p("organization_id :: text = ANY(ARRAY ['a','b','c'])"),
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "SetDereference",
|
||||||
|
Queries: []string{`"*" in input.object.acl_group_list[input.object.org_owner]`},
|
||||||
|
ExpectedSQL: p("group_acl->organization_id :: text ? '*'"),
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "JsonbLiteralDereference",
|
||||||
|
Queries: []string{`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`},
|
||||||
|
ExpectedSQL: p("group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'"),
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Complex",
|
||||||
|
Queries: []string{
|
||||||
|
`input.object.org_owner != ""`,
|
||||||
|
`input.object.org_owner in {"a", "b", "c"}`,
|
||||||
|
`input.object.org_owner != ""`,
|
||||||
|
`"read" in input.object.acl_group_list.allUsers`,
|
||||||
|
`"read" in input.object.acl_user_list.me`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: `((organization_id :: text != '') OR ` +
|
||||||
|
`(organization_id :: text = ANY(ARRAY ['a','b','c'])) OR ` +
|
||||||
|
`(organization_id :: text != '') OR ` +
|
||||||
|
`(group_acl->'allUsers' ? 'read') OR ` +
|
||||||
|
`(user_acl->'me' ? 'read'))`,
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "NoACLs",
|
||||||
|
Queries: []string{
|
||||||
|
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||||
|
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
|
||||||
|
},
|
||||||
|
// Special case where the bool is wrapped
|
||||||
|
ExpectedSQL: p("(false) OR (false)"),
|
||||||
|
VariableConverter: regosql.NoACLConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "TwoExpressions",
|
||||||
|
Queries: []string{
|
||||||
|
`true; true`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: p("true AND true"),
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
|
||||||
|
// Actual vectors from production
|
||||||
|
{
|
||||||
|
Name: "FromOwner",
|
||||||
|
Queries: []string{
|
||||||
|
``,
|
||||||
|
`"05f58202-4bfc-43ce-9ba4-5ff6e0174a71" = input.object.org_owner`,
|
||||||
|
`"read" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: "true",
|
||||||
|
VariableConverter: regosql.NoACLConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "OrgAdmin",
|
||||||
|
Queries: []string{
|
||||||
|
`input.object.org_owner != "";
|
||||||
|
input.object.org_owner in {"05f58202-4bfc-43ce-9ba4-5ff6e0174a71"};
|
||||||
|
input.object.owner != "";
|
||||||
|
"d5389ccc-57a4-4b13-8c3f-31747bcdc9f1" = input.object.owner`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: "((organization_id :: text != '') AND " +
|
||||||
|
"(organization_id :: text = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'])) AND " +
|
||||||
|
"(owner_id :: text != '') AND " +
|
||||||
|
"('d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' = owner_id :: text))",
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "UserACLAllow",
|
||||||
|
Queries: []string{
|
||||||
|
`"read" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
|
||||||
|
`"*" in input.object.acl_user_list["d5389ccc-57a4-4b13-8c3f-31747bcdc9f1"]`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: "((user_acl->'d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' ? 'read') OR " +
|
||||||
|
"(user_acl->'d5389ccc-57a4-4b13-8c3f-31747bcdc9f1' ? '*'))",
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "NoACLConfig",
|
||||||
|
Queries: []string{
|
||||||
|
`input.object.org_owner != "";
|
||||||
|
input.object.org_owner in {"05f58202-4bfc-43ce-9ba4-5ff6e0174a71"};
|
||||||
|
"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: "((organization_id :: text != '') AND (organization_id :: text = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'])) AND (false))",
|
||||||
|
VariableConverter: regosql.NoACLConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "EmptyACLListNoACLs",
|
||||||
|
Queries: []string{
|
||||||
|
`input.object.org_owner != "";
|
||||||
|
input.object.org_owner in set();
|
||||||
|
"create" in input.object.acl_group_list[input.object.org_owner]`,
|
||||||
|
|
||||||
|
`input.object.org_owner != "";
|
||||||
|
input.object.org_owner in set();
|
||||||
|
"*" in input.object.acl_group_list[input.object.org_owner]`,
|
||||||
|
|
||||||
|
`"create" in input.object.acl_user_list.me`,
|
||||||
|
|
||||||
|
`"*" in input.object.acl_user_list.me`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: p(p("(organization_id :: text != '') AND (false) AND (group_acl->organization_id :: text ? 'create')") + " OR " +
|
||||||
|
p("(organization_id :: text != '') AND (false) AND (group_acl->organization_id :: text ? '*')") + " OR " +
|
||||||
|
p("user_acl->'me' ? 'create'") + " OR " +
|
||||||
|
p("user_acl->'me' ? '*'")),
|
||||||
|
VariableConverter: regosql.DefaultVariableConverter(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "TemplateOwner",
|
||||||
|
Queries: []string{
|
||||||
|
`neq(input.object.org_owner, "");
|
||||||
|
internal.member_2(input.object.org_owner, {"3bf82434-e40b-44ae-b3d8-d0115bba9bad", "5630fda3-26ab-462c-9014-a88a62d7a415", "c304877a-bc0d-4e9b-9623-a38eae412929"});
|
||||||
|
neq(input.object.owner, "");
|
||||||
|
"806dd721-775f-4c85-9ce3-63fbbd975954" = input.object.owner`,
|
||||||
|
},
|
||||||
|
ExpectedSQL: p(p("organization_id :: text != ''") + " AND " +
|
||||||
|
p("organization_id :: text = ANY(ARRAY ['3bf82434-e40b-44ae-b3d8-d0115bba9bad','5630fda3-26ab-462c-9014-a88a62d7a415','c304877a-bc0d-4e9b-9623-a38eae412929'])") + " AND " +
|
||||||
|
p("false") + " AND " +
|
||||||
|
p("false")),
|
||||||
|
VariableConverter: regosql.TemplateConverter(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
part := partialQueries(tc.Queries...)
|
||||||
|
|
||||||
|
cfg := regosql.ConvertConfig{
|
||||||
|
VariableConverter: tc.VariableConverter,
|
||||||
|
}
|
||||||
|
|
||||||
|
requireConvert(t, convertTestCase{
|
||||||
|
part: part,
|
||||||
|
cfg: cfg,
|
||||||
|
expectSQL: tc.ExpectedSQL,
|
||||||
|
expectConvertError: tc.ExpectError,
|
||||||
|
expectSQLGenError: tc.ExpectedSQLGenError,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type convertTestCase struct {
|
||||||
|
part *rego.PartialQueries
|
||||||
|
cfg regosql.ConvertConfig
|
||||||
|
|
||||||
|
expectConvertError bool
|
||||||
|
expectSQL string
|
||||||
|
expectSQLGenError bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireConvert(t *testing.T, tc convertTestCase) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for i, q := range tc.part.Queries {
|
||||||
|
t.Logf("Query %d: %s", i, q.String())
|
||||||
|
}
|
||||||
|
for i, s := range tc.part.Support {
|
||||||
|
t.Logf("Support %d: %s", i, s.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
root, err := regosql.ConvertRegoAst(tc.cfg, tc.part)
|
||||||
|
if tc.expectConvertError {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err, "compile")
|
||||||
|
|
||||||
|
gen := sqltypes.NewSQLGenerator()
|
||||||
|
sqlString := root.SQLString(gen)
|
||||||
|
if tc.expectSQLGenError {
|
||||||
|
require.True(t, len(gen.Errors()) > 0, "expected SQL generation error")
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err, "sql gen")
|
||||||
|
require.Equal(t, tc.expectSQL, sqlString, "sql match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func partialQueries(queries ...string) *rego.PartialQueries {
|
||||||
|
opts := ast.ParserOptions{
|
||||||
|
AllFutureKeywords: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
astQueries := make([]ast.Body, 0, len(queries))
|
||||||
|
for _, q := range queries {
|
||||||
|
astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
return ®o.PartialQueries{
|
||||||
|
Queries: astQueries,
|
||||||
|
Support: []*ast.Module{},
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
package regosql
|
||||||
|
|
||||||
|
import "github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||||
|
|
||||||
|
func organizationOwnerMatcher() sqltypes.VariableMatcher {
|
||||||
|
return sqltypes.StringVarMatcher("organization_id :: text", []string{"input", "object", "org_owner"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func userOwnerMatcher() sqltypes.VariableMatcher {
|
||||||
|
return sqltypes.StringVarMatcher("owner_id :: text", []string{"input", "object", "owner"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func groupACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher {
|
||||||
|
return ACLGroupMatcher(m, "group_acl", []string{"input", "object", "acl_group_list"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func userACLMatcher(m sqltypes.VariableMatcher) sqltypes.VariableMatcher {
|
||||||
|
return ACLGroupMatcher(m, "user_acl", []string{"input", "object", "acl_user_list"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TemplateConverter() *sqltypes.VariableConverter {
|
||||||
|
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||||
|
organizationOwnerMatcher(),
|
||||||
|
// Templates have no user owner, only owner by an organization.
|
||||||
|
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||||
|
)
|
||||||
|
matcher.RegisterMatcher(
|
||||||
|
groupACLMatcher(matcher),
|
||||||
|
userACLMatcher(matcher),
|
||||||
|
)
|
||||||
|
return matcher
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoACLConverter should be used when the target SQL table does not contain
|
||||||
|
// group or user ACL columns.
|
||||||
|
func NoACLConverter() *sqltypes.VariableConverter {
|
||||||
|
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||||
|
organizationOwnerMatcher(),
|
||||||
|
userOwnerMatcher(),
|
||||||
|
)
|
||||||
|
matcher.RegisterMatcher(
|
||||||
|
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
|
||||||
|
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return matcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultVariableConverter() *sqltypes.VariableConverter {
|
||||||
|
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||||
|
organizationOwnerMatcher(),
|
||||||
|
userOwnerMatcher(),
|
||||||
|
)
|
||||||
|
matcher.RegisterMatcher(
|
||||||
|
groupACLMatcher(matcher),
|
||||||
|
userACLMatcher(matcher),
|
||||||
|
)
|
||||||
|
|
||||||
|
return matcher
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
// Package regosql converts rego queries into SQL WHERE clauses. This is so
|
||||||
|
// the rego queries can be used to filter the results of a SQL query.
|
||||||
|
package regosql
|
|
@ -0,0 +1,61 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/open-policy-agent/opa/ast"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ Node = alwaysFalse{}
|
||||||
|
var _ VariableMatcher = alwaysFalse{}
|
||||||
|
|
||||||
|
type alwaysFalse struct {
|
||||||
|
Matcher VariableMatcher
|
||||||
|
|
||||||
|
InnerNode Node
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlwaysFalse overrides the inner node with a constant "false".
|
||||||
|
func AlwaysFalse(m VariableMatcher) VariableMatcher {
|
||||||
|
return alwaysFalse{
|
||||||
|
Matcher: m,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlwaysFalseNode is mainly used for unit testing to make a Node immediately.
|
||||||
|
func AlwaysFalseNode(n Node) Node {
|
||||||
|
return alwaysFalse{
|
||||||
|
InnerNode: n,
|
||||||
|
Matcher: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseAs uses a type no one supports to always override with false.
|
||||||
|
func (alwaysFalse) UseAs() Node { return alwaysFalse{} }
|
||||||
|
func (f alwaysFalse) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||||
|
if f.Matcher != nil {
|
||||||
|
n, ok := f.Matcher.ConvertVariable(rego)
|
||||||
|
if ok {
|
||||||
|
return alwaysFalse{
|
||||||
|
Matcher: f.Matcher,
|
||||||
|
InnerNode: n,
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alwaysFalse) SQLString(_ *SQLGenerator) string {
|
||||||
|
return "false"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alwaysFalse) ContainsSQL(_ *SQLGenerator, _ Node) (string, error) {
|
||||||
|
return "false", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alwaysFalse) ContainedInSQL(_ *SQLGenerator, _ Node) (string, error) {
|
||||||
|
return "false", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (alwaysFalse) EqualsSQLString(_ *SQLGenerator, _ bool, _ Node) (string, error) {
|
||||||
|
return "false", nil
|
||||||
|
}
|
|
@ -0,0 +1,69 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ASTArray struct {
|
||||||
|
Source RegoSource
|
||||||
|
Value []Node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Array is typed to whatever the first element is. If there is not first
|
||||||
|
// element, the array element type is invalid.
|
||||||
|
func Array(source RegoSource, nodes ...Node) (Node, error) {
|
||||||
|
for i := 1; i < len(nodes); i++ {
|
||||||
|
if reflect.TypeOf(nodes[0]) != reflect.TypeOf(nodes[i]) {
|
||||||
|
// Do not allow mixed types in arrays
|
||||||
|
return nil, xerrors.Errorf("array element %d in %q: type mismatch", i, source)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ASTArray{Value: nodes, Source: source}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ASTArray) UseAs() Node { return ASTArray{} }
|
||||||
|
|
||||||
|
func (a ASTArray) ContainsSQL(cfg *SQLGenerator, needle Node) (string, error) {
|
||||||
|
// If we have no elements in our set, then our needle is never in the set.
|
||||||
|
if len(a.Value) == 0 {
|
||||||
|
return "false", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This condition supports any contains function if the needle type is
|
||||||
|
// the same as the ASTArray element type.
|
||||||
|
if reflect.TypeOf(a.MyType().UseAs()) != reflect.TypeOf(needle.UseAs()) {
|
||||||
|
return "ArrayContainsError", xerrors.Errorf("array contains %q: type mismatch (%T, %T)",
|
||||||
|
a.Source, a.MyType(), needle)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s = ANY(%s)", needle.SQLString(cfg), a.SQLString(cfg)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a ASTArray) SQLString(cfg *SQLGenerator) string {
|
||||||
|
switch a.MyType().UseAs().(type) {
|
||||||
|
case invalidNode:
|
||||||
|
cfg.AddError(xerrors.Errorf("array %q: empty array", a.Source))
|
||||||
|
return "ArrayError"
|
||||||
|
case AstNumber, AstString, AstBoolean:
|
||||||
|
// Primitive types
|
||||||
|
values := make([]string, 0, len(a.Value))
|
||||||
|
for _, v := range a.Value {
|
||||||
|
values = append(values, v.SQLString(cfg))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("ARRAY [%s]", strings.Join(values, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.AddError(xerrors.Errorf("array %q: unsupported type %T", a.Source, a.MyType()))
|
||||||
|
return "ArrayError"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a ASTArray) MyType() Node {
|
||||||
|
if len(a.Value) == 0 {
|
||||||
|
return invalidNode{}
|
||||||
|
}
|
||||||
|
return a.Value[0]
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type binaryOperator int
|
||||||
|
|
||||||
|
const (
|
||||||
|
_ binaryOperator = iota
|
||||||
|
binaryOpOR
|
||||||
|
binaryOpAND
|
||||||
|
)
|
||||||
|
|
||||||
|
type binaryOp struct {
|
||||||
|
source RegoSource
|
||||||
|
op binaryOperator
|
||||||
|
|
||||||
|
Terms []BooleanNode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (binaryOp) UseAs() Node { return binaryOp{} }
|
||||||
|
func (binaryOp) IsBooleanNode() {}
|
||||||
|
|
||||||
|
func Or(source RegoSource, terms ...BooleanNode) BooleanNode {
|
||||||
|
return newBinaryOp(source, binaryOpOR, terms...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func And(source RegoSource, terms ...BooleanNode) BooleanNode {
|
||||||
|
return newBinaryOp(source, binaryOpAND, terms...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBinaryOp(source RegoSource, op binaryOperator, terms ...BooleanNode) BooleanNode {
|
||||||
|
if len(terms) == 0 {
|
||||||
|
// TODO: How to handle 0 terms?
|
||||||
|
return Bool(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
opTerms := make([]BooleanNode, 0, len(terms))
|
||||||
|
for i := range terms {
|
||||||
|
// Always wrap terms in parentheses to be safe.
|
||||||
|
opTerms = append(opTerms, BoolParenthesis(terms[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opTerms) == 1 {
|
||||||
|
return opTerms[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return binaryOp{
|
||||||
|
Terms: opTerms,
|
||||||
|
op: op,
|
||||||
|
source: source,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b binaryOp) SQLString(cfg *SQLGenerator) string {
|
||||||
|
sqlOp := ""
|
||||||
|
switch b.op {
|
||||||
|
case binaryOpOR:
|
||||||
|
sqlOp = "OR"
|
||||||
|
case binaryOpAND:
|
||||||
|
sqlOp = "AND"
|
||||||
|
default:
|
||||||
|
cfg.AddError(xerrors.Errorf("unsupported binary operator: %s (%d)", b.source, b.op))
|
||||||
|
return "BinaryOpError"
|
||||||
|
}
|
||||||
|
|
||||||
|
terms := make([]string, 0, len(b.Terms))
|
||||||
|
for _, term := range b.Terms {
|
||||||
|
termSQL := term.SQLString(cfg)
|
||||||
|
terms = append(terms, termSQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(terms, " "+sqlOp+" ")
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AstBoolean is a literal true/false value.
|
||||||
|
type AstBoolean struct {
|
||||||
|
Source RegoSource
|
||||||
|
Value bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func Bool(t bool) BooleanNode {
|
||||||
|
return AstBoolean{Value: t, Source: RegoSource(strconv.FormatBool(t))}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (AstBoolean) IsBooleanNode() {}
|
||||||
|
func (AstBoolean) UseAs() Node { return AstBoolean{} }
|
||||||
|
|
||||||
|
func (b AstBoolean) SQLString(_ *SQLGenerator) string {
|
||||||
|
return strconv.FormatBool(b.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b AstBoolean) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||||
|
return boolEqualsSQLString(cfg, b, not, other)
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
// Package sqltypes contains the types used to convert rego queries into SQL.
|
||||||
|
// The rego ast is converted into these types to better control the SQL
|
||||||
|
// generation. It allows writing the SQL generation for types in an easier to
|
||||||
|
// read way.
|
||||||
|
package sqltypes
|
|
@ -0,0 +1,102 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SupportsEquality is an interface that can be implemented by types that
|
||||||
|
// support equality with other types. We defer to other types to implement this
|
||||||
|
// as it is much easier to implement this in the context of the type.
|
||||||
|
type SupportsEquality interface {
|
||||||
|
// EqualsSQLString intentionally returns an error. This is so if
|
||||||
|
// left = right is not supported, we can try right = left.
|
||||||
|
EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ BooleanNode = equality{}
|
||||||
|
var _ Node = equality{}
|
||||||
|
var _ SupportsEquality = equality{}
|
||||||
|
|
||||||
|
type equality struct {
|
||||||
|
Left Node
|
||||||
|
Right Node
|
||||||
|
|
||||||
|
// Not just inverses the result of the comparison. We could implement this
|
||||||
|
// as a Not node wrapping the equality, but this is more efficient.
|
||||||
|
Not bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func Equality(notEquals bool, a, b Node) BooleanNode {
|
||||||
|
return equality{
|
||||||
|
Left: a,
|
||||||
|
Right: b,
|
||||||
|
Not: notEquals,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (equality) IsBooleanNode() {}
|
||||||
|
|
||||||
|
// UseAs returns an ASTBoolean as equalities resolve to boolean values
|
||||||
|
func (equality) UseAs() Node { return AstBoolean{} }
|
||||||
|
|
||||||
|
func (e equality) SQLString(cfg *SQLGenerator) string {
|
||||||
|
// Equalities can be flipped without changing the result, so we can
|
||||||
|
// try both left = right and right = left.
|
||||||
|
if eq, ok := e.Left.(SupportsEquality); ok {
|
||||||
|
v, err := eq.EqualsSQLString(cfg, e.Not, e.Right)
|
||||||
|
if err == nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eq, ok := e.Right.(SupportsEquality); ok {
|
||||||
|
v, err := eq.EqualsSQLString(cfg, e.Not, e.Left)
|
||||||
|
if err == nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.AddError(xerrors.Errorf("unsupported equality: %T %s %T", e.Left, equalsOp(e.Not), e.Right))
|
||||||
|
return "EqualityError"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e equality) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||||
|
return boolEqualsSQLString(cfg, e, not, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolEqualsSQLString(cfg *SQLGenerator, a BooleanNode, not bool, other Node) (string, error) {
|
||||||
|
switch other.UseAs().(type) {
|
||||||
|
case BooleanNode:
|
||||||
|
bn, ok := other.(BooleanNode)
|
||||||
|
if !ok {
|
||||||
|
return "", xerrors.Errorf("not a boolean node: %T", other)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always wrap both sides in parens to ensure the correct precedence.
|
||||||
|
return fmt.Sprintf("%s %s %s",
|
||||||
|
BoolParenthesis(a).SQLString(cfg),
|
||||||
|
equalsOp(not),
|
||||||
|
BoolParenthesis(bn).SQLString(cfg),
|
||||||
|
), nil
|
||||||
|
default:
|
||||||
|
return "", xerrors.Errorf("unsupported equality: %T %s %T", a, equalsOp(not), other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:revive
|
||||||
|
func equalsOp(not bool) string {
|
||||||
|
if not {
|
||||||
|
return "!="
|
||||||
|
}
|
||||||
|
return "="
|
||||||
|
}
|
||||||
|
|
||||||
|
func basicSQLEquality(cfg *SQLGenerator, not bool, a, b Node) string {
|
||||||
|
return fmt.Sprintf("%s %s %s",
|
||||||
|
a.SQLString(cfg),
|
||||||
|
equalsOp(not),
|
||||||
|
b.SQLString(cfg),
|
||||||
|
)
|
||||||
|
}
|
|
@ -0,0 +1,130 @@
|
||||||
|
package sqltypes_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEquality(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := []struct {
|
||||||
|
Name string
|
||||||
|
Equality sqltypes.Node
|
||||||
|
ExpectedSQL string
|
||||||
|
ExpectedErrors int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "String=String",
|
||||||
|
Equality: sqltypes.Equality(false,
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
sqltypes.String("bar"),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "'foo' = 'bar'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Number=Number",
|
||||||
|
Equality: sqltypes.Equality(false,
|
||||||
|
sqltypes.Number("", json.Number("5")),
|
||||||
|
sqltypes.Number("", json.Number("22")),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "5 = 22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Bool=Bool",
|
||||||
|
Equality: sqltypes.Equality(false,
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
sqltypes.Bool(false),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "true = false",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Bool=Equality",
|
||||||
|
Equality: sqltypes.Equality(false,
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
sqltypes.Equality(true,
|
||||||
|
sqltypes.Equality(true,
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
sqltypes.String("bar"),
|
||||||
|
),
|
||||||
|
sqltypes.Bool(false),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "true = (('foo' != 'bar') != false)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Equality=Equality",
|
||||||
|
Equality: sqltypes.Equality(false,
|
||||||
|
sqltypes.Equality(true,
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
sqltypes.Bool(false),
|
||||||
|
),
|
||||||
|
sqltypes.Equality(false,
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "(true != false) = ('foo' = 'foo')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Membership=Membership",
|
||||||
|
Equality: sqltypes.Equality(false,
|
||||||
|
sqltypes.Equality(true,
|
||||||
|
sqltypes.MemberOf(
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
must(sqltypes.Array("",
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
sqltypes.String("bar"),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
sqltypes.Bool(false),
|
||||||
|
),
|
||||||
|
sqltypes.Equality(false,
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
sqltypes.MemberOf(
|
||||||
|
sqltypes.Number("", "2"),
|
||||||
|
must(sqltypes.Array("",
|
||||||
|
sqltypes.Number("", "5"),
|
||||||
|
sqltypes.Number("", "2"),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "(('foo' = ANY(ARRAY ['foo','bar'])) != false) = (true = (2 = ANY(ARRAY [5,2])))",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "AlwaysFalse=String",
|
||||||
|
Equality: sqltypes.Equality(false,
|
||||||
|
sqltypes.AlwaysFalseNode(sqltypes.String("foo")),
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "false",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "String=AlwaysFalse",
|
||||||
|
Equality: sqltypes.Equality(false,
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
sqltypes.AlwaysFalseNode(sqltypes.String("foo")),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "false",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gen := sqltypes.NewSQLGenerator()
|
||||||
|
found := tc.Equality.SQLString(gen)
|
||||||
|
if tc.ExpectedErrors > 0 {
|
||||||
|
require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected AstNumber of errors")
|
||||||
|
} else {
|
||||||
|
require.Equal(t, tc.ExpectedSQL, found, "expected sql")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
type SQLGenerator struct {
|
||||||
|
errors []error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSQLGenerator() *SQLGenerator {
|
||||||
|
return &SQLGenerator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *SQLGenerator) AddError(err error) {
|
||||||
|
if err != nil {
|
||||||
|
g.errors = append(g.errors, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *SQLGenerator) Errors() []error {
|
||||||
|
return g.errors
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SupportsContains is an interface that can be implemented by types that
|
||||||
|
// support "me.Contains(other)". This is `internal_member2` in the rego.
|
||||||
|
type SupportsContains interface {
|
||||||
|
ContainsSQL(cfg *SQLGenerator, other Node) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportsContainedIn is the inverse of SupportsContains. It is implemented
|
||||||
|
// from the "needle" rather than the haystack.
|
||||||
|
type SupportsContainedIn interface {
|
||||||
|
ContainedInSQL(cfg *SQLGenerator, other Node) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ BooleanNode = memberOf{}
|
||||||
|
var _ Node = memberOf{}
|
||||||
|
var _ SupportsEquality = memberOf{}
|
||||||
|
|
||||||
|
type memberOf struct {
|
||||||
|
Needle Node
|
||||||
|
Haystack Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func MemberOf(needle, haystack Node) BooleanNode {
|
||||||
|
return memberOf{
|
||||||
|
Needle: needle,
|
||||||
|
Haystack: haystack,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (memberOf) IsBooleanNode() {}
|
||||||
|
func (memberOf) UseAs() Node { return AstBoolean{} }
|
||||||
|
|
||||||
|
func (e memberOf) SQLString(cfg *SQLGenerator) string {
|
||||||
|
// Equalities can be flipped without changing the result, so we can
|
||||||
|
// try both left = right and right = left.
|
||||||
|
if sc, ok := e.Haystack.(SupportsContains); ok {
|
||||||
|
v, err := sc.ContainsSQL(cfg, e.Needle)
|
||||||
|
if err == nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc, ok := e.Needle.(SupportsContainedIn); ok {
|
||||||
|
v, err := sc.ContainedInSQL(cfg, e.Haystack)
|
||||||
|
if err == nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.AddError(xerrors.Errorf("unsupported contains: %T contains %T", e.Haystack, e.Needle))
|
||||||
|
return "MemberOfError"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e memberOf) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||||
|
return boolEqualsSQLString(cfg, e, not, other)
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
package sqltypes_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMembership(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCases := []struct {
|
||||||
|
Name string
|
||||||
|
Membership sqltypes.Node
|
||||||
|
ExpectedSQL string
|
||||||
|
ExpectedErrors int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "StringArray",
|
||||||
|
Membership: sqltypes.MemberOf(
|
||||||
|
sqltypes.String("foo"),
|
||||||
|
must(sqltypes.Array("",
|
||||||
|
sqltypes.String("bar"),
|
||||||
|
sqltypes.String("buzz"),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "'foo' = ANY(ARRAY ['bar','buzz'])",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "NumberArray",
|
||||||
|
Membership: sqltypes.MemberOf(
|
||||||
|
sqltypes.Number("", "5"),
|
||||||
|
must(sqltypes.Array("",
|
||||||
|
sqltypes.Number("", "2"),
|
||||||
|
sqltypes.Number("", "5"),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "5 = ANY(ARRAY [2,5])",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "BoolArray",
|
||||||
|
Membership: sqltypes.MemberOf(
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
must(sqltypes.Array("",
|
||||||
|
sqltypes.Bool(false),
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "true = ANY(ARRAY [false,true])",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "EmptyArray",
|
||||||
|
Membership: sqltypes.MemberOf(
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
must(sqltypes.Array("")),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "false",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "AlwaysFalseMember",
|
||||||
|
Membership: sqltypes.MemberOf(
|
||||||
|
sqltypes.AlwaysFalseNode(sqltypes.Bool(true)),
|
||||||
|
must(sqltypes.Array("",
|
||||||
|
sqltypes.Bool(false),
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "false",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "AlwaysFalseArray",
|
||||||
|
Membership: sqltypes.MemberOf(
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
sqltypes.AlwaysFalseNode(must(sqltypes.Array("",
|
||||||
|
sqltypes.Bool(false),
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
))),
|
||||||
|
),
|
||||||
|
ExpectedSQL: "false",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Errors
|
||||||
|
{
|
||||||
|
Name: "Unsupported",
|
||||||
|
Membership: sqltypes.MemberOf(
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
sqltypes.Bool(true),
|
||||||
|
),
|
||||||
|
ExpectedErrors: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gen := sqltypes.NewSQLGenerator()
|
||||||
|
found := tc.Membership.SQLString(gen)
|
||||||
|
if tc.ExpectedErrors > 0 {
|
||||||
|
require.Equal(t, tc.ExpectedErrors, len(gen.Errors()), "expected some errors")
|
||||||
|
} else {
|
||||||
|
require.Equal(t, tc.ExpectedSQL, found, "expected sql")
|
||||||
|
require.Equal(t, tc.ExpectedErrors, 0, "expected no errors")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func must[V any](v V, err error) V {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Node interface {
|
||||||
|
SQLString(cfg *SQLGenerator) string
|
||||||
|
// UseAs is a helper function to allow a node to be used as a different
|
||||||
|
// Node in operators. For example, a variable is really just a "string", so
|
||||||
|
// having the Equality operator check for "String" or "StringVar" is just
|
||||||
|
// excessive. Instead, we can just have the variable implement this function.
|
||||||
|
UseAs() Node
|
||||||
|
}
|
||||||
|
|
||||||
|
// BooleanNode is a node that returns a true/false when evaluated.
|
||||||
|
type BooleanNode interface {
|
||||||
|
Node
|
||||||
|
IsBooleanNode()
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegoSource string
|
||||||
|
|
||||||
|
type invalidNode struct{}
|
||||||
|
|
||||||
|
func (invalidNode) UseAs() Node { return invalidNode{} }
|
||||||
|
|
||||||
|
func (invalidNode) SQLString(cfg *SQLGenerator) string {
|
||||||
|
cfg.AddError(xerrors.Errorf("invalid node called"))
|
||||||
|
return "invalid_type"
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsPrimitive(n Node) bool {
|
||||||
|
switch n.(type) {
|
||||||
|
case AstBoolean, AstString, AstNumber:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AstNumber struct {
|
||||||
|
Source RegoSource
|
||||||
|
// Value is intentionally vague as to if it's an integer or a float.
|
||||||
|
// This defers that decision to the user. Rego keeps all numbers in this
|
||||||
|
// type. If we were to source the type from something other than Rego,
|
||||||
|
// we might want to make a Float and Int type which keep the original
|
||||||
|
// precision.
|
||||||
|
Value json.Number
|
||||||
|
}
|
||||||
|
|
||||||
|
func Number(source RegoSource, v json.Number) Node {
|
||||||
|
return AstNumber{Value: v, Source: source}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (AstNumber) UseAs() Node { return AstNumber{} }
|
||||||
|
|
||||||
|
func (n AstNumber) SQLString(_ *SQLGenerator) string {
|
||||||
|
return n.Value.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n AstNumber) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||||
|
switch other.UseAs().(type) {
|
||||||
|
case AstNumber:
|
||||||
|
return basicSQLEquality(cfg, not, n, other), nil
|
||||||
|
default:
|
||||||
|
return "", xerrors.Errorf("unsupported equality: %T %s %T", n, equalsOp(not), other)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type astParenthesis struct {
|
||||||
|
Value BooleanNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// BoolParenthesis wraps the given boolean node in parens.
|
||||||
|
// This is useful for grouping and avoiding ambiguity. This does not work for
|
||||||
|
// mathematical parenthesis to change order of operations.
|
||||||
|
func BoolParenthesis(value BooleanNode) BooleanNode {
|
||||||
|
// Wrapping primitives is useless.
|
||||||
|
if IsPrimitive(value) {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap any existing parens. Do not add excess parens.
|
||||||
|
if p, ok := value.(astParenthesis); ok {
|
||||||
|
return BoolParenthesis(p.Value)
|
||||||
|
}
|
||||||
|
return astParenthesis{Value: value}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (astParenthesis) IsBooleanNode() {}
|
||||||
|
func (p astParenthesis) UseAs() Node { return p.Value.UseAs() }
|
||||||
|
func (p astParenthesis) SQLString(cfg *SQLGenerator) string {
|
||||||
|
return "(" + p.Value.SQLString(cfg) + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p astParenthesis) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||||
|
if supp, ok := p.Value.(SupportsEquality); ok {
|
||||||
|
return supp.EqualsSQLString(cfg, not, other)
|
||||||
|
}
|
||||||
|
return "", xerrors.Errorf("unsupported equality: %T %s %T", p.Value, equalsOp(not), other)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p astParenthesis) ContainsSQL(cfg *SQLGenerator, other Node) (string, error) {
|
||||||
|
if supp, ok := p.Value.(SupportsContains); ok {
|
||||||
|
return supp.ContainsSQL(cfg, other)
|
||||||
|
}
|
||||||
|
return "", xerrors.Errorf("unsupported contains: %T %T", p.Value, other)
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AstString struct {
|
||||||
|
Source RegoSource
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func String(v string) Node {
|
||||||
|
return AstString{Value: v, Source: RegoSource(v)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (AstString) UseAs() Node { return AstString{} }
|
||||||
|
|
||||||
|
func (s AstString) SQLString(_ *SQLGenerator) string {
|
||||||
|
return "'" + s.Value + "'"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AstString) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||||
|
switch other.UseAs().(type) {
|
||||||
|
case AstString:
|
||||||
|
return basicSQLEquality(cfg, not, s, other), nil
|
||||||
|
default:
|
||||||
|
return "", xerrors.Errorf("unsupported equality: %T %s %T", s, equalsOp(not), other)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,113 @@
|
||||||
|
package sqltypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/open-policy-agent/opa/ast"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VariableMatcher interface {
|
||||||
|
ConvertVariable(rego ast.Ref) (Node, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VariableConverter struct {
|
||||||
|
converters []VariableMatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVariableConverter() *VariableConverter {
|
||||||
|
return &VariableConverter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vc *VariableConverter) RegisterMatcher(m ...VariableMatcher) *VariableConverter {
|
||||||
|
vc.converters = append(vc.converters, m...)
|
||||||
|
// Returns the VariableConverter for easier instantiation
|
||||||
|
return vc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vc *VariableConverter) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||||
|
for _, c := range vc.converters {
|
||||||
|
if n, ok := c.ConvertVariable(rego); ok {
|
||||||
|
return n, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegoVarPath will consume the following terms from the given rego Ref and
|
||||||
|
// return the remaining terms. If the path does not fully match, an error is
|
||||||
|
// returned. The first term must always be a Var.
|
||||||
|
func RegoVarPath(path []string, terms []*ast.Term) ([]*ast.Term, error) {
|
||||||
|
if len(terms) < len(path) {
|
||||||
|
return nil, xerrors.Errorf("path %s longer than rego path %s", path, terms)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(terms) == 0 || len(path) == 0 {
|
||||||
|
return nil, xerrors.Errorf("path %s and rego path %s must not be empty", path, terms)
|
||||||
|
}
|
||||||
|
|
||||||
|
varTerm, ok := terms[0].Value.(ast.Var)
|
||||||
|
if !ok {
|
||||||
|
return nil, xerrors.Errorf("expected var, got %T", terms[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(varTerm) != path[0] {
|
||||||
|
return nil, xerrors.Errorf("expected var %s, got %s", path[0], varTerm)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < len(path); i++ {
|
||||||
|
nextTerm, ok := terms[i].Value.(ast.String)
|
||||||
|
if !ok {
|
||||||
|
return nil, xerrors.Errorf("expected ast.string, got %T", terms[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(nextTerm) != path[i] {
|
||||||
|
return nil, xerrors.Errorf("expected string %s, got %s", path[i], nextTerm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return terms[len(path):], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ VariableMatcher = astStringVar{}
|
||||||
|
var _ Node = astStringVar{}
|
||||||
|
|
||||||
|
// astStringVar is any variable that represents a string.
|
||||||
|
type astStringVar struct {
|
||||||
|
Source RegoSource
|
||||||
|
FieldPath []string
|
||||||
|
ColumnString string
|
||||||
|
}
|
||||||
|
|
||||||
|
func StringVarMatcher(sqlString string, regoPath []string) VariableMatcher {
|
||||||
|
return astStringVar{FieldPath: regoPath, ColumnString: sqlString}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (astStringVar) UseAs() Node { return AstString{} }
|
||||||
|
|
||||||
|
// ConvertVariable will return a new astStringVar Node if the given rego Ref
|
||||||
|
// matches this astStringVar.
|
||||||
|
func (s astStringVar) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||||
|
left, err := RegoVarPath(s.FieldPath, rego)
|
||||||
|
if err == nil && len(left) == 0 {
|
||||||
|
return astStringVar{
|
||||||
|
Source: RegoSource(rego.String()),
|
||||||
|
FieldPath: s.FieldPath,
|
||||||
|
ColumnString: s.ColumnString,
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s astStringVar) SQLString(_ *SQLGenerator) string {
|
||||||
|
return s.ColumnString
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s astStringVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||||
|
switch other.UseAs().(type) {
|
||||||
|
case AstString:
|
||||||
|
return basicSQLEquality(cfg, not, s, other), nil
|
||||||
|
default:
|
||||||
|
return "", xerrors.Errorf("unsupported equality: %T %s %T", s, equalsOp(not), other)
|
||||||
|
}
|
||||||
|
}
|
|
@ -316,25 +316,27 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
|
||||||
func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
organization := httpmw.OrganizationParam(r)
|
organization := httpmw.OrganizationParam(r)
|
||||||
templates, err := api.Database.GetTemplatesWithFilter(ctx, database.GetTemplatesWithFilterParams{
|
|
||||||
OrganizationID: organization.ID,
|
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceTemplate.Type)
|
||||||
})
|
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: "Internal error fetching templates in organization.",
|
Message: "Internal error preparing sql filter.",
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter templates based on rbac permissions
|
// Filter templates based on rbac permissions
|
||||||
templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates)
|
templates, err := api.Database.GetAuthorizedTemplates(ctx, database.GetTemplatesWithFilterParams{
|
||||||
|
OrganizationID: organization.ID,
|
||||||
|
}, prepared)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: "Internal error fetching templates.",
|
Message: "Internal error fetching templates in organization.",
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
|
||||||
"github.com/coder/coder/agent"
|
"github.com/coder/coder/agent"
|
||||||
"github.com/coder/coder/coderd/audit"
|
"github.com/coder/coder/coderd/audit"
|
||||||
"github.com/coder/coder/coderd/coderdtest"
|
"github.com/coder/coder/coderd/coderdtest"
|
||||||
|
|
|
@ -118,7 +118,8 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
|
||||||
filter.OwnerUsername = ""
|
filter.OwnerUsername = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
|
// Workspaces do not have ACL columns.
|
||||||
|
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceWorkspace.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: "Internal error preparing sql filter.",
|
Message: "Internal error preparing sql filter.",
|
||||||
|
@ -127,7 +128,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, sqlFilter)
|
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, prepared)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: "Internal error fetching workspaces.",
|
Message: "Internal error fetching workspaces.",
|
||||||
|
|
|
@ -31,7 +31,7 @@ func (api *API) postGroupByOrganization(rw http.ResponseWriter, r *http.Request)
|
||||||
)
|
)
|
||||||
defer commitAudit()
|
defer commitAudit()
|
||||||
|
|
||||||
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceGroup) {
|
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceGroup.InOrg(org.ID)) {
|
||||||
http.NotFound(rw, r)
|
http.NotFound(rw, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,9 @@ import (
|
||||||
"github.com/coder/coder/coderd/audit"
|
"github.com/coder/coder/coderd/audit"
|
||||||
"github.com/coder/coder/coderd/coderdtest"
|
"github.com/coder/coder/coderd/coderdtest"
|
||||||
"github.com/coder/coder/coderd/database"
|
"github.com/coder/coder/coderd/database"
|
||||||
|
"github.com/coder/coder/coderd/rbac"
|
||||||
"github.com/coder/coder/codersdk"
|
"github.com/coder/coder/codersdk"
|
||||||
|
"github.com/coder/coder/cryptorand"
|
||||||
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
||||||
"github.com/coder/coder/provisioner/echo"
|
"github.com/coder/coder/provisioner/echo"
|
||||||
"github.com/coder/coder/testutil"
|
"github.com/coder/coder/testutil"
|
||||||
|
@ -747,3 +749,210 @@ func TestUpdateTemplateACL(t *testing.T) {
|
||||||
require.Equal(t, http.StatusNotFound, cerr.StatusCode())
|
require.Equal(t, http.StatusNotFound, cerr.StatusCode())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTemplateAccess tests the rego -> sql conversion. We need to implement
|
||||||
|
// this test on at least 1 table type to ensure that the conversion is correct.
|
||||||
|
// The rbac tests only assert against static SQL queries.
|
||||||
|
// This is a full rbac test of many of the common role combinations.
|
||||||
|
//
|
||||||
|
//nolint:tparallel
|
||||||
|
func TestTemplateAccess(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
ownerClient := coderdenttest.New(t, nil)
|
||||||
|
owner := coderdtest.CreateFirstUser(t, ownerClient)
|
||||||
|
_ = coderdenttest.AddLicense(t, ownerClient, coderdenttest.LicenseOptions{
|
||||||
|
TemplateRBAC: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
type coderUser struct {
|
||||||
|
*codersdk.Client
|
||||||
|
User codersdk.User
|
||||||
|
}
|
||||||
|
|
||||||
|
type orgSetup struct {
|
||||||
|
Admin coderUser
|
||||||
|
MemberInGroup coderUser
|
||||||
|
MemberNoGroup coderUser
|
||||||
|
|
||||||
|
DefaultTemplate codersdk.Template
|
||||||
|
AllRead codersdk.Template
|
||||||
|
UserACL codersdk.Template
|
||||||
|
GroupACL codersdk.Template
|
||||||
|
|
||||||
|
Group codersdk.Group
|
||||||
|
Org codersdk.Organization
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the following users
|
||||||
|
// - owner: Site wide owner
|
||||||
|
// - template-admin
|
||||||
|
// - org-admin (org 1)
|
||||||
|
// - org-admin (org 2)
|
||||||
|
// - org-member (org 1)
|
||||||
|
// - org-member (org 2)
|
||||||
|
|
||||||
|
// Create the following templates in each org
|
||||||
|
// - template 1, default acls
|
||||||
|
// - template 2, all_user read
|
||||||
|
// - template 3, user_acl read for member
|
||||||
|
// - template 4, group_acl read for groupMember
|
||||||
|
|
||||||
|
templateAdmin := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin())
|
||||||
|
|
||||||
|
makeTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, acl codersdk.UpdateTemplateACL) codersdk.Template {
|
||||||
|
version := coderdtest.CreateTemplateVersion(t, client, orgID, nil)
|
||||||
|
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
|
||||||
|
|
||||||
|
err := client.UpdateTemplateACL(ctx, template.ID, acl)
|
||||||
|
require.NoError(t, err, "failed to update template acl")
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
|
|
||||||
|
makeOrg := func(t *testing.T) orgSetup {
|
||||||
|
// Make org
|
||||||
|
orgName, err := cryptorand.String(5)
|
||||||
|
require.NoError(t, err, "org name")
|
||||||
|
|
||||||
|
// Make users
|
||||||
|
newOrg, err := ownerClient.CreateOrganization(ctx, codersdk.CreateOrganizationRequest{Name: orgName})
|
||||||
|
require.NoError(t, err, "failed to create org")
|
||||||
|
|
||||||
|
adminCli, adminUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgAdmin(newOrg.ID))
|
||||||
|
groupMemCli, groupMemUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgMember(newOrg.ID))
|
||||||
|
memberCli, memberUsr := coderdtest.CreateAnotherUserWithUser(t, ownerClient, newOrg.ID, rbac.RoleOrgMember(newOrg.ID))
|
||||||
|
|
||||||
|
// Make group
|
||||||
|
group, err := adminCli.CreateGroup(ctx, newOrg.ID, codersdk.CreateGroupRequest{
|
||||||
|
Name: "SingleUser",
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "failed to create group")
|
||||||
|
|
||||||
|
group, err = adminCli.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{
|
||||||
|
AddUsers: []string{groupMemUsr.ID.String()},
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "failed to add user to group")
|
||||||
|
|
||||||
|
// Make templates
|
||||||
|
|
||||||
|
return orgSetup{
|
||||||
|
Admin: coderUser{Client: adminCli, User: adminUsr},
|
||||||
|
MemberInGroup: coderUser{Client: groupMemCli, User: groupMemUsr},
|
||||||
|
MemberNoGroup: coderUser{Client: memberCli, User: memberUsr},
|
||||||
|
Org: newOrg,
|
||||||
|
Group: group,
|
||||||
|
|
||||||
|
DefaultTemplate: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
|
||||||
|
GroupPerms: map[string]codersdk.TemplateRole{
|
||||||
|
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
AllRead: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
|
||||||
|
GroupPerms: map[string]codersdk.TemplateRole{
|
||||||
|
newOrg.ID.String(): codersdk.TemplateRoleUse,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
UserACL: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
|
||||||
|
GroupPerms: map[string]codersdk.TemplateRole{
|
||||||
|
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
|
||||||
|
},
|
||||||
|
UserPerms: map[string]codersdk.TemplateRole{
|
||||||
|
memberUsr.ID.String(): codersdk.TemplateRoleUse,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
GroupACL: makeTemplate(t, adminCli, newOrg.ID, codersdk.UpdateTemplateACL{
|
||||||
|
GroupPerms: map[string]codersdk.TemplateRole{
|
||||||
|
group.ID.String(): codersdk.TemplateRoleUse,
|
||||||
|
newOrg.ID.String(): codersdk.TemplateRoleDeleted,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make 2 organizations
|
||||||
|
orgs := []orgSetup{
|
||||||
|
makeOrg(t),
|
||||||
|
makeOrg(t),
|
||||||
|
}
|
||||||
|
|
||||||
|
testTemplateRead := func(t *testing.T, org orgSetup, usr *codersdk.Client, read []codersdk.Template) {
|
||||||
|
found, err := usr.TemplatesByOrganization(ctx, org.Org.ID)
|
||||||
|
require.NoError(t, err, "failed to get templates")
|
||||||
|
|
||||||
|
exp := make(map[uuid.UUID]codersdk.Template)
|
||||||
|
for _, tmpl := range read {
|
||||||
|
exp[tmpl.ID] = tmpl
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range found {
|
||||||
|
if _, ok := exp[f.ID]; !ok {
|
||||||
|
t.Errorf("found unexpected template %q", f.Name)
|
||||||
|
}
|
||||||
|
delete(exp, f.ID)
|
||||||
|
}
|
||||||
|
require.Len(t, exp, 0, "expected templates not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:paralleltest
|
||||||
|
t.Run("OwnerReadAll", func(t *testing.T) {
|
||||||
|
for _, o := range orgs {
|
||||||
|
// Owners can read all templates in all orgs
|
||||||
|
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
|
||||||
|
testTemplateRead(t, o, ownerClient, exp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// nolint:paralleltest
|
||||||
|
t.Run("TemplateAdminReadAll", func(t *testing.T) {
|
||||||
|
for _, o := range orgs {
|
||||||
|
// Template Admins can read all templates in all orgs
|
||||||
|
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
|
||||||
|
testTemplateRead(t, o, templateAdmin, exp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// nolint:paralleltest
|
||||||
|
t.Run("OrgAdminReadAllTheirs", func(t *testing.T) {
|
||||||
|
for i, o := range orgs {
|
||||||
|
cli := o.Admin.Client
|
||||||
|
// Only read their own org
|
||||||
|
exp := []codersdk.Template{o.DefaultTemplate, o.AllRead, o.UserACL, o.GroupACL}
|
||||||
|
testTemplateRead(t, o, cli, exp)
|
||||||
|
|
||||||
|
other := orgs[(i+1)%len(orgs)]
|
||||||
|
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
|
||||||
|
testTemplateRead(t, other, cli, []codersdk.Template{})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// nolint:paralleltest
|
||||||
|
t.Run("TestMemberNoGroup", func(t *testing.T) {
|
||||||
|
for i, o := range orgs {
|
||||||
|
cli := o.MemberNoGroup.Client
|
||||||
|
// Only read their own org
|
||||||
|
exp := []codersdk.Template{o.AllRead, o.UserACL}
|
||||||
|
testTemplateRead(t, o, cli, exp)
|
||||||
|
|
||||||
|
other := orgs[(i+1)%len(orgs)]
|
||||||
|
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
|
||||||
|
testTemplateRead(t, other, cli, []codersdk.Template{})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// nolint:paralleltest
|
||||||
|
t.Run("TestMemberInGroup", func(t *testing.T) {
|
||||||
|
for i, o := range orgs {
|
||||||
|
cli := o.MemberInGroup.Client
|
||||||
|
// Only read their own org
|
||||||
|
exp := []codersdk.Template{o.AllRead, o.GroupACL}
|
||||||
|
testTemplateRead(t, o, cli, exp)
|
||||||
|
|
||||||
|
other := orgs[(i+1)%len(orgs)]
|
||||||
|
require.NotEqual(t, other.Org.ID, o.Org.ID, "this test needs at least 2 orgs")
|
||||||
|
testTemplateRead(t, other, cli, []codersdk.Template{})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue