mirror of https://gitlab.com/ngerakines/tavern.git
209 lines
6.6 KiB
Go
209 lines
6.6 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofrs/uuid"
|
|
|
|
"github.com/ngerakines/tavern/common"
|
|
"github.com/ngerakines/tavern/errors"
|
|
)
|
|
|
|
type ACLStorage interface {
|
|
RecordACL(ctx context.Context, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error)
|
|
RecordACLAll(ctx context.Context, rowID uuid.UUID, createdAt, updatedAt time.Time, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error)
|
|
ListACLsByScope(ctx context.Context, actorRowID uuid.UUID) (ACLs, error)
|
|
ListACLsByScopes(ctx context.Context, actorRowIDs []uuid.UUID) (ACLs, error)
|
|
ListACLsByTarget(ctx context.Context, target string) (ACLs, error)
|
|
ListACLsByTargets(ctx context.Context, targets []string) (ACLs, error)
|
|
ListACLsByScopesAndWildTargets(ctx context.Context, actorRowIDs []uuid.UUID, targets []string) (ACLs, error)
|
|
}
|
|
|
|
type ACL struct {
|
|
ID uuid.UUID
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
Scope uuid.UUID
|
|
Target string
|
|
Action ACLAction
|
|
Priority int
|
|
}
|
|
|
|
type aclSorter struct {
|
|
acls []ACL
|
|
by func(a1, a2 *ACL) bool
|
|
}
|
|
|
|
func (s *aclSorter) Len() int {
|
|
return len(s.acls)
|
|
}
|
|
|
|
func (s *aclSorter) Swap(i, j int) {
|
|
s.acls[i], s.acls[j] = s.acls[j], s.acls[i]
|
|
}
|
|
|
|
func (s *aclSorter) Less(i, j int) bool {
|
|
return s.by(&s.acls[i], &s.acls[j])
|
|
}
|
|
|
|
type ACLs []ACL
|
|
|
|
func (l ACLs) ToMatchSet(highScopes *common.UniqueUUIDs, defaultAllow bool) common.MatcherSet {
|
|
highAllow := make([]common.AllowDenyMatcher, 0)
|
|
highDeny := make([]common.AllowDenyMatcher, 0)
|
|
lowAllow := make([]common.AllowDenyMatcher, 0)
|
|
lowDeny := make([]common.AllowDenyMatcher, 0)
|
|
fmt.Println(l)
|
|
for _, acl := range l {
|
|
matcher := common.NewAllowDenyMatcher(acl.Target)
|
|
_, isHigh := highScopes.Tracking[acl.Scope]
|
|
fmt.Println("placing matcher", isHigh, acl.Action)
|
|
if isHigh && acl.Action == AllowACLAction {
|
|
highAllow = append(highAllow, matcher)
|
|
} else if isHigh && acl.Action == DenyACLAction {
|
|
highDeny = append(highDeny, matcher)
|
|
} else if !isHigh && acl.Action == AllowACLAction {
|
|
lowAllow = append(lowAllow, matcher)
|
|
} else if !isHigh && acl.Action == DenyACLAction {
|
|
lowDeny = append(lowDeny, matcher)
|
|
}
|
|
}
|
|
return common.MatcherSet{
|
|
Default: defaultAllow,
|
|
HighAllow: highAllow,
|
|
HighDeny: highDeny,
|
|
LowAllow: lowAllow,
|
|
LowDeny: lowDeny,
|
|
}
|
|
}
|
|
|
|
//
|
|
// func (l ACLs) WithPriority(serverScope uuid.UUID) ACLs {
|
|
// results := make(ACLs, len(l))
|
|
// for i, val := range l {
|
|
// priority := 1
|
|
// if val.Scope == serverScope {
|
|
// priority++
|
|
// }
|
|
// if val.Action == AllowACLAction {
|
|
// priority++
|
|
// }
|
|
// results[i] = ACL{
|
|
// ID: val.ID,
|
|
// CreatedAt: val.CreatedAt,
|
|
// UpdatedAt: val.UpdatedAt,
|
|
// Scope: val.Scope,
|
|
// Target: val.Target,
|
|
// Action: val.Action,
|
|
// Priority: priority,
|
|
// }
|
|
// }
|
|
//
|
|
// ps := &aclSorter{
|
|
// acls: results,
|
|
// by: func(a1, a2 *ACL) bool {
|
|
// return a1.Priority < a2.Priority
|
|
// },
|
|
// }
|
|
// sort.Sort(ps)
|
|
//
|
|
// for _, acl := range l {
|
|
// fmt.Println(acl.Scope, acl.Target, acl.Action)
|
|
// }
|
|
//
|
|
// return results
|
|
// }
|
|
|
|
type ACLAction int8
|
|
|
|
const (
|
|
AllowACLAction ACLAction = 0
|
|
DenyACLAction ACLAction = 1
|
|
)
|
|
|
|
var aclsFields = []string{
|
|
"id",
|
|
"created_at",
|
|
"updated_at",
|
|
"scope",
|
|
"target",
|
|
"action",
|
|
}
|
|
|
|
func (s pgStorage) RecordACL(ctx context.Context, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) {
|
|
rowID := NewV4()
|
|
now := s.now()
|
|
return s.RecordACLAll(ctx, rowID, now, now, actorRowID, target, action, wild)
|
|
}
|
|
|
|
func (s pgStorage) RecordACLAll(ctx context.Context, rowID uuid.UUID, createdAt, updatedAt time.Time, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) {
|
|
query := `INSERT INTO acls (id, created_at, updated_at, scope, target, action, wild) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT ON CONSTRAINT acls_scope_action_uindex DO UPDATE SET updated_at = $3, action = $6 RETURNING id`
|
|
var id uuid.UUID
|
|
err := s.db.QueryRowContext(ctx, query, rowID, createdAt, updatedAt, actorRowID, target, action, wild).Scan(&id)
|
|
return id, errors.WrapGroupUpsertFailedError(err)
|
|
}
|
|
|
|
func (s pgStorage) ListACLsByScope(ctx context.Context, actorRowID uuid.UUID) (ACLs, error) {
|
|
query := fmt.Sprintf(`SELECT %s FROM acls WHERE scope = $1`, strings.Join(aclsFields, ","))
|
|
return s.queryACLs(ctx, query, actorRowID)
|
|
}
|
|
|
|
func (s pgStorage) ListACLsByScopes(ctx context.Context, actorRowIDs []uuid.UUID) (ACLs, error) {
|
|
if len(actorRowIDs) == 0 {
|
|
return []ACL{}, nil
|
|
}
|
|
query := fmt.Sprintf("SELECT %s FROM acls WHERE scope IN (%s)", strings.Join(aclsFields, ","), strings.Join(common.DollarForEach(len(actorRowIDs)), ","))
|
|
return s.queryACLs(ctx, query, common.UUIDsToInterfaces(actorRowIDs)...)
|
|
}
|
|
|
|
func (s pgStorage) ListACLsByTarget(ctx context.Context, target string) (ACLs, error) {
|
|
query := fmt.Sprintf(`SELECT %s FROM acls WHERE target = $1`, strings.Join(aclsFields, ","))
|
|
return s.queryACLs(ctx, query, target)
|
|
}
|
|
|
|
func (s pgStorage) ListACLsByTargets(ctx context.Context, targets []string) (ACLs, error) {
|
|
if len(targets) == 0 {
|
|
return []ACL{}, nil
|
|
}
|
|
query := fmt.Sprintf(`SELECT %s FROM acls WHERE target IN (%s)`, strings.Join(aclsFields, ","), strings.Join(common.DollarForEach(len(targets)), ","))
|
|
return s.queryACLs(ctx, query, common.StringsToInterfaces(targets)...)
|
|
}
|
|
|
|
func (s pgStorage) ListACLsByScopesAndWildTargets(ctx context.Context, actorRowIDs []uuid.UUID, targets []string) (ACLs, error) {
|
|
if len(actorRowIDs) == 0 || len(targets) == 0 {
|
|
return []ACL{}, nil
|
|
}
|
|
scopePlaceholders := strings.Join(common.DollarForEach(len(actorRowIDs)), ",")
|
|
targetPlaceholders := strings.Join(common.DollarForEachFrom(len(actorRowIDs)+1, len(actorRowIDs)+len(targets)), ",")
|
|
query := fmt.Sprintf(`SELECT %s FROM acls WHERE scope IN (%s) AND (target IN (%s) OR wild = $%d)`, strings.Join(aclsFields, ","), scopePlaceholders, targetPlaceholders, len(actorRowIDs)+len(targets)+1)
|
|
args := make([]interface{}, 0)
|
|
for _, actorRowID := range actorRowIDs {
|
|
args = append(args, actorRowID)
|
|
}
|
|
for _, target := range targets {
|
|
args = append(args, target)
|
|
}
|
|
args = append(args, true)
|
|
return s.queryACLs(ctx, query, args...)
|
|
}
|
|
|
|
func (s pgStorage) queryACLs(ctx context.Context, query string, args ...interface{}) (ACLs, error) {
|
|
results := make([]ACL, 0)
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, errors.NewObjectSelectFailedError(err)
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var acl ACL
|
|
if err := rows.Scan(&acl.ID, &acl.CreatedAt, &acl.UpdatedAt, &acl.Scope, &acl.Target, &acl.Action); err != nil {
|
|
return nil, errors.NewInvalidObjectError(err)
|
|
}
|
|
results = append(results, acl)
|
|
}
|
|
return results, nil
|
|
}
|