diff --git a/.gitignore b/.gitignore index 0728bac..518e5c0 100644 --- a/.gitignore +++ b/.gitignore @@ -52,4 +52,5 @@ dist/ # local development environment files common.env -tavern-town.env \ No newline at end of file +tavern-town.env +*.pem \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 7553bf8..fcdfb46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,11 +15,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), Groups are special actors that share the same namespace as users, have their own profiles ([#68](https://gitlab.com/ngerakines/tavern/-/issues/68)), and can be browsed via a directory ([#67](https://gitlab.com/ngerakines/tavern/-/issues/67)). * Group invitations allow group members (manager level) to invite others to the group. +* Allow/Deny matching for inbox activity with server and local scope. This allows the server to implement blank allow/deny matching, as well as user or group allow/deny matching. +* Added migrations to create non-empty checks for varchar fields. +* Add init-service command to create a service actor specific to the instance. +* Added a start-up check for the service actor. + + **Important**: Without the service actor record, the server will not start up. ### Changed * Changed link in user agent. * Added metrics collector storage driver to expose query metrics. +* Renamed the init command to init-admin to make room for initializing the service actor. +* Updated documentation reflecting the addition of the service actor. ### Fixed diff --git a/README.md b/README.md index 1ca4622..961545b 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,13 @@ Milestones: https://gitlab.com/ngerakines/tavern/-/milestones The quickest way to get up and running is with docker-compose. -If you are building from source, be sure to run `docker-compose build`. +If you are building from source, be sure to run `docker-compose build`. These instructions assume all of the required configuration is set. 1. docker-compose up -d db svger 2. docker-compose run web migrate -3. docker-compose run web init --admin-email=nick.gerakines@gmail.com --admin-password=password --admin-name=nick -4. docker-compose up -d +3. docker-compose run web init-service --pem /service.pem +4. docker-compose run web init-admin --admin-email=nick.gerakines@gmail.com --admin-password=password --admin-name=nick +5. docker-compose up -d # Contributing diff --git a/common/allow.go b/common/allow.go index 4130054..f67cc30 100644 --- a/common/allow.go +++ b/common/allow.go @@ -1,6 +1,7 @@ package common import ( + "fmt" "strings" ) @@ -12,6 +13,14 @@ type AllowDenyMatcher struct { type AllowDenyMatcherType int16 +type MatcherSet struct { + Default bool + HighAllow []AllowDenyMatcher + HighDeny []AllowDenyMatcher + LowAllow []AllowDenyMatcher + LowDeny []AllowDenyMatcher +} + const ( ExactAllowDenyMatcher = iota PrefixAllowDenyMatcher @@ -70,3 +79,30 @@ func (m AllowDenyMatcher) Match(test string) bool { return false } } + +func (s MatcherSet) Allow(tests ...string) bool { + for _, test := range tests { + fmt.Println("testing", test) + for _, matcher := range s.HighAllow { + if matcher.Match(test) { + return true + } + } + for _, matcher := range s.HighDeny { + if matcher.Match(test) { + return false + } + } + for _, matcher := range s.LowAllow { + if matcher.Match(test) { + return true + } + } + for _, matcher := range s.LowDeny { + if matcher.Match(test) { + return false + } + } + } + return s.Default +} diff --git a/common/strings.go b/common/strings.go index 553c424..f54247b 100644 --- a/common/strings.go +++ b/common/strings.go @@ -58,6 +58,10 @@ func DollarForEach(max int) []string { return MapStrings(StringIntRange(1, max), Dollar) } +func DollarForEachFrom(min, max int) []string { + return MapStrings(StringIntRange(min, max), Dollar) +} + func StringsToInterfaces(input []string) []interface{} { results := make([]interface{}, len(input)) for i, value := range input { @@ -141,21 +145,43 @@ func NewUniqueUUIDs() *UniqueUUIDs { func (u *UniqueStrings) Add(values ...string) { for _, value := range values { if _, ok := u.Tracking[value]; ok { - return + continue } u.Tracking[value] = struct{}{} u.Values = append(u.Values, value) } } -func (u *UniqueUUIDs) Add(values ...uuid.UUID) { +func (u *UniqueStrings) With(values ...string) *UniqueStrings { for _, value := range values { if _, ok := u.Tracking[value]; ok { - return + continue } u.Tracking[value] = struct{}{} u.Values = append(u.Values, value) } + return u +} + +func (u *UniqueUUIDs) Add(values ...uuid.UUID) { + for _, value := range values { + if _, ok := u.Tracking[value]; ok { + continue + } + u.Tracking[value] = struct{}{} + u.Values = append(u.Values, value) + } +} + +func (u *UniqueUUIDs) With(values ...uuid.UUID) *UniqueUUIDs { + for _, value := range values { + if _, ok := u.Tracking[value]; ok { + continue + } + u.Tracking[value] = struct{}{} + u.Values = append(u.Values, value) + } + return u } func StringAppendIfMissing(slice []string, i string) []string { diff --git a/errors/errors_generated.go b/errors/errors_generated.go index 6e56593..b80433d 100644 --- a/errors/errors_generated.go +++ b/errors/errors_generated.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// This file was generated by herr at 2020-04-05 12:40:16.825087707 -0400 EDT m=+0.009599340 +// This file was generated by herr at 2020-04-06 09:40:11.129122444 -0400 EDT m=+0.009597985 package errors import ( @@ -905,6 +905,11 @@ type InvalidUserIDError struct { Stack *stack } +type AccessDeniedError struct { + Err error + Stack *stack +} + var _ CodedError = NotFoundError{} var _ CodedError = EncryptFailedError{} var _ CodedError = DecryptFailedError{} @@ -1082,6 +1087,7 @@ var _ CodedError = AuthenticationRequiredError{} var _ CodedError = TranslatorNotFoundError{} var _ CodedError = TranslationNotFoundError{} var _ CodedError = InvalidUserIDError{} +var _ CodedError = AccessDeniedError{} // ErrorFromCode returns the CodedError for a serialized coded error string. func ErrorFromCode(code string) (bool, error) { @@ -1440,6 +1446,8 @@ func ErrorFromCode(code string) (bool, error) { return true, TranslationNotFoundError{} case "TAVWEBAAAAAAAI": return true, InvalidUserIDError{} + case "TAVWEBAAAAAAAJ": + return true, AccessDeniedError{} default: return false, fmt.Errorf("unknown error code: %s", code) } @@ -3746,6 +3754,19 @@ func WrapInvalidUserIDError(err error) error { return NewInvalidUserIDError(err) } +func NewAccessDeniedError(err error) error { + + return AccessDeniedError{ Err: err, Stack: callers() } + +} + +func WrapAccessDeniedError(err error) error { + if err == nil { + return nil + } + return NewAccessDeniedError(err) +} + func (e NotFoundError) Error() string { return "TAVAAAAAAAB" } @@ -12242,6 +12263,54 @@ func (e InvalidUserIDError) Format(s fmt.State, verb rune) { } } +func (e AccessDeniedError) Error() string { + return "TAVWEBAAAAAAAJ" +} + +func (e AccessDeniedError) Unwrap() error { + return e.Err +} + +func (e AccessDeniedError) Is(target error) bool { + t, ok := target.(AccessDeniedError) + if !ok { + return false + } + return t.Prefix() == "TAVWEB" && t.Code() == 9 +} + +func (e AccessDeniedError) Code() int { + return 9 +} + +func (e AccessDeniedError) Description() string { + return "Access denied." +} + +func (e AccessDeniedError) Prefix() string { + return "TAVWEB" +} + +func (e AccessDeniedError) String() string { + return "TAVWEBAAAAAAAJ Access denied." +} + +func (e AccessDeniedError) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v", e.Unwrap()) + e.Stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, "TAVWEBAAAAAAAJ") + case 'q': + fmt.Fprintf(s, "%q", "TAVWEBAAAAAAAJ") + } +} + // Frame represents a program counter inside a stack frame. // For historical reasons if Frame is interpreted as a uintptr diff --git a/errors/errors_generated_test.go b/errors/errors_generated_test.go index d302fac..ef5e237 100644 --- a/errors/errors_generated_test.go +++ b/errors/errors_generated_test.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// This file was generated by herr at 2020-04-05 12:40:16.85622214 -0400 EDT m=+0.040733718 +// This file was generated by herr at 2020-04-06 09:40:11.159683339 -0400 EDT m=+0.040158828 package errors import ( @@ -8583,4 +8583,50 @@ func TestInvalidUserID (t *testing.T) { } } +func TestAccessDenied (t *testing.T) { + err1 := NewAccessDeniedError(nil) + { + err1, ok := err1.(AccessDeniedError) + if !ok { + t.Errorf("Assertion failed on AccessDenied: %T is not AccessDeniedError", err1) + } + if err1.Prefix() != "TAVWEB" { + t.Errorf("Assertion failed on AccessDenied: %s != TAVWEB", err1.Prefix()) + } + if err1.Code() != 9 { + t.Errorf("Assertion failed on AccessDenied: %d != 9", err1.Code()) + } + if err1.Description() != "Access denied." { + t.Errorf("Assertion failed on AccessDenied: %s != Access denied.", err1.Description()) + } + } + + errNotFound := fmt.Errorf("not found") + errThingNotFound := fmt.Errorf("thing: %w", errNotFound) + err2 := NewAccessDeniedError(errThingNotFound) + { + err2, ok := err2.(AccessDeniedError) + if !ok { + t.Errorf("Assertion failed on AccessDenied: %T is not AccessDeniedError", err2) + } + errNestErr2 := fmt.Errorf("oh snap: %w", err2) + if err2.Code() != 9 { + t.Errorf("Assertion failed on AccessDenied: %d != 9", err2.Code()) + } + if !errors.Is(err2, errNotFound) { + t.Errorf("Assertion failed on AccessDenied: errNotFound not unwrapped correctly") + } + if !errors.Is(err2, errThingNotFound) { + t.Errorf("Assertion failed on AccessDenied: errThingNotFound not unwrapped correctly") + } + if !errors.Is(err2, AccessDeniedError{}) { + t.Errorf("Assertion failed on AccessDenied: AccessDeniedError{} not identified correctly") + } + if !errors.Is(errNestErr2, AccessDeniedError{}) { + t.Errorf("Assertion failed on AccessDenied: AccessDeniedError{} not identified correctly") + } + + } +} + diff --git a/errors/tavweb.csv b/errors/tavweb.csv index 681d632..eb2c196 100644 --- a/errors/tavweb.csv +++ b/errors/tavweb.csv @@ -5,4 +5,5 @@ 5,TAVWEB,AuthenticationRequired ,Authentication required 6,TAVWEB,TranslatorNotFound ,Translator not found # Used when the underlying translation system isn't found. 7,TAVWEB,TranslationNotFound ,Translation not found -8,TAVWEB,InvalidUserID ,The user id is invalid. \ No newline at end of file +8,TAVWEB,InvalidUserID ,The user id is invalid. +9,TAVWEB,AccessDenied ,Access denied. \ No newline at end of file diff --git a/main.go b/main.go index 0c1754d..3824ce0 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,8 @@ func main() { app.Copyright = "(c) 2020 Nick Gerakines" app.Commands = []*cli.Command{ - &start.Command, + &start.InitAdminCommand, + &start.InitServiceCommand, &web.Command, &asset.Command, &migrations.Command, diff --git a/migrations/20200405151030_acls.down.sql b/migrations/20200405151030_acls.down.sql new file mode 100644 index 0000000..68125d4 --- /dev/null +++ b/migrations/20200405151030_acls.down.sql @@ -0,0 +1 @@ +DROP TABLE acls; \ No newline at end of file diff --git a/migrations/20200405151030_acls.up.sql b/migrations/20200405151030_acls.up.sql new file mode 100644 index 0000000..981604d --- /dev/null +++ b/migrations/20200405151030_acls.up.sql @@ -0,0 +1,44 @@ +create table if not exists public.acls +( + id uuid not null + constraint acls_pk primary key, + created_at timestamp with time zone not null, + updated_at timestamp with time zone not null, + scope uuid not null, + target varchar not null, + action integer not null, + wild bool default false not null, + constraint acls_scope_action_uindex + unique (scope, target) +); + +UPDATE actor_keys AS ak +SET key_id = a.payload -> 'publicKey' ->> 'id' +FROM actors AS a +WHERE ak.actor_id = a.id + AND ak.key_id = ''; + +ALTER TABLE acls + ADD CONSTRAINT acls_checks CHECK (target <> '' AND action >= 0 AND action <= 2); + +ALTER TABLE actor_keys + ADD CONSTRAINT actor_keys_checks CHECK (key_id <> '' AND pem <> ''); + +ALTER TABLE actor_aliases + ADD CONSTRAINT actor_aliases_checks CHECK (alias <> ''); + +ALTER TABLE actors + ADD CONSTRAINT actors_checks CHECK (actor_id <> ''); + +ALTER TABLE groups + ADD CONSTRAINT groups_checks CHECK (name <> '' AND public_key <> '' AND private_key <> '' AND display_name <> ''); + +ALTER TABLE objects + ADD CONSTRAINT objects_checks CHECK (object_id <> ''); + +ALTER TABLE object_events + ADD CONSTRAINT object_events_checks CHECK (activity_id <> ''); + +ALTER TABLE users + ADD CONSTRAINT users_checks CHECK (email <> '' AND password <> '' AND name <> '' AND public_key <> '' AND + private_key <> '' AND display_name <> ''); diff --git a/start/command.go b/start/admin.go similarity index 96% rename from start/command.go rename to start/admin.go index 6905ac5..c028607 100644 --- a/start/command.go +++ b/start/admin.go @@ -21,8 +21,8 @@ import ( "github.com/ngerakines/tavern/storage" ) -var Command = cli.Command{ - Name: "init", +var InitAdminCommand = cli.Command{ + Name: "init-admin", Usage: "Initialize the server", Flags: []cli.Flag{ &config.EnvironmentFlag, @@ -59,10 +59,10 @@ var Command = cli.Command{ Usage: "The 'about me' of the admin user", }, }, - Action: serverCommandAction, + Action: initAdminCommandAction, } -func serverCommandAction(cliCtx *cli.Context) error { +func initAdminCommandAction(cliCtx *cli.Context) error { logger, err := config.Logger(cliCtx) if err != nil { return err diff --git a/start/service.go b/start/service.go new file mode 100644 index 0000000..9d01c19 --- /dev/null +++ b/start/service.go @@ -0,0 +1,157 @@ +package start + +import ( + "bytes" + "context" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "runtime" + "strings" + + "github.com/getsentry/sentry-go" + "github.com/urfave/cli/v2" + "go.uber.org/zap" + + "github.com/ngerakines/tavern/config" + "github.com/ngerakines/tavern/errors" + "github.com/ngerakines/tavern/g" + "github.com/ngerakines/tavern/storage" +) + +var InitServiceCommand = cli.Command{ + Name: "init-service", + Usage: "Initialize the service", + Flags: []cli.Flag{ + &config.DomainFlag, + &config.DatabaseFlag, + &cli.StringFlag{ + Name: "pem", + Usage: "The path to the private key PEM file.", + Required: true, + }, + }, + Action: initServiceCommandAction, +} + +func initServiceCommandAction(cliCtx *cli.Context) error { + logger, err := config.Logger(cliCtx) + if err != nil { + return err + } + + domain := cliCtx.String("domain") + siteBase := fmt.Sprintf("https://%s", domain) + + logger.Info("Starting", + zap.String("command", cliCtx.Command.Name), + zap.String("GOOS", runtime.GOOS), + zap.String("site", siteBase), + zap.String("env", cliCtx.String("environment"))) + + sentryConfig, err := config.NewSentryConfig(cliCtx) + if err != nil { + return err + } + + if sentryConfig.Enabled { + err = sentry.Init(sentry.ClientOptions{ + Dsn: sentryConfig.Key, + Environment: cliCtx.String("environment"), + Release: fmt.Sprintf("%s-%s", g.Release, g.GitCommit), + }) + if err != nil { + return err + } + sentry.ConfigureScope(func(scope *sentry.Scope) { + scope.SetTags(map[string]string{"container": "server"}) + }) + defer sentry.Recover() + } + + db, dbClose, err := config.DB(cliCtx, logger) + if err != nil { + return err + } + defer dbClose() + + actorID := fmt.Sprintf("https://%s/server", domain) + actor := storage.EmptyPayload() + actor["@context"] = "https://www.w3.org/ns/activitystreams" + actor["id"] = actorID + actor["type"] = "Service" + actor["inbox"] = fmt.Sprintf("%s/inbox", actorID) + actor["outbox"] = fmt.Sprintf("%s/outbox", actorID) + actor["name"] = domain + actor["summary"] = domain + actor["preferredUsername"] = domain + actor["url"] = fmt.Sprintf("https://%s/", domain) + + ctx := context.Background() + + pemBytes, err := ioutil.ReadFile(cliCtx.String("pem")) + if err != nil { + return err + } + block, _ := pem.Decode(pemBytes) + keyID, ok := block.Headers["id"] + if !ok { + return fmt.Errorf("missing header: id") + } + if len(keyID) == 0 || !strings.HasPrefix(keyID, fmt.Sprintf("https://%s/", domain)) { + return fmt.Errorf("invalid key id: %s", keyID) + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return fmt.Errorf("unable to parse private key bytes: %w", err) + } + + publicKeyBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public()) + if err != nil { + return fmt.Errorf("unable to marshal public key: %w", err) + } + + var publicKeyBuffer bytes.Buffer + if err = pem.Encode(&publicKeyBuffer, &pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + }); err != nil { + return err + } + publicKey := string(publicKeyBuffer.Bytes()) + + txErr := storage.TransactionalStorage(ctx, storage.DefaultStorage(storage.LoggingSQLDriver{Driver: db, Logger: logger}), func(s storage.Storage) error { + keyPayload := storage.EmptyPayload() + keyPayload["id"] = keyID + keyPayload["owner"] = actorID + keyPayload["publicKeyPem"] = publicKey + actor["publicKey"] = keyPayload + + err = s.CreateActor(ctx, actorID, actor) + if err != nil { + return err + } + + actorRowID, err := s.ActorRowIDForActorID(ctx, actorID) + if err != nil { + return err + } + + err = s.RecordActorKey(ctx, actorRowID, keyID, publicKey) + if err != nil { + return err + } + + if err = s.RecordActorAlias(ctx, actorRowID, actorID, storage.ActorAliasSelf); err != nil { + return err + } + + return nil + }) + if txErr != nil { + logger.Error("error creating service", zap.Error(err), zap.Strings("error_chain", errors.ErrorChain(err))) + } + return txErr +} diff --git a/storage/acl.go b/storage/acl.go new file mode 100644 index 0000000..b38a003 --- /dev/null +++ b/storage/acl.go @@ -0,0 +1,208 @@ +package storage + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/gofrs/uuid" + + "github.com/ngerakines/tavern/common" + "github.com/ngerakines/tavern/errors" +) + +type ACLStorage interface { + RecordACL(ctx context.Context, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) + RecordACLAll(ctx context.Context, rowID uuid.UUID, createdAt, updatedAt time.Time, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) + ListACLsByScope(ctx context.Context, actorRowID uuid.UUID) (ACLs, error) + ListACLsByScopes(ctx context.Context, actorRowIDs []uuid.UUID) (ACLs, error) + ListACLsByTarget(ctx context.Context, target string) (ACLs, error) + ListACLsByTargets(ctx context.Context, targets []string) (ACLs, error) + ListACLsByScopesAndWildTargets(ctx context.Context, actorRowIDs []uuid.UUID, targets []string) (ACLs, error) +} + +type ACL struct { + ID uuid.UUID + CreatedAt time.Time + UpdatedAt time.Time + Scope uuid.UUID + Target string + Action ACLAction + Priority int +} + +type aclSorter struct { + acls []ACL + by func(a1, a2 *ACL) bool +} + +func (s *aclSorter) Len() int { + return len(s.acls) +} + +func (s *aclSorter) Swap(i, j int) { + s.acls[i], s.acls[j] = s.acls[j], s.acls[i] +} + +func (s *aclSorter) Less(i, j int) bool { + return s.by(&s.acls[i], &s.acls[j]) +} + +type ACLs []ACL + +func (l ACLs) ToMatchSet(highScopes *common.UniqueUUIDs, defaultAllow bool) common.MatcherSet { + highAllow := make([]common.AllowDenyMatcher, 0) + highDeny := make([]common.AllowDenyMatcher, 0) + lowAllow := make([]common.AllowDenyMatcher, 0) + lowDeny := make([]common.AllowDenyMatcher, 0) + fmt.Println(l) + for _, acl := range l { + matcher := common.NewAllowDenyMatcher(acl.Target) + _, isHigh := highScopes.Tracking[acl.Scope] + fmt.Println("placing matcher", isHigh, acl.Action) + if isHigh && acl.Action == AllowACLAction { + highAllow = append(highAllow, matcher) + } else if isHigh && acl.Action == DenyACLAction { + highDeny = append(highDeny, matcher) + } else if !isHigh && acl.Action == AllowACLAction { + lowAllow = append(lowAllow, matcher) + } else if !isHigh && acl.Action == DenyACLAction { + lowDeny = append(lowDeny, matcher) + } + } + return common.MatcherSet{ + Default: defaultAllow, + HighAllow: highAllow, + HighDeny: highDeny, + LowAllow: lowAllow, + LowDeny: lowDeny, + } +} + +// +// func (l ACLs) WithPriority(serverScope uuid.UUID) ACLs { +// results := make(ACLs, len(l)) +// for i, val := range l { +// priority := 1 +// if val.Scope == serverScope { +// priority++ +// } +// if val.Action == AllowACLAction { +// priority++ +// } +// results[i] = ACL{ +// ID: val.ID, +// CreatedAt: val.CreatedAt, +// UpdatedAt: val.UpdatedAt, +// Scope: val.Scope, +// Target: val.Target, +// Action: val.Action, +// Priority: priority, +// } +// } +// +// ps := &aclSorter{ +// acls: results, +// by: func(a1, a2 *ACL) bool { +// return a1.Priority < a2.Priority +// }, +// } +// sort.Sort(ps) +// +// for _, acl := range l { +// fmt.Println(acl.Scope, acl.Target, acl.Action) +// } +// +// return results +// } + +type ACLAction int8 + +const ( + AllowACLAction ACLAction = 0 + DenyACLAction ACLAction = 1 +) + +var aclsFields = []string{ + "id", + "created_at", + "updated_at", + "scope", + "target", + "action", +} + +func (s pgStorage) RecordACL(ctx context.Context, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) { + rowID := NewV4() + now := s.now() + return s.RecordACLAll(ctx, rowID, now, now, actorRowID, target, action, wild) +} + +func (s pgStorage) RecordACLAll(ctx context.Context, rowID uuid.UUID, createdAt, updatedAt time.Time, actorRowID uuid.UUID, target string, action ACLAction, wild bool) (uuid.UUID, error) { + query := `INSERT INTO acls (id, created_at, updated_at, scope, target, action, wild) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT ON CONSTRAINT acls_scope_action_uindex DO UPDATE SET updated_at = $3, action = $6 RETURNING id` + var id uuid.UUID + err := s.db.QueryRowContext(ctx, query, rowID, createdAt, updatedAt, actorRowID, target, action, wild).Scan(&id) + return id, errors.WrapGroupUpsertFailedError(err) +} + +func (s pgStorage) ListACLsByScope(ctx context.Context, actorRowID uuid.UUID) (ACLs, error) { + query := fmt.Sprintf(`SELECT %s FROM acls WHERE scope = $1`, strings.Join(aclsFields, ",")) + return s.queryACLs(ctx, query, actorRowID) +} + +func (s pgStorage) ListACLsByScopes(ctx context.Context, actorRowIDs []uuid.UUID) (ACLs, error) { + if len(actorRowIDs) == 0 { + return []ACL{}, nil + } + query := fmt.Sprintf("SELECT %s FROM acls WHERE scope IN (%s)", strings.Join(aclsFields, ","), strings.Join(common.DollarForEach(len(actorRowIDs)), ",")) + return s.queryACLs(ctx, query, common.UUIDsToInterfaces(actorRowIDs)...) +} + +func (s pgStorage) ListACLsByTarget(ctx context.Context, target string) (ACLs, error) { + query := fmt.Sprintf(`SELECT %s FROM acls WHERE target = $1`, strings.Join(aclsFields, ",")) + return s.queryACLs(ctx, query, target) +} + +func (s pgStorage) ListACLsByTargets(ctx context.Context, targets []string) (ACLs, error) { + if len(targets) == 0 { + return []ACL{}, nil + } + query := fmt.Sprintf(`SELECT %s FROM acls WHERE target IN (%s)`, strings.Join(aclsFields, ","), strings.Join(common.DollarForEach(len(targets)), ",")) + return s.queryACLs(ctx, query, common.StringsToInterfaces(targets)...) +} + +func (s pgStorage) ListACLsByScopesAndWildTargets(ctx context.Context, actorRowIDs []uuid.UUID, targets []string) (ACLs, error) { + if len(actorRowIDs) == 0 || len(targets) == 0 { + return []ACL{}, nil + } + scopePlaceholders := strings.Join(common.DollarForEach(len(actorRowIDs)), ",") + targetPlaceholders := strings.Join(common.DollarForEachFrom(len(actorRowIDs)+1, len(actorRowIDs)+len(targets)), ",") + query := fmt.Sprintf(`SELECT %s FROM acls WHERE scope IN (%s) AND (target IN (%s) OR wild = $%d)`, strings.Join(aclsFields, ","), scopePlaceholders, targetPlaceholders, len(actorRowIDs)+len(targets)+1) + args := make([]interface{}, 0) + for _, actorRowID := range actorRowIDs { + args = append(args, actorRowID) + } + for _, target := range targets { + args = append(args, target) + } + args = append(args, true) + return s.queryACLs(ctx, query, args...) +} + +func (s pgStorage) queryACLs(ctx context.Context, query string, args ...interface{}) (ACLs, error) { + results := make([]ACL, 0) + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.NewObjectSelectFailedError(err) + } + defer rows.Close() + for rows.Next() { + var acl ACL + if err := rows.Scan(&acl.ID, &acl.CreatedAt, &acl.UpdatedAt, &acl.Scope, &acl.Target, &acl.Action); err != nil { + return nil, errors.NewInvalidObjectError(err) + } + results = append(results, acl) + } + return results, nil +} diff --git a/storage/actor.go b/storage/actor.go index 518d8f6..0e43b36 100644 --- a/storage/actor.go +++ b/storage/actor.go @@ -30,6 +30,7 @@ type ActorStorage interface { ActorSubjects(ctx context.Context, actors []uuid.UUID) ([]ActorAlias, error) ActorAliasSubjectExists(ctx context.Context, alias string) (bool, error) FilterGroupsByActorID(ctx context.Context, actorIDs []string) ([]string, error) + ActorIDForActorRowID(ctx context.Context, actorRowID uuid.UUID) (string, error) UpdateActorPayload(ctx context.Context, actorRowID uuid.UUID, payload Payload) error } @@ -260,6 +261,18 @@ func (s pgStorage) ActorRowIDForActorID(ctx context.Context, actorID string) (uu return actorRowID, nil } +func (s pgStorage) ActorIDForActorRowID(ctx context.Context, actorRowID uuid.UUID) (string, error) { + var actorID string + err := s.db.QueryRowContext(ctx, `SELECT actor_id FROM actors WHERE id = $1`, actorRowID).Scan(&actorID) + if err != nil { + if err == sql.ErrNoRows { + return "", errors.NewActorNotFoundError(err) + } + return "", errors.NewActorSelectFailedError(err) + } + return actorID, nil +} + func (s pgStorage) ActorSubjects(ctx context.Context, actors []uuid.UUID) ([]ActorAlias, error) { if len(actors) == 0 { return nil, nil diff --git a/storage/storage.go b/storage/storage.go index ef0aa17..478991b 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -22,6 +22,7 @@ type Storage interface { AssetStorage InstanceStatsStorage GroupStorage + ACLStorage GetExecutor() QueryExecute } @@ -259,4 +260,4 @@ func CountMap(counts []Count) map[string]int { results[c.Key] = c.Count } return results -} \ No newline at end of file +} diff --git a/templates/configure.html b/templates/configure.html index 3c58525..0d45a00 100644 --- a/templates/configure.html +++ b/templates/configure.html @@ -9,8 +9,8 @@
+
+
+

ACLs

+
+
+ + + + Enter an actor URL or domain. + +
+
+ + +
+ +
+
+
+
+
+ + + + + + + + + + {{ range $acl := .acls }} + + + + + + {{ end }} + +
ScopeTargetAction
{{ $acl.Scope }}{{ $acl.Target }}{{ if eq $acl.Action 0 }}Allow{{ else }}Deny{{ end }}
+
+
+
+
+

ACL Test

+
+ +
+ + +
+ +
+
+
{{end}} \ No newline at end of file diff --git a/translations/en/errors_generated.json b/translations/en/errors_generated.json index b2e8286..58585f3 100644 --- a/translations/en/errors_generated.json +++ b/translations/en/errors_generated.json @@ -883,5 +883,10 @@ "locale": "en", "key": "TAVWEBAAAAAAAI", "trans": "The user id is invalid." + }, + { + "locale": "en", + "key": "TAVWEBAAAAAAAJ", + "trans": "Access denied." } ] \ No newline at end of file diff --git a/web/command.go b/web/command.go index 7ea198d..e7b5c23 100644 --- a/web/command.go +++ b/web/command.go @@ -151,6 +151,11 @@ func serverCommandAction(cliCtx *cli.Context) error { s := storage.DefaultStorage(storage.WrapMetricDriver(fact, "tavern", "web", storage.LoggingSQLDriver{Driver: db, Logger: logger})) + serverActorRowID, err := s.ActorRowIDForActorID(context.Background(), fmt.Sprintf("https://%s/server", domain)) + if err != nil { + return err + } + utrans, err := config.Trans(cliCtx) if err != nil { return err @@ -269,22 +274,23 @@ func serverCommandAction(cliCtx *cli.Context) error { } h := handler{ - storage: s, - logger: logger, - domain: domain, - sentryConfig: sentryConfig, - fedConfig: fedConfig, - groupConfig: groupConfig, - webFingerQueue: webFingerQueue, - crawlQueue: crawlQueue, - assetQueue: assetQueue, - adminUser: cliCtx.String("admin-name"), - url: tmplUrlGen(siteBase), - svgConverter: svgConv, - assetStorage: assetStorage, - httpClient: httpClient, - metricFactory: promauto.With(registry), - publisherClient: newPublisherClient(logger, httpClient, publisherConfig), + storage: s, + logger: logger, + domain: domain, + sentryConfig: sentryConfig, + fedConfig: fedConfig, + groupConfig: groupConfig, + webFingerQueue: webFingerQueue, + crawlQueue: crawlQueue, + assetQueue: assetQueue, + adminUser: cliCtx.String("admin-name"), + url: tmplUrlGen(siteBase), + svgConverter: svgConv, + assetStorage: assetStorage, + httpClient: httpClient, + metricFactory: promauto.With(registry), + publisherClient: newPublisherClient(logger, httpClient, publisherConfig), + serverActorRowID: serverActorRowID, } configI18nMiddleware(sentryConfig, logger, utrans, domain, r) @@ -395,6 +401,8 @@ func serverCommandAction(cliCtx *cli.Context) error { authenticated.GET("/configure", h.configure) authenticated.POST("/configure/user", h.saveUserSettings) + authenticated.POST("/configure/acl", h.createACL) + authenticated.POST("/configure/acl-validate", h.validateACL) authenticated.GET("/notifications", h.notifications) diff --git a/web/handler.go b/web/handler.go index 84147de..ead5683 100644 --- a/web/handler.go +++ b/web/handler.go @@ -8,6 +8,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" "github.com/prometheus/client_golang/prometheus/promauto" "go.uber.org/zap" @@ -37,7 +38,8 @@ type handler struct { httpClient common.HTTPClient metricFactory promauto.Factory - publisherClient *publisherClient + publisherClient *publisherClient + serverActorRowID uuid.UUID } func (h handler) hardFail(ctx *gin.Context, err error, fields ...zap.Field) { diff --git a/web/handler_auth.go b/web/handler_auth.go index 4dd5bd9..5684dd8 100644 --- a/web/handler_auth.go +++ b/web/handler_auth.go @@ -1,9 +1,15 @@ package web import ( + "context" + "strings" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" + "go.uber.org/zap" + "github.com/ngerakines/tavern/common" "github.com/ngerakines/tavern/errors" "github.com/ngerakines/tavern/storage" ) @@ -51,3 +57,28 @@ func (h handler) loggedInAPI(c *gin.Context, requireUser bool) (*storage.User, s } return user, session, true } + +func allow(ctx context.Context, logger *zap.Logger, s storage.Storage, serverScope, localScope uuid.UUID, target string) bool { + var matchSet common.MatcherSet + + domain := strings.TrimPrefix(target, "https://") + if i := strings.IndexRune(domain, '/'); i > -1 { + domain = domain[:i] + } + targets := common.NewUniqueStrings().With(target, domain) + + txErr := storage.TransactionalStorage(ctx, s, func(tx storage.Storage) error { + acls, err := tx.ListACLsByScopesAndWildTargets(ctx, []uuid.UUID{serverScope, localScope}, targets.Values) + if err != nil { + return err + } + matchSet = acls.ToMatchSet(common.NewUniqueUUIDs().With(serverScope), true) + return nil + }) + if txErr != nil { + logger.Warn("unable to verify allow deny", zap.Error(txErr), zap.String("target", target)) + return false + } + + return matchSet.Allow(domain, target) +} diff --git a/web/handler_configure.go b/web/handler_configure.go index 95e94fe..656039c 100644 --- a/web/handler_configure.go +++ b/web/handler_configure.go @@ -1,12 +1,17 @@ package web import ( + "fmt" "net/http" + "strconv" + "strings" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" "go.uber.org/zap" + "github.com/ngerakines/tavern/common" "github.com/ngerakines/tavern/errors" "github.com/ngerakines/tavern/storage" ) @@ -52,6 +57,19 @@ func (h handler) configure(c *gin.Context) { "auto_accept_followers": user.AcceptFollowers, } + serverActorRowID, err := h.storage.ActorRowIDForActorID(ctx, fmt.Sprintf("https://%s/server", h.domain)) + if err != nil { + h.hardFail(c, err) + return + } + + acls, err := h.storage.ListACLsByScopes(ctx, []uuid.UUID{user.ActorID, serverActorRowID}) + if err != nil { + h.hardFail(c, err) + return + } + data["acls"] = acls + if err = session.Save(); err != nil { h.hardFail(c, err) return @@ -74,12 +92,12 @@ func (h handler) saveUserSettings(c *gin.Context) { zap.String("reply_collection_updates", c.PostForm("reply_collection_updates")), ) - err = storage.TransactionalStorage(ctx, h.storage, func(storage storage.Storage) error { - txErr := storage.UpdateUserAutoAcceptFollowers(ctx, user.ID, c.PostForm("auto_accept_followers") == "checked") + err = storage.TransactionalStorage(ctx, h.storage, func(tx storage.Storage) error { + txErr := tx.UpdateUserAutoAcceptFollowers(ctx, user.ID, c.PostForm("auto_accept_followers") == "checked") if txErr != nil { return txErr } - return storage.UpdateUserReplyCollectionUpdates(ctx, user.ID, c.PostForm("reply_collection_updates") == "checked") + return tx.UpdateUserReplyCollectionUpdates(ctx, user.ID, c.PostForm("reply_collection_updates") == "checked") }) if err != nil { h.hardFail(c, err) @@ -88,3 +106,73 @@ func (h handler) saveUserSettings(c *gin.Context) { c.Redirect(http.StatusFound, h.url("configure")) } + +func (h handler) createACL(c *gin.Context) { + if c.PostForm("validate") == "validate" { + h.validateACL(c) + return + } + + session := sessions.Default(c) + ctx := c.Request.Context() + user, err := h.storage.GetUserBySession(ctx, session) + if err != nil { + h.hardFail(c, err) + return + } + + wild := strings.ContainsRune(c.PostForm("target"), '*') + + txErr := storage.TransactionalStorage(ctx, h.storage, func(tx storage.Storage) error { + action, err := strconv.Atoi(c.PostForm("action")) + if err != nil { + return err + } + _, err = tx.RecordACL(ctx, user.ActorID, c.PostForm("target"), storage.ACLAction(action), wild) + return err + }) + if txErr != nil { + h.hardFail(c, txErr) + return + } + + h.flashSuccessOrFail(c, h.url("configure"), "ACL created") +} + +func (h handler) validateACL(c *gin.Context) { + session := sessions.Default(c) + ctx := c.Request.Context() + user, err := h.storage.GetUserBySession(ctx, session) + if err != nil { + h.hardFail(c, err) + return + } + + target := c.PostForm("target") + domain := strings.TrimPrefix(target, "https://") + if i := strings.IndexRune(domain, '/'); i > -1 { + domain = domain[:i] + } + + var matchSet common.MatcherSet + + txErr := storage.TransactionalStorage(ctx, h.storage, func(tx storage.Storage) error { + acls, err := tx.ListACLsByScopesAndWildTargets(ctx, []uuid.UUID{user.ActorID, h.serverActorRowID}, []string{domain, target}) + if err != nil { + return err + } + matchSet = acls.ToMatchSet(common.NewUniqueUUIDs().With(h.serverActorRowID), true) + return nil + }) + if txErr != nil { + h.hardFail(c, txErr) + return + } + + if matchSet.Allow(domain, target) { + h.flashSuccessOrFail(c, h.url("configure"), "ACL allow") + return + } + + h.flashErrorOrFail(c, h.url("configure"), fmt.Errorf("ACL deny")) +} diff --git a/web/handler_groups_inbox.go b/web/handler_groups_inbox.go index ca4aa58..3eb893b 100644 --- a/web/handler_groups_inbox.go +++ b/web/handler_groups_inbox.go @@ -63,11 +63,11 @@ func (h handler) groupActorInbox(c *gin.Context) { payloadType, _ := storage.JSONString(payload, "type") actor, _ := storage.JSONString(payload, "actor") - if len(actor) > 0 { - err = h.webFingerQueue.Add(actor) - if err != nil { - h.logger.Error("unable to add actor to web finger queue", zap.String("actor", actor)) - } + + if len(actor) > 0 && !allow(c.Request.Context(), h.logger, h.storage, h.serverActorRowID, group.ActorID, actor) { + h.logger.Debug("actor denied", zap.String("group", name), zap.String("actor", actor)) + c.Status(http.StatusOK) + return } switch payloadType { @@ -88,7 +88,7 @@ func (h handler) groupActorInbox(c *gin.Context) { } func (h handler) groupActorInboxInvite(c *gin.Context, group storage.Group, payload storage.Payload) { - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, group.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -252,7 +252,7 @@ func (h handler) groupActorInboxFollowActor(c *gin.Context, group storage.Group, return } - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, group.ActorID); err != nil { h.logger.Debug("signature verification failed", zap.Error(err)) h.unauthorizedJSON(c, err) return @@ -400,7 +400,7 @@ func (h handler) groupActorInboxUndoFollowActor(c *gin.Context, group storage.Gr return } - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, group.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -465,7 +465,7 @@ func (h handler) groupActorInboxUndo(c *gin.Context, group storage.Group, payloa func (h handler) groupActorInboxCreate(c *gin.Context, group storage.Group, payload storage.Payload) { // Because actors must follow the group, it is safe to assume that we // have actor and actor key records for all valid incoming activities. - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, group.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -596,7 +596,7 @@ func (h handler) groupActorInboxCreate(c *gin.Context, group storage.Group, payl } func (h handler) groupActorInboxAnnounce(c *gin.Context, group storage.Group, payload storage.Payload) { - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, group.ActorID); err != nil { h.unauthorizedJSON(c, err) return } diff --git a/web/handler_user_inbox.go b/web/handler_user_inbox.go index 4a312de..7b4811f 100644 --- a/web/handler_user_inbox.go +++ b/web/handler_user_inbox.go @@ -67,11 +67,11 @@ func (h handler) userActorInbox(c *gin.Context) { payloadType, _ := storage.JSONString(payload, "type") actor, _ := storage.JSONString(payload, "actor") - if len(actor) > 0 { - err = h.webFingerQueue.Add(actor) - if err != nil { - h.logger.Error("unable to add actor to web finger queue", zap.String("actor", actor)) - } + + if len(actor) > 0 && !allow(c.Request.Context(), h.logger, h.storage, h.serverActorRowID, user.ActorID, actor) { + h.logger.Debug("actor denied", zap.String("user", name), zap.String("actor", actor)) + c.Status(http.StatusOK) + return } switch payloadType { @@ -96,7 +96,7 @@ func (h handler) userActorInbox(c *gin.Context) { } func (h handler) actorInboxAccept(c *gin.Context, user *storage.User, payload storage.Payload) { - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, user.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -132,7 +132,7 @@ func (h handler) actorInboxAccept(c *gin.Context, user *storage.User, payload st } func (h handler) actorInboxReject(c *gin.Context, user *storage.User, payload storage.Payload) { - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, user.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -214,7 +214,7 @@ func (h handler) actorInboxFollowActor(c *gin.Context, user *storage.User, paylo return } - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, user.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -297,7 +297,7 @@ func (h handler) actorInboxFollowNote(c *gin.Context, user *storage.User, payloa return } - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, user.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -412,7 +412,7 @@ func (h handler) actorInboxUndoFollowActor(c *gin.Context, user *storage.User, p return } - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, user.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -452,7 +452,7 @@ func (h handler) actorInboxUndoFollowNote(c *gin.Context, user *storage.User, pa return } - if err := h.verifySignature(c); err != nil { + if err := h.verifySignature(c, user.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -513,7 +513,7 @@ func (h handler) actorInboxCreate(c *gin.Context, user *storage.User, payload st return } - if err = h.verifySignature(c); err != nil { + if err = h.verifySignature(c, user.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -666,7 +666,7 @@ func (h handler) actorInboxAnnounce(c *gin.Context, user *storage.User, payload return } - if err = h.verifySignature(c); err != nil { + if err = h.verifySignature(c, user.ActorID); err != nil { h.unauthorizedJSON(c, err) return } @@ -807,7 +807,7 @@ func (h handler) actorInboxAnnounce(c *gin.Context, user *storage.User, payload } func (h handler) actorInboxDelete(c *gin.Context, user *storage.User, payload storage.Payload, raw []byte) { - err := h.verifySignature(c) + err := h.verifySignature(c, user.ActorID) if err != nil { h.unauthorizedJSON(c, err) return @@ -1013,17 +1013,35 @@ func (h handler) isActivityRelevant(ctx context.Context, activity storage.Payloa return false, nil } -func (h handler) verifySignature(c *gin.Context) error { +func (h handler) verifySignature(c *gin.Context, scope uuid.UUID) error { // Host header isn't set for some reason. c.Request.Header.Add("Host", c.Request.Host) verifier, err := httpsig.NewVerifier(c.Request) if err != nil { return err } - key, err := h.storage.GetKey(c.Request.Context(), verifier.KeyId()) - if err != nil { - return err + ctx := c.Request.Context() + var key *storage.Key + var actorID string + txErr := storage.TransactionalStorage(ctx, h.storage, func(tx storage.Storage) error { + var err error + key, err = h.storage.GetKey(c.Request.Context(), verifier.KeyId()) + if err != nil { + return err + } + actorID, err = tx.ActorIDForActorRowID(ctx, key.Actor) + if err != nil { + return err + } + return nil + }) + if txErr != nil { + return txErr } + if !allow(c.Request.Context(), h.logger, h.storage, h.serverActorRowID, scope, actorID) { + return errors.NewAccessDeniedError(nil) + } + publicKey, err := key.GetDecodedPublicKey() if err != nil { return err