tavern/storage/acl.go

209 lines
6.6 KiB
Go

package storage
import (
"context"
"fmt"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/errors"
)
type ACLStorage interface {
RecordACL(ctx context.Context, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error)
RecordACLAll(ctx context.Context, rowID uuid.UUID, createdAt, updatedAt time.Time, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error)
ListACLsByScope(ctx context.Context, actorRowID uuid.UUID) (ACLs, error)
ListACLsByScopes(ctx context.Context, actorRowIDs []uuid.UUID) (ACLs, error)
ListACLsByTarget(ctx context.Context, target string) (ACLs, error)
ListACLsByTargets(ctx context.Context, targets []string) (ACLs, error)
ListACLsByScopesAndWildTargets(ctx context.Context, actorRowIDs []uuid.UUID, targets []string) (ACLs, error)
}
type ACL struct {
ID uuid.UUID
CreatedAt time.Time
UpdatedAt time.Time
Scope uuid.UUID
Target string
Action ACLAction
Priority int
}
type aclSorter struct {
acls []ACL
by func(a1, a2 *ACL) bool
}
func (s *aclSorter) Len() int {
return len(s.acls)
}
func (s *aclSorter) Swap(i, j int) {
s.acls[i], s.acls[j] = s.acls[j], s.acls[i]
}
func (s *aclSorter) Less(i, j int) bool {
return s.by(&s.acls[i], &s.acls[j])
}
type ACLs []ACL
func (l ACLs) ToMatchSet(highScopes *common.UniqueUUIDs, defaultAllow bool) common.MatcherSet {
highAllow := make([]common.AllowDenyMatcher, 0)
highDeny := make([]common.AllowDenyMatcher, 0)
lowAllow := make([]common.AllowDenyMatcher, 0)
lowDeny := make([]common.AllowDenyMatcher, 0)
fmt.Println(l)
for _, acl := range l {
matcher := common.NewAllowDenyMatcher(acl.Target)
_, isHigh := highScopes.Tracking[acl.Scope]
fmt.Println("placing matcher", isHigh, acl.Action)
if isHigh && acl.Action == AllowACLAction {
highAllow = append(highAllow, matcher)
} else if isHigh && acl.Action == DenyACLAction {
highDeny = append(highDeny, matcher)
} else if !isHigh && acl.Action == AllowACLAction {
lowAllow = append(lowAllow, matcher)
} else if !isHigh && acl.Action == DenyACLAction {
lowDeny = append(lowDeny, matcher)
}
}
return common.MatcherSet{
Default: defaultAllow,
HighAllow: highAllow,
HighDeny: highDeny,
LowAllow: lowAllow,
LowDeny: lowDeny,
}
}
//
// func (l ACLs) WithPriority(serverScope uuid.UUID) ACLs {
// results := make(ACLs, len(l))
// for i, val := range l {
// priority := 1
// if val.Scope == serverScope {
// priority++
// }
// if val.Action == AllowACLAction {
// priority++
// }
// results[i] = ACL{
// ID: val.ID,
// CreatedAt: val.CreatedAt,
// UpdatedAt: val.UpdatedAt,
// Scope: val.Scope,
// Target: val.Target,
// Action: val.Action,
// Priority: priority,
// }
// }
//
// ps := &aclSorter{
// acls: results,
// by: func(a1, a2 *ACL) bool {
// return a1.Priority < a2.Priority
// },
// }
// sort.Sort(ps)
//
// for _, acl := range l {
// fmt.Println(acl.Scope, acl.Target, acl.Action)
// }
//
// return results
// }
type ACLAction int8
const (
AllowACLAction ACLAction = 0
DenyACLAction ACLAction = 1
)
var aclsFields = []string{
"id",
"created_at",
"updated_at",
"scope",
"target",
"action",
}
func (s pgStorage) RecordACL(ctx context.Context, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) {
rowID := NewV4()
now := s.now()
return s.RecordACLAll(ctx, rowID, now, now, actorRowID, target, action, wild)
}
func (s pgStorage) RecordACLAll(ctx context.Context, rowID uuid.UUID, createdAt, updatedAt time.Time, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) {
query := `INSERT INTO acls (id, created_at, updated_at, scope, target, action, wild) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT ON CONSTRAINT acls_scope_action_uindex DO UPDATE SET updated_at = $3, action = $6 RETURNING id`
var id uuid.UUID
err := s.db.QueryRowContext(ctx, query, rowID, createdAt, updatedAt, actorRowID, target, action, wild).Scan(&id)
return id, errors.WrapGroupUpsertFailedError(err)
}
func (s pgStorage) ListACLsByScope(ctx context.Context, actorRowID uuid.UUID) (ACLs, error) {
query := fmt.Sprintf(`SELECT %s FROM acls WHERE scope = $1`, strings.Join(aclsFields, ","))
return s.queryACLs(ctx, query, actorRowID)
}
func (s pgStorage) ListACLsByScopes(ctx context.Context, actorRowIDs []uuid.UUID) (ACLs, error) {
if len(actorRowIDs) == 0 {
return []ACL{}, nil
}
query := fmt.Sprintf("SELECT %s FROM acls WHERE scope IN (%s)", strings.Join(aclsFields, ","), strings.Join(common.DollarForEach(len(actorRowIDs)), ","))
return s.queryACLs(ctx, query, common.UUIDsToInterfaces(actorRowIDs)...)
}
func (s pgStorage) ListACLsByTarget(ctx context.Context, target string) (ACLs, error) {
query := fmt.Sprintf(`SELECT %s FROM acls WHERE target = $1`, strings.Join(aclsFields, ","))
return s.queryACLs(ctx, query, target)
}
func (s pgStorage) ListACLsByTargets(ctx context.Context, targets []string) (ACLs, error) {
if len(targets) == 0 {
return []ACL{}, nil
}
query := fmt.Sprintf(`SELECT %s FROM acls WHERE target IN (%s)`, strings.Join(aclsFields, ","), strings.Join(common.DollarForEach(len(targets)), ","))
return s.queryACLs(ctx, query, common.StringsToInterfaces(targets)...)
}
func (s pgStorage) ListACLsByScopesAndWildTargets(ctx context.Context, actorRowIDs []uuid.UUID, targets []string) (ACLs, error) {
if len(actorRowIDs) == 0 || len(targets) == 0 {
return []ACL{}, nil
}
scopePlaceholders := strings.Join(common.DollarForEach(len(actorRowIDs)), ",")
targetPlaceholders := strings.Join(common.DollarForEachFrom(len(actorRowIDs)+1, len(actorRowIDs)+len(targets)), ",")
query := fmt.Sprintf(`SELECT %s FROM acls WHERE scope IN (%s) AND (target IN (%s) OR wild = $%d)`, strings.Join(aclsFields, ","), scopePlaceholders, targetPlaceholders, len(actorRowIDs)+len(targets)+1)
args := make([]interface{}, 0)
for _, actorRowID := range actorRowIDs {
args = append(args, actorRowID)
}
for _, target := range targets {
args = append(args, target)
}
args = append(args, true)
return s.queryACLs(ctx, query, args...)
}
func (s pgStorage) queryACLs(ctx context.Context, query string, args ...interface{}) (ACLs, error) {
results := make([]ACL, 0)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.NewObjectSelectFailedError(err)
}
defer rows.Close()
for rows.Next() {
var acl ACL
if err := rows.Scan(&acl.ID, &acl.CreatedAt, &acl.UpdatedAt, &acl.Scope, &acl.Target, &acl.Action); err != nil {
return nil, errors.NewInvalidObjectError(err)
}
results = append(results, acl)
}
return results, nil
}