Implemented ACLs, added db checks to varchar columns, and cleaned up code.

This commit is contained in:
Nick Gerakines 2020-04-06 10:20:00 -04:00
parent c42b82f2a8
commit b3f78820d2
No known key found for this signature in database
GPG Key ID: 33D43D854F96B2E4
24 changed files with 890 additions and 66 deletions

3
.gitignore vendored
View File

@ -52,4 +52,5 @@ dist/
# local development environment files
common.env
tavern-town.env
tavern-town.env
*.pem

View File

@ -15,11 +15,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
Groups are special actors that share the same namespace as users, have their own profiles ([#68](https://gitlab.com/ngerakines/tavern/-/issues/68)), and can be browsed via a directory ([#67](https://gitlab.com/ngerakines/tavern/-/issues/67)).
* Group invitations allow group members (manager level) to invite others to the group.
* Allow/Deny matching for inbox activity with server and local scope. This allows the server to implement blank allow/deny matching, as well as user or group allow/deny matching.
* Added migrations to create non-empty checks for varchar fields.
* Add init-service command to create a service actor specific to the instance.
* Added a start-up check for the service actor.
**Important**: Without the service actor record, the server will not start up.
### Changed
* Changed link in user agent.
* Added metrics collector storage driver to expose query metrics.
* Renamed the init command to init-admin to make room for initializing the service actor.
* Updated documentation reflecting the addition of the service actor.
### Fixed

View File

@ -12,12 +12,13 @@ Milestones: https://gitlab.com/ngerakines/tavern/-/milestones
The quickest way to get up and running is with docker-compose.
If you are building from source, be sure to run `docker-compose build`.
If you are building from source, be sure to run `docker-compose build`. These instructions assume all of the required configuration is set.
1. docker-compose up -d db svger
2. docker-compose run web migrate
3. docker-compose run web init --admin-email=nick.gerakines@gmail.com --admin-password=password --admin-name=nick
4. docker-compose up -d
3. docker-compose run web init-service --pem /service.pem
4. docker-compose run web init-admin --admin-email=nick.gerakines@gmail.com --admin-password=password --admin-name=nick
5. docker-compose up -d
# Contributing

View File

@ -1,6 +1,7 @@
package common
import (
"fmt"
"strings"
)
@ -12,6 +13,14 @@ type AllowDenyMatcher struct {
type AllowDenyMatcherType int16
type MatcherSet struct {
Default bool
HighAllow []AllowDenyMatcher
HighDeny []AllowDenyMatcher
LowAllow []AllowDenyMatcher
LowDeny []AllowDenyMatcher
}
const (
ExactAllowDenyMatcher = iota
PrefixAllowDenyMatcher
@ -70,3 +79,30 @@ func (m AllowDenyMatcher) Match(test string) bool {
return false
}
}
func (s MatcherSet) Allow(tests ...string) bool {
for _, test := range tests {
fmt.Println("testing", test)
for _, matcher := range s.HighAllow {
if matcher.Match(test) {
return true
}
}
for _, matcher := range s.HighDeny {
if matcher.Match(test) {
return false
}
}
for _, matcher := range s.LowAllow {
if matcher.Match(test) {
return true
}
}
for _, matcher := range s.LowDeny {
if matcher.Match(test) {
return false
}
}
}
return s.Default
}

View File

@ -58,6 +58,10 @@ func DollarForEach(max int) []string {
return MapStrings(StringIntRange(1, max), Dollar)
}
func DollarForEachFrom(min, max int) []string {
return MapStrings(StringIntRange(min, max), Dollar)
}
func StringsToInterfaces(input []string) []interface{} {
results := make([]interface{}, len(input))
for i, value := range input {
@ -141,21 +145,43 @@ func NewUniqueUUIDs() *UniqueUUIDs {
func (u *UniqueStrings) Add(values ...string) {
for _, value := range values {
if _, ok := u.Tracking[value]; ok {
return
continue
}
u.Tracking[value] = struct{}{}
u.Values = append(u.Values, value)
}
}
func (u *UniqueUUIDs) Add(values ...uuid.UUID) {
func (u *UniqueStrings) With(values ...string) *UniqueStrings {
for _, value := range values {
if _, ok := u.Tracking[value]; ok {
return
continue
}
u.Tracking[value] = struct{}{}
u.Values = append(u.Values, value)
}
return u
}
func (u *UniqueUUIDs) Add(values ...uuid.UUID) {
for _, value := range values {
if _, ok := u.Tracking[value]; ok {
continue
}
u.Tracking[value] = struct{}{}
u.Values = append(u.Values, value)
}
}
func (u *UniqueUUIDs) With(values ...uuid.UUID) *UniqueUUIDs {
for _, value := range values {
if _, ok := u.Tracking[value]; ok {
continue
}
u.Tracking[value] = struct{}{}
u.Values = append(u.Values, value)
}
return u
}
func StringAppendIfMissing(slice []string, i string) []string {

View File

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT.
// This file was generated by herr at 2020-04-05 12:40:16.825087707 -0400 EDT m=+0.009599340
// This file was generated by herr at 2020-04-06 09:40:11.129122444 -0400 EDT m=+0.009597985
package errors
import (
@ -905,6 +905,11 @@ type InvalidUserIDError struct {
Stack *stack
}
type AccessDeniedError struct {
Err error
Stack *stack
}
var _ CodedError = NotFoundError{}
var _ CodedError = EncryptFailedError{}
var _ CodedError = DecryptFailedError{}
@ -1082,6 +1087,7 @@ var _ CodedError = AuthenticationRequiredError{}
var _ CodedError = TranslatorNotFoundError{}
var _ CodedError = TranslationNotFoundError{}
var _ CodedError = InvalidUserIDError{}
var _ CodedError = AccessDeniedError{}
// ErrorFromCode returns the CodedError for a serialized coded error string.
func ErrorFromCode(code string) (bool, error) {
@ -1440,6 +1446,8 @@ func ErrorFromCode(code string) (bool, error) {
return true, TranslationNotFoundError{}
case "TAVWEBAAAAAAAI":
return true, InvalidUserIDError{}
case "TAVWEBAAAAAAAJ":
return true, AccessDeniedError{}
default:
return false, fmt.Errorf("unknown error code: %s", code)
}
@ -3746,6 +3754,19 @@ func WrapInvalidUserIDError(err error) error {
return NewInvalidUserIDError(err)
}
func NewAccessDeniedError(err error) error {
return AccessDeniedError{ Err: err, Stack: callers() }
}
func WrapAccessDeniedError(err error) error {
if err == nil {
return nil
}
return NewAccessDeniedError(err)
}
func (e NotFoundError) Error() string {
return "TAVAAAAAAAB"
}
@ -12242,6 +12263,54 @@ func (e InvalidUserIDError) Format(s fmt.State, verb rune) {
}
}
func (e AccessDeniedError) Error() string {
return "TAVWEBAAAAAAAJ"
}
func (e AccessDeniedError) Unwrap() error {
return e.Err
}
func (e AccessDeniedError) Is(target error) bool {
t, ok := target.(AccessDeniedError)
if !ok {
return false
}
return t.Prefix() == "TAVWEB" && t.Code() == 9
}
func (e AccessDeniedError) Code() int {
return 9
}
func (e AccessDeniedError) Description() string {
return "Access denied."
}
func (e AccessDeniedError) Prefix() string {
return "TAVWEB"
}
func (e AccessDeniedError) String() string {
return "TAVWEBAAAAAAAJ Access denied."
}
func (e AccessDeniedError) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v", e.Unwrap())
e.Stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, "TAVWEBAAAAAAAJ")
case 'q':
fmt.Fprintf(s, "%q", "TAVWEBAAAAAAAJ")
}
}
// Frame represents a program counter inside a stack frame.
// For historical reasons if Frame is interpreted as a uintptr

View File

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT.
// This file was generated by herr at 2020-04-05 12:40:16.85622214 -0400 EDT m=+0.040733718
// This file was generated by herr at 2020-04-06 09:40:11.159683339 -0400 EDT m=+0.040158828
package errors
import (
@ -8583,4 +8583,50 @@ func TestInvalidUserID (t *testing.T) {
}
}
func TestAccessDenied (t *testing.T) {
err1 := NewAccessDeniedError(nil)
{
err1, ok := err1.(AccessDeniedError)
if !ok {
t.Errorf("Assertion failed on AccessDenied: %T is not AccessDeniedError", err1)
}
if err1.Prefix() != "TAVWEB" {
t.Errorf("Assertion failed on AccessDenied: %s != TAVWEB", err1.Prefix())
}
if err1.Code() != 9 {
t.Errorf("Assertion failed on AccessDenied: %d != 9", err1.Code())
}
if err1.Description() != "Access denied." {
t.Errorf("Assertion failed on AccessDenied: %s != Access denied.", err1.Description())
}
}
errNotFound := fmt.Errorf("not found")
errThingNotFound := fmt.Errorf("thing: %w", errNotFound)
err2 := NewAccessDeniedError(errThingNotFound)
{
err2, ok := err2.(AccessDeniedError)
if !ok {
t.Errorf("Assertion failed on AccessDenied: %T is not AccessDeniedError", err2)
}
errNestErr2 := fmt.Errorf("oh snap: %w", err2)
if err2.Code() != 9 {
t.Errorf("Assertion failed on AccessDenied: %d != 9", err2.Code())
}
if !errors.Is(err2, errNotFound) {
t.Errorf("Assertion failed on AccessDenied: errNotFound not unwrapped correctly")
}
if !errors.Is(err2, errThingNotFound) {
t.Errorf("Assertion failed on AccessDenied: errThingNotFound not unwrapped correctly")
}
if !errors.Is(err2, AccessDeniedError{}) {
t.Errorf("Assertion failed on AccessDenied: AccessDeniedError{} not identified correctly")
}
if !errors.Is(errNestErr2, AccessDeniedError{}) {
t.Errorf("Assertion failed on AccessDenied: AccessDeniedError{} not identified correctly")
}
}
}

View File

@ -5,4 +5,5 @@
5,TAVWEB,AuthenticationRequired ,Authentication required
6,TAVWEB,TranslatorNotFound ,Translator not found # Used when the underlying translation system isn't found.
7,TAVWEB,TranslationNotFound ,Translation not found
8,TAVWEB,InvalidUserID ,The user id is invalid.
8,TAVWEB,InvalidUserID ,The user id is invalid.
9,TAVWEB,AccessDenied ,Access denied.
1 1 TAVWEB InvalidEmailVerification Invalid verification.
5 5 TAVWEB AuthenticationRequired Authentication required
6 6 TAVWEB TranslatorNotFound Translator not found # Used when the underlying translation system isn't found.
7 7 TAVWEB TranslationNotFound Translation not found
8 8 TAVWEB InvalidUserID The user id is invalid.
9 9 TAVWEB AccessDenied Access denied.

View File

@ -48,7 +48,8 @@ func main() {
app.Copyright = "(c) 2020 Nick Gerakines"
app.Commands = []*cli.Command{
&start.Command,
&start.InitAdminCommand,
&start.InitServiceCommand,
&web.Command,
&asset.Command,
&migrations.Command,

View File

@ -0,0 +1 @@
DROP TABLE acls;

View File

@ -0,0 +1,44 @@
create table if not exists public.acls
(
id uuid not null
constraint acls_pk primary key,
created_at timestamp with time zone not null,
updated_at timestamp with time zone not null,
scope uuid not null,
target varchar not null,
action integer not null,
wild bool default false not null,
constraint acls_scope_action_uindex
unique (scope, target)
);
UPDATE actor_keys AS ak
SET key_id = a.payload -> 'publicKey' ->> 'id'
FROM actors AS a
WHERE ak.actor_id = a.id
AND ak.key_id = '';
ALTER TABLE acls
ADD CONSTRAINT acls_checks CHECK (target <> '' AND action >= 0 AND action <= 2);
ALTER TABLE actor_keys
ADD CONSTRAINT actor_keys_checks CHECK (key_id <> '' AND pem <> '');
ALTER TABLE actor_aliases
ADD CONSTRAINT actor_aliases_checks CHECK (alias <> '');
ALTER TABLE actors
ADD CONSTRAINT actors_checks CHECK (actor_id <> '');
ALTER TABLE groups
ADD CONSTRAINT groups_checks CHECK (name <> '' AND public_key <> '' AND private_key <> '' AND display_name <> '');
ALTER TABLE objects
ADD CONSTRAINT objects_checks CHECK (object_id <> '');
ALTER TABLE object_events
ADD CONSTRAINT object_events_checks CHECK (activity_id <> '');
ALTER TABLE users
ADD CONSTRAINT users_checks CHECK (email <> '' AND password <> '' AND name <> '' AND public_key <> '' AND
private_key <> '' AND display_name <> '');

View File

@ -21,8 +21,8 @@ import (
"github.com/ngerakines/tavern/storage"
)
var Command = cli.Command{
Name: "init",
var InitAdminCommand = cli.Command{
Name: "init-admin",
Usage: "Initialize the server",
Flags: []cli.Flag{
&config.EnvironmentFlag,
@ -59,10 +59,10 @@ var Command = cli.Command{
Usage: "The 'about me' of the admin user",
},
},
Action: serverCommandAction,
Action: initAdminCommandAction,
}
func serverCommandAction(cliCtx *cli.Context) error {
func initAdminCommandAction(cliCtx *cli.Context) error {
logger, err := config.Logger(cliCtx)
if err != nil {
return err

157
start/service.go Normal file
View File

@ -0,0 +1,157 @@
package start
import (
"bytes"
"context"
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"runtime"
"strings"
"github.com/getsentry/sentry-go"
"github.com/urfave/cli/v2"
"go.uber.org/zap"
"github.com/ngerakines/tavern/config"
"github.com/ngerakines/tavern/errors"
"github.com/ngerakines/tavern/g"
"github.com/ngerakines/tavern/storage"
)
var InitServiceCommand = cli.Command{
Name: "init-service",
Usage: "Initialize the service",
Flags: []cli.Flag{
&config.DomainFlag,
&config.DatabaseFlag,
&cli.StringFlag{
Name: "pem",
Usage: "The path to the private key PEM file.",
Required: true,
},
},
Action: initServiceCommandAction,
}
func initServiceCommandAction(cliCtx *cli.Context) error {
logger, err := config.Logger(cliCtx)
if err != nil {
return err
}
domain := cliCtx.String("domain")
siteBase := fmt.Sprintf("https://%s", domain)
logger.Info("Starting",
zap.String("command", cliCtx.Command.Name),
zap.String("GOOS", runtime.GOOS),
zap.String("site", siteBase),
zap.String("env", cliCtx.String("environment")))
sentryConfig, err := config.NewSentryConfig(cliCtx)
if err != nil {
return err
}
if sentryConfig.Enabled {
err = sentry.Init(sentry.ClientOptions{
Dsn: sentryConfig.Key,
Environment: cliCtx.String("environment"),
Release: fmt.Sprintf("%s-%s", g.Release, g.GitCommit),
})
if err != nil {
return err
}
sentry.ConfigureScope(func(scope *sentry.Scope) {
scope.SetTags(map[string]string{"container": "server"})
})
defer sentry.Recover()
}
db, dbClose, err := config.DB(cliCtx, logger)
if err != nil {
return err
}
defer dbClose()
actorID := fmt.Sprintf("https://%s/server", domain)
actor := storage.EmptyPayload()
actor["@context"] = "https://www.w3.org/ns/activitystreams"
actor["id"] = actorID
actor["type"] = "Service"
actor["inbox"] = fmt.Sprintf("%s/inbox", actorID)
actor["outbox"] = fmt.Sprintf("%s/outbox", actorID)
actor["name"] = domain
actor["summary"] = domain
actor["preferredUsername"] = domain
actor["url"] = fmt.Sprintf("https://%s/", domain)
ctx := context.Background()
pemBytes, err := ioutil.ReadFile(cliCtx.String("pem"))
if err != nil {
return err
}
block, _ := pem.Decode(pemBytes)
keyID, ok := block.Headers["id"]
if !ok {
return fmt.Errorf("missing header: id")
}
if len(keyID) == 0 || !strings.HasPrefix(keyID, fmt.Sprintf("https://%s/", domain)) {
return fmt.Errorf("invalid key id: %s", keyID)
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return fmt.Errorf("unable to parse private key bytes: %w", err)
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public())
if err != nil {
return fmt.Errorf("unable to marshal public key: %w", err)
}
var publicKeyBuffer bytes.Buffer
if err = pem.Encode(&publicKeyBuffer, &pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
}); err != nil {
return err
}
publicKey := string(publicKeyBuffer.Bytes())
txErr := storage.TransactionalStorage(ctx, storage.DefaultStorage(storage.LoggingSQLDriver{Driver: db, Logger: logger}), func(s storage.Storage) error {
keyPayload := storage.EmptyPayload()
keyPayload["id"] = keyID
keyPayload["owner"] = actorID
keyPayload["publicKeyPem"] = publicKey
actor["publicKey"] = keyPayload
err = s.CreateActor(ctx, actorID, actor)
if err != nil {
return err
}
actorRowID, err := s.ActorRowIDForActorID(ctx, actorID)
if err != nil {
return err
}
err = s.RecordActorKey(ctx, actorRowID, keyID, publicKey)
if err != nil {
return err
}
if err = s.RecordActorAlias(ctx, actorRowID, actorID, storage.ActorAliasSelf); err != nil {
return err
}
return nil
})
if txErr != nil {
logger.Error("error creating service", zap.Error(err), zap.Strings("error_chain", errors.ErrorChain(err)))
}
return txErr
}

208
storage/acl.go Normal file
View File

@ -0,0 +1,208 @@
package storage
import (
"context"
"fmt"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/errors"
)
type ACLStorage interface {
RecordACL(ctx context.Context, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error)
RecordACLAll(ctx context.Context, rowID uuid.UUID, createdAt, updatedAt time.Time, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error)
ListACLsByScope(ctx context.Context, actorRowID uuid.UUID) (ACLs, error)
ListACLsByScopes(ctx context.Context, actorRowIDs []uuid.UUID) (ACLs, error)
ListACLsByTarget(ctx context.Context, target string) (ACLs, error)
ListACLsByTargets(ctx context.Context, targets []string) (ACLs, error)
ListACLsByScopesAndWildTargets(ctx context.Context, actorRowIDs []uuid.UUID, targets []string) (ACLs, error)
}
type ACL struct {
ID uuid.UUID
CreatedAt time.Time
UpdatedAt time.Time
Scope uuid.UUID
Target string
Action ACLAction
Priority int
}
type aclSorter struct {
acls []ACL
by func(a1, a2 *ACL) bool
}
func (s *aclSorter) Len() int {
return len(s.acls)
}
func (s *aclSorter) Swap(i, j int) {
s.acls[i], s.acls[j] = s.acls[j], s.acls[i]
}
func (s *aclSorter) Less(i, j int) bool {
return s.by(&s.acls[i], &s.acls[j])
}
type ACLs []ACL
func (l ACLs) ToMatchSet(highScopes *common.UniqueUUIDs, defaultAllow bool) common.MatcherSet {
highAllow := make([]common.AllowDenyMatcher, 0)
highDeny := make([]common.AllowDenyMatcher, 0)
lowAllow := make([]common.AllowDenyMatcher, 0)
lowDeny := make([]common.AllowDenyMatcher, 0)
fmt.Println(l)
for _, acl := range l {
matcher := common.NewAllowDenyMatcher(acl.Target)
_, isHigh := highScopes.Tracking[acl.Scope]
fmt.Println("placing matcher", isHigh, acl.Action)
if isHigh && acl.Action == AllowACLAction {
highAllow = append(highAllow, matcher)
} else if isHigh && acl.Action == DenyACLAction {
highDeny = append(highDeny, matcher)
} else if !isHigh && acl.Action == AllowACLAction {
lowAllow = append(lowAllow, matcher)
} else if !isHigh && acl.Action == DenyACLAction {
lowDeny = append(lowDeny, matcher)
}
}
return common.MatcherSet{
Default: defaultAllow,
HighAllow: highAllow,
HighDeny: highDeny,
LowAllow: lowAllow,
LowDeny: lowDeny,
}
}
//
// func (l ACLs) WithPriority(serverScope uuid.UUID) ACLs {
// results := make(ACLs, len(l))
// for i, val := range l {
// priority := 1
// if val.Scope == serverScope {
// priority++
// }
// if val.Action == AllowACLAction {
// priority++
// }
// results[i] = ACL{
// ID: val.ID,
// CreatedAt: val.CreatedAt,
// UpdatedAt: val.UpdatedAt,
// Scope: val.Scope,
// Target: val.Target,
// Action: val.Action,
// Priority: priority,
// }
// }
//
// ps := &aclSorter{
// acls: results,
// by: func(a1, a2 *ACL) bool {
// return a1.Priority < a2.Priority
// },
// }
// sort.Sort(ps)
//
// for _, acl := range l {
// fmt.Println(acl.Scope, acl.Target, acl.Action)
// }
//
// return results
// }
type ACLAction int8
const (
AllowACLAction ACLAction = 0
DenyACLAction ACLAction = 1
)
var aclsFields = []string{
"id",
"created_at",
"updated_at",
"scope",
"target",
"action",
}
func (s pgStorage) RecordACL(ctx context.Context, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) {
rowID := NewV4()
now := s.now()
return s.RecordACLAll(ctx, rowID, now, now, actorRowID, target, action, wild)
}
func (s pgStorage) RecordACLAll(ctx context.Context, rowID uuid.UUID, createdAt, updatedAt time.Time, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) {
query := `INSERT INTO acls (id, created_at, updated_at, scope, target, action, wild) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT ON CONSTRAINT acls_scope_action_uindex DO UPDATE SET updated_at = $3, action = $6 RETURNING id`
var id uuid.UUID
err := s.db.QueryRowContext(ctx, query, rowID, createdAt, updatedAt, actorRowID, target, action, wild).Scan(&id)
return id, errors.WrapGroupUpsertFailedError(err)
}
func (s pgStorage) ListACLsByScope(ctx context.Context, actorRowID uuid.UUID) (ACLs, error) {
query := fmt.Sprintf(`SELECT %s FROM acls WHERE scope = $1`, strings.Join(aclsFields, ","))
return s.queryACLs(ctx, query, actorRowID)
}
func (s pgStorage) ListACLsByScopes(ctx context.Context, actorRowIDs []uuid.UUID) (ACLs, error) {
if len(actorRowIDs) == 0 {
return []ACL{}, nil
}
query := fmt.Sprintf("SELECT %s FROM acls WHERE scope IN (%s)", strings.Join(aclsFields, ","), strings.Join(common.DollarForEach(len(actorRowIDs)), ","))
return s.queryACLs(ctx, query, common.UUIDsToInterfaces(actorRowIDs)...)
}
func (s pgStorage) ListACLsByTarget(ctx context.Context, target string) (ACLs, error) {
query := fmt.Sprintf(`SELECT %s FROM acls WHERE target = $1`, strings.Join(aclsFields, ","))
return s.queryACLs(ctx, query, target)
}
func (s pgStorage) ListACLsByTargets(ctx context.Context, targets []string) (ACLs, error) {
if len(targets) == 0 {
return []ACL{}, nil
}
query := fmt.Sprintf(`SELECT %s FROM acls WHERE target IN (%s)`, strings.Join(aclsFields, ","), strings.Join(common.DollarForEach(len(targets)), ","))
return s.queryACLs(ctx, query, common.StringsToInterfaces(targets)...)
}
func (s pgStorage) ListACLsByScopesAndWildTargets(ctx context.Context, actorRowIDs []uuid.UUID, targets []string) (ACLs, error) {
if len(actorRowIDs) == 0 || len(targets) == 0 {
return []ACL{}, nil
}
scopePlaceholders := strings.Join(common.DollarForEach(len(actorRowIDs)), ",")
targetPlaceholders := strings.Join(common.DollarForEachFrom(len(actorRowIDs)+1, len(actorRowIDs)+len(targets)), ",")
query := fmt.Sprintf(`SELECT %s FROM acls WHERE scope IN (%s) AND (target IN (%s) OR wild = $%d)`, strings.Join(aclsFields, ","), scopePlaceholders, targetPlaceholders, len(actorRowIDs)+len(targets)+1)
args := make([]interface{}, 0)
for _, actorRowID := range actorRowIDs {
args = append(args, actorRowID)
}
for _, target := range targets {
args = append(args, target)
}
args = append(args, true)
return s.queryACLs(ctx, query, args...)
}
func (s pgStorage) queryACLs(ctx context.Context, query string, args ...interface{}) (ACLs, error) {
results := make([]ACL, 0)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.NewObjectSelectFailedError(err)
}
defer rows.Close()
for rows.Next() {
var acl ACL
if err := rows.Scan(&acl.ID, &acl.CreatedAt, &acl.UpdatedAt, &acl.Scope, &acl.Target, &acl.Action); err != nil {
return nil, errors.NewInvalidObjectError(err)
}
results = append(results, acl)
}
return results, nil
}

View File

@ -30,6 +30,7 @@ type ActorStorage interface {
ActorSubjects(ctx context.Context, actors []uuid.UUID) ([]ActorAlias, error)
ActorAliasSubjectExists(ctx context.Context, alias string) (bool, error)
FilterGroupsByActorID(ctx context.Context, actorIDs []string) ([]string, error)
ActorIDForActorRowID(ctx context.Context, actorRowID uuid.UUID) (string, error)
UpdateActorPayload(ctx context.Context, actorRowID uuid.UUID, payload Payload) error
}
@ -260,6 +261,18 @@ func (s pgStorage) ActorRowIDForActorID(ctx context.Context, actorID string) (uu
return actorRowID, nil
}
func (s pgStorage) ActorIDForActorRowID(ctx context.Context, actorRowID uuid.UUID) (string, error) {
var actorID string
err := s.db.QueryRowContext(ctx, `SELECT actor_id FROM actors WHERE id = $1`, actorRowID).Scan(&actorID)
if err != nil {
if err == sql.ErrNoRows {
return "", errors.NewActorNotFoundError(err)
}
return "", errors.NewActorSelectFailedError(err)
}
return actorID, nil
}
func (s pgStorage) ActorSubjects(ctx context.Context, actors []uuid.UUID) ([]ActorAlias, error) {
if len(actors) == 0 {
return nil, nil

View File

@ -22,6 +22,7 @@ type Storage interface {
AssetStorage
InstanceStatsStorage
GroupStorage
ACLStorage
GetExecutor() QueryExecute
}
@ -259,4 +260,4 @@ func CountMap(counts []Count) map[string]int {
results[c.Key] = c.Count
}
return results
}
}

View File

@ -9,8 +9,8 @@
<div class="form-check">
<input class="form-check-input" type="checkbox" value="checked"
id="userReplyCollectionUpdates" name="reply_collection_updates"
{{ if .user_config.reply_collection_updates }}checked="checked"{{ end }}
{{ if eq .fed_config.reply_collection_updates false }}disabled{{ end }}
{{ if .user_config.reply_collection_updates }} checked="checked"{{ end }}
{{ if eq .fed_config.reply_collection_updates false }} disabled{{ end }}
>
<label class="form-check-label" for="userReplyCollectionUpdates">
Enable Reply Collection Updates
@ -67,4 +67,63 @@
</ul>
</div>
</div>
<div class="row pt-3">
<div class="col">
<h3>ACLs</h3>
<form method="POST" action="{{ url "" }}configure/acl">
<div class="form-group">
<label for="testACLTarget">Target</label>
<input type="text" class="form-control" id="testACLTarget" name="target"
placeholder="https://lookatme.ceo/users/richy.rich" required>
<small id="createACLTargetHelp" class="form-text text-muted">
Enter an actor URL or domain.
</small>
</div>
<div class="form-group">
<label for="createACLAction">Action</label>
<select id="createACLAction" class="form-control" name="action" required>
<option value="1" selected>Deny</option>
<option value="0">Allow</option>
</select>
</div>
<input type="submit" class="btn btn-dark" name="submit" value="Submit"/>
</form>
</div>
</div>
<div class="row pt-3">
<div class="col">
<table class="table table-borderless table-hover">
<thead class="table-dark">
<tr>
<th>Scope</th>
<th>Target</th>
<th>Action</th>
</tr>
</thead>
<tbody>
{{ range $acl := .acls }}
<tr>
<td>{{ $acl.Scope }}</td>
<td>{{ $acl.Target }}</td>
<td>{{ if eq $acl.Action 0 }}Allow{{ else }}Deny{{ end }}</td>
</tr>
{{ end }}
</tbody>
</table>
</div>
</div>
<div class="row pt-3">
<div class="col">
<h3>ACL Test</h3>
<form method="POST" action="{{ url "" }}configure/acl">
<input type="hidden" name="validate" value="validate"/>
<div class="form-group">
<label for="testACLTarget">Target</label>
<input type="text" class="form-control" id="testACLTarget" name="target"
placeholder="https://lookatme.ceo/users/richy.rich" required>
</div>
<input type="submit" class="btn btn-dark" name="submit" value="Test"/>
</form>
</div>
</div>
{{end}}

View File

@ -883,5 +883,10 @@
"locale": "en",
"key": "TAVWEBAAAAAAAI",
"trans": "The user id is invalid."
},
{
"locale": "en",
"key": "TAVWEBAAAAAAAJ",
"trans": "Access denied."
}
]

View File

@ -151,6 +151,11 @@ func serverCommandAction(cliCtx *cli.Context) error {
s := storage.DefaultStorage(storage.WrapMetricDriver(fact, "tavern", "web", storage.LoggingSQLDriver{Driver: db, Logger: logger}))
serverActorRowID, err := s.ActorRowIDForActorID(context.Background(), fmt.Sprintf("https://%s/server", domain))
if err != nil {
return err
}
utrans, err := config.Trans(cliCtx)
if err != nil {
return err
@ -269,22 +274,23 @@ func serverCommandAction(cliCtx *cli.Context) error {
}
h := handler{
storage: s,
logger: logger,
domain: domain,
sentryConfig: sentryConfig,
fedConfig: fedConfig,
groupConfig: groupConfig,
webFingerQueue: webFingerQueue,
crawlQueue: crawlQueue,
assetQueue: assetQueue,
adminUser: cliCtx.String("admin-name"),
url: tmplUrlGen(siteBase),
svgConverter: svgConv,
assetStorage: assetStorage,
httpClient: httpClient,
metricFactory: promauto.With(registry),
publisherClient: newPublisherClient(logger, httpClient, publisherConfig),
storage: s,
logger: logger,
domain: domain,
sentryConfig: sentryConfig,
fedConfig: fedConfig,
groupConfig: groupConfig,
webFingerQueue: webFingerQueue,
crawlQueue: crawlQueue,
assetQueue: assetQueue,
adminUser: cliCtx.String("admin-name"),
url: tmplUrlGen(siteBase),
svgConverter: svgConv,
assetStorage: assetStorage,
httpClient: httpClient,
metricFactory: promauto.With(registry),
publisherClient: newPublisherClient(logger, httpClient, publisherConfig),
serverActorRowID: serverActorRowID,
}
configI18nMiddleware(sentryConfig, logger, utrans, domain, r)
@ -395,6 +401,8 @@ func serverCommandAction(cliCtx *cli.Context) error {
authenticated.GET("/configure", h.configure)
authenticated.POST("/configure/user", h.saveUserSettings)
authenticated.POST("/configure/acl", h.createACL)
authenticated.POST("/configure/acl-validate", h.validateACL)
authenticated.GET("/notifications", h.notifications)

View File

@ -8,6 +8,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.uber.org/zap"
@ -37,7 +38,8 @@ type handler struct {
httpClient common.HTTPClient
metricFactory promauto.Factory
publisherClient *publisherClient
publisherClient *publisherClient
serverActorRowID uuid.UUID
}
func (h handler) hardFail(ctx *gin.Context, err error, fields ...zap.Field) {

View File

@ -1,9 +1,15 @@
package web
import (
"context"
"strings"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"go.uber.org/zap"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/errors"
"github.com/ngerakines/tavern/storage"
)
@ -51,3 +57,28 @@ func (h handler) loggedInAPI(c *gin.Context, requireUser bool) (*storage.User, s
}
return user, session, true
}
func allow(ctx context.Context, logger *zap.Logger, s storage.Storage, serverScope, localScope uuid.UUID, target string) bool {
var matchSet common.MatcherSet
domain := strings.TrimPrefix(target, "https://")
if i := strings.IndexRune(domain, '/'); i > -1 {
domain = domain[:i]
}
targets := common.NewUniqueStrings().With(target, domain)
txErr := storage.TransactionalStorage(ctx, s, func(tx storage.Storage) error {
acls, err := tx.ListACLsByScopesAndWildTargets(ctx, []uuid.UUID{serverScope, localScope}, targets.Values)
if err != nil {
return err
}
matchSet = acls.ToMatchSet(common.NewUniqueUUIDs().With(serverScope), true)
return nil
})
if txErr != nil {
logger.Warn("unable to verify allow deny", zap.Error(txErr), zap.String("target", target))
return false
}
return matchSet.Allow(domain, target)
}

View File

@ -1,12 +1,17 @@
package web
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"go.uber.org/zap"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/errors"
"github.com/ngerakines/tavern/storage"
)
@ -52,6 +57,19 @@ func (h handler) configure(c *gin.Context) {
"auto_accept_followers": user.AcceptFollowers,
}
serverActorRowID, err := h.storage.ActorRowIDForActorID(ctx, fmt.Sprintf("https://%s/server", h.domain))
if err != nil {
h.hardFail(c, err)
return
}
acls, err := h.storage.ListACLsByScopes(ctx, []uuid.UUID{user.ActorID, serverActorRowID})
if err != nil {
h.hardFail(c, err)
return
}
data["acls"] = acls
if err = session.Save(); err != nil {
h.hardFail(c, err)
return
@ -74,12 +92,12 @@ func (h handler) saveUserSettings(c *gin.Context) {
zap.String("reply_collection_updates", c.PostForm("reply_collection_updates")),
)
err = storage.TransactionalStorage(ctx, h.storage, func(storage storage.Storage) error {
txErr := storage.UpdateUserAutoAcceptFollowers(ctx, user.ID, c.PostForm("auto_accept_followers") == "checked")
err = storage.TransactionalStorage(ctx, h.storage, func(tx storage.Storage) error {
txErr := tx.UpdateUserAutoAcceptFollowers(ctx, user.ID, c.PostForm("auto_accept_followers") == "checked")
if txErr != nil {
return txErr
}
return storage.UpdateUserReplyCollectionUpdates(ctx, user.ID, c.PostForm("reply_collection_updates") == "checked")
return tx.UpdateUserReplyCollectionUpdates(ctx, user.ID, c.PostForm("reply_collection_updates") == "checked")
})
if err != nil {
h.hardFail(c, err)
@ -88,3 +106,73 @@ func (h handler) saveUserSettings(c *gin.Context) {
c.Redirect(http.StatusFound, h.url("configure"))
}
func (h handler) createACL(c *gin.Context) {
if c.PostForm("validate") == "validate" {
h.validateACL(c)
return
}
session := sessions.Default(c)
ctx := c.Request.Context()
user, err := h.storage.GetUserBySession(ctx, session)
if err != nil {
h.hardFail(c, err)
return
}
wild := strings.ContainsRune(c.PostForm("target"), '*')
txErr := storage.TransactionalStorage(ctx, h.storage, func(tx storage.Storage) error {
action, err := strconv.Atoi(c.PostForm("action"))
if err != nil {
return err
}
_, err = tx.RecordACL(ctx, user.ActorID, c.PostForm("target"), storage.ACLAction(action), wild)
return err
})
if txErr != nil {
h.hardFail(c, txErr)
return
}
h.flashSuccessOrFail(c, h.url("configure"), "ACL created")
}
func (h handler) validateACL(c *gin.Context) {
session := sessions.Default(c)
ctx := c.Request.Context()
user, err := h.storage.GetUserBySession(ctx, session)
if err != nil {
h.hardFail(c, err)
return
}
target := c.PostForm("target")
domain := strings.TrimPrefix(target, "https://")
if i := strings.IndexRune(domain, '/'); i > -1 {
domain = domain[:i]
}
var matchSet common.MatcherSet
txErr := storage.TransactionalStorage(ctx, h.storage, func(tx storage.Storage) error {
acls, err := tx.ListACLsByScopesAndWildTargets(ctx, []uuid.UUID{user.ActorID, h.serverActorRowID}, []string{domain, target})
if err != nil {
return err
}
matchSet = acls.ToMatchSet(common.NewUniqueUUIDs().With(h.serverActorRowID), true)
return nil
})
if txErr != nil {
h.hardFail(c, txErr)
return
}
if matchSet.Allow(domain, target) {
h.flashSuccessOrFail(c, h.url("configure"), "ACL allow")
return
}
h.flashErrorOrFail(c, h.url("configure"), fmt.Errorf("ACL deny"))
}

View File

@ -63,11 +63,11 @@ func (h handler) groupActorInbox(c *gin.Context) {
payloadType, _ := storage.JSONString(payload, "type")
actor, _ := storage.JSONString(payload, "actor")
if len(actor) > 0 {
err = h.webFingerQueue.Add(actor)
if err != nil {
h.logger.Error("unable to add actor to web finger queue", zap.String("actor", actor))
}
if len(actor) > 0 && !allow(c.Request.Context(), h.logger, h.storage, h.serverActorRowID, group.ActorID, actor) {
h.logger.Debug("actor denied", zap.String("group", name), zap.String("actor", actor))
c.Status(http.StatusOK)
return
}
switch payloadType {
@ -88,7 +88,7 @@ func (h handler) groupActorInbox(c *gin.Context) {
}
func (h handler) groupActorInboxInvite(c *gin.Context, group storage.Group, payload storage.Payload) {
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, group.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -252,7 +252,7 @@ func (h handler) groupActorInboxFollowActor(c *gin.Context, group storage.Group,
return
}
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, group.ActorID); err != nil {
h.logger.Debug("signature verification failed", zap.Error(err))
h.unauthorizedJSON(c, err)
return
@ -400,7 +400,7 @@ func (h handler) groupActorInboxUndoFollowActor(c *gin.Context, group storage.Gr
return
}
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, group.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -465,7 +465,7 @@ func (h handler) groupActorInboxUndo(c *gin.Context, group storage.Group, payloa
func (h handler) groupActorInboxCreate(c *gin.Context, group storage.Group, payload storage.Payload) {
// Because actors must follow the group, it is safe to assume that we
// have actor and actor key records for all valid incoming activities.
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, group.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -596,7 +596,7 @@ func (h handler) groupActorInboxCreate(c *gin.Context, group storage.Group, payl
}
func (h handler) groupActorInboxAnnounce(c *gin.Context, group storage.Group, payload storage.Payload) {
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, group.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}

View File

@ -67,11 +67,11 @@ func (h handler) userActorInbox(c *gin.Context) {
payloadType, _ := storage.JSONString(payload, "type")
actor, _ := storage.JSONString(payload, "actor")
if len(actor) > 0 {
err = h.webFingerQueue.Add(actor)
if err != nil {
h.logger.Error("unable to add actor to web finger queue", zap.String("actor", actor))
}
if len(actor) > 0 && !allow(c.Request.Context(), h.logger, h.storage, h.serverActorRowID, user.ActorID, actor) {
h.logger.Debug("actor denied", zap.String("user", name), zap.String("actor", actor))
c.Status(http.StatusOK)
return
}
switch payloadType {
@ -96,7 +96,7 @@ func (h handler) userActorInbox(c *gin.Context) {
}
func (h handler) actorInboxAccept(c *gin.Context, user *storage.User, payload storage.Payload) {
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, user.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -132,7 +132,7 @@ func (h handler) actorInboxAccept(c *gin.Context, user *storage.User, payload st
}
func (h handler) actorInboxReject(c *gin.Context, user *storage.User, payload storage.Payload) {
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, user.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -214,7 +214,7 @@ func (h handler) actorInboxFollowActor(c *gin.Context, user *storage.User, paylo
return
}
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, user.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -297,7 +297,7 @@ func (h handler) actorInboxFollowNote(c *gin.Context, user *storage.User, payloa
return
}
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, user.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -412,7 +412,7 @@ func (h handler) actorInboxUndoFollowActor(c *gin.Context, user *storage.User, p
return
}
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, user.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -452,7 +452,7 @@ func (h handler) actorInboxUndoFollowNote(c *gin.Context, user *storage.User, pa
return
}
if err := h.verifySignature(c); err != nil {
if err := h.verifySignature(c, user.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -513,7 +513,7 @@ func (h handler) actorInboxCreate(c *gin.Context, user *storage.User, payload st
return
}
if err = h.verifySignature(c); err != nil {
if err = h.verifySignature(c, user.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -666,7 +666,7 @@ func (h handler) actorInboxAnnounce(c *gin.Context, user *storage.User, payload
return
}
if err = h.verifySignature(c); err != nil {
if err = h.verifySignature(c, user.ActorID); err != nil {
h.unauthorizedJSON(c, err)
return
}
@ -807,7 +807,7 @@ func (h handler) actorInboxAnnounce(c *gin.Context, user *storage.User, payload
}
func (h handler) actorInboxDelete(c *gin.Context, user *storage.User, payload storage.Payload, raw []byte) {
err := h.verifySignature(c)
err := h.verifySignature(c, user.ActorID)
if err != nil {
h.unauthorizedJSON(c, err)
return
@ -1013,17 +1013,35 @@ func (h handler) isActivityRelevant(ctx context.Context, activity storage.Payloa
return false, nil
}
func (h handler) verifySignature(c *gin.Context) error {
func (h handler) verifySignature(c *gin.Context, scope uuid.UUID) error {
// Host header isn't set for some reason.
c.Request.Header.Add("Host", c.Request.Host)
verifier, err := httpsig.NewVerifier(c.Request)
if err != nil {
return err
}
key, err := h.storage.GetKey(c.Request.Context(), verifier.KeyId())
if err != nil {
return err
ctx := c.Request.Context()
var key *storage.Key
var actorID string
txErr := storage.TransactionalStorage(ctx, h.storage, func(tx storage.Storage) error {
var err error
key, err = h.storage.GetKey(c.Request.Context(), verifier.KeyId())
if err != nil {
return err
}
actorID, err = tx.ActorIDForActorRowID(ctx, key.Actor)
if err != nil {
return err
}
return nil
})
if txErr != nil {
return txErr
}
if !allow(c.Request.Context(), h.logger, h.storage, h.serverActorRowID, scope, actorID) {
return errors.NewAccessDeniedError(nil)
}
publicKey, err := key.GetDecodedPublicKey()
if err != nil {
return err