tavern/storage/actor.go

390 lines
12 KiB
Go

package storage
import (
"context"
"crypto/md5"
"crypto/rsa"
"database/sql"
"encoding/hex"
"fmt"
"math/big"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/errors"
)
type ActorStorage interface {
GetActor(ctx context.Context, id uuid.UUID) (*Actor, error)
GetActorByActorID(ctx context.Context, actorID string) (*Actor, error)
CreateActor(ctx context.Context, actorID string, payload Payload) error
GetKey(ctx context.Context, keyID string) (*Key, error)
RecordActorAlias(ctx context.Context, actorID uuid.UUID, alias string, aliasType ActorAliasType) error
RecordActorKey(ctx context.Context, actorID uuid.UUID, keyID, pem string) error
ActorsByActorID(ctx context.Context, actorIDs []string) ([]*Actor, error)
ActorRowIDForActorID(ctx context.Context, actorID string) (uuid.UUID, error)
GetActorByAlias(ctx context.Context, subject string) (*Actor, error)
ActorSubjects(ctx context.Context, actors []uuid.UUID) ([]ActorAlias, error)
ActorAliasSubjectExists(ctx context.Context, alias string) (bool, error)
FilterGroupsByActorID(ctx context.Context, actorIDs []string) ([]string, error)
ActorIDForActorRowID(ctx context.Context, actorRowID uuid.UUID) (string, error)
UpdateActorPayload(ctx context.Context, actorRowID uuid.UUID, payload Payload) error
}
type Actor struct {
ID uuid.UUID
ActorID string
Payload Payload
CreatedAt time.Time
UpdatedAt time.Time
ActorType string
PreferredUsername string
Name string
Inbox string
CurrentKey *KeyData
}
type ActorAlias struct {
ID uuid.UUID
Actor uuid.UUID
Alias string
AliasType ActorAliasType
CreatedAt time.Time
UpdatedAt time.Time
}
type KeyData struct {
KeyID string
Owner string
PEM string
}
type Key struct {
ID uuid.UUID
Actor uuid.UUID
KeyID string
PEM string
CreatedAt time.Time
UpdatedAt time.Time
}
type ActorAliasType int16
const (
ActorAliasSubject ActorAliasType = iota
ActorAliasSelf
ActorAliasProfilePage
)
var ActorsFields = []string{
"id",
"actor_id",
"payload",
"created_at",
"updated_at",
}
var ActorSubjectsFields = []string{
"id",
"actor_id",
"alias",
"alias_type",
"created_at",
"updated_at",
}
func actorsSelectQuery(join string, where []string) string {
var query strings.Builder
query.WriteString("SELECT ")
query.WriteString(strings.Join(common.MapStrings(ActorsFields, common.AddPrefix("a.")), ", "))
query.WriteString(" FROM actors a")
if len(join) > 0 {
query.WriteString(" INNER JOIN ")
query.WriteString(join)
}
if len(where) > 0 {
query.WriteString(" WHERE")
for i, w := range where {
if i > 1 {
query.WriteString(" AND ")
}
query.WriteString(" ")
query.WriteString(w)
}
}
return query.String()
}
func actorAliasesSelectQuery(join string, where []string) string {
var query strings.Builder
query.WriteString("SELECT ")
query.WriteString(strings.Join(common.MapStrings(ActorSubjectsFields, common.AddPrefix("aa.")), ", "))
query.WriteString(" FROM actor_aliases aa")
if len(join) > 0 {
query.WriteString(" INNER JOIN ")
query.WriteString(join)
}
if len(where) > 0 {
query.WriteString(" WHERE")
for i, w := range where {
if i > 1 {
query.WriteString(" AND ")
}
query.WriteString(" ")
query.WriteString(w)
}
}
return query.String()
}
func (s pgStorage) GetActor(ctx context.Context, id uuid.UUID) (*Actor, error) {
return s.getFirstActor(ctx, actorsSelectQuery("", []string{"a.id = $1"}), id)
}
func (s pgStorage) GetActorByActorID(ctx context.Context, actorID string) (*Actor, error) {
return s.getFirstActor(ctx, actorsSelectQuery("", []string{"a.actor_id = $1"}), actorID)
}
func (s pgStorage) GetActorByAlias(ctx context.Context, alias string) (*Actor, error) {
query := actorsSelectQuery("actor_aliases aa ON a.id = aa.actor_id", []string{"aa.alias = $1"})
return s.getFirstActor(ctx, query, alias)
}
func (s pgStorage) getFirstActor(ctx context.Context, query string, args ...interface{}) (*Actor, error) {
results, err := s.getActors(ctx, query, args...)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, errors.NewNotFoundError(nil)
}
return results[0], nil
}
func (s pgStorage) getActors(ctx context.Context, query string, args ...interface{}) ([]*Actor, error) {
results := make([]*Actor, 0)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
a := &Actor{}
err = rows.Scan(&a.ID, &a.ActorID, &a.Payload, &a.CreatedAt, &a.UpdatedAt)
if err != nil {
return nil, err
}
a.ActorType, _ = JSONString(a.Payload, "type")
a.Inbox, _ = JSONString(a.Payload, "inbox")
a.PreferredUsername, _ = JSONString(a.Payload, "preferredUsername")
a.Name, _ = JSONString(a.Payload, "name")
keyID, hasKeyID := JSONDeepString(a.Payload, "publicKey", "id")
keyOwner, hasKeyOwner := JSONDeepString(a.Payload, "publicKey", "owner")
keyPEM, hasKeyPEM := JSONDeepString(a.Payload, "publicKey", "publicKeyPem")
if hasKeyID && hasKeyOwner && hasKeyPEM {
a.CurrentKey = &KeyData{
KeyID: keyID,
Owner: keyOwner,
PEM: keyPEM,
}
}
results = append(results, a)
}
return results, nil
}
func (s pgStorage) GetKey(ctx context.Context, keyID string) (*Key, error) {
key := &Key{}
err := s.db.
QueryRowContext(ctx, `SELECT id, actor_id, key_id, pem, created_at, updated_at from actor_keys WHERE key_id = $1`, keyID).
Scan(&key.ID,
&key.Actor,
&key.KeyID,
&key.PEM,
&key.CreatedAt,
&key.UpdatedAt)
if err != nil {
return nil, err
}
return key, nil
}
func (s pgStorage) CreateActor(ctx context.Context, actorID string, payload Payload) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "INSERT INTO actors (id, actor_id, payload, created_at, updated_at) VALUES ($1, $2, $3, $4, $4) ON CONFLICT ON CONSTRAINT actors_actor_id DO UPDATE SET payload = $3, updated_at = $4", NewV4(), actorID, payload, now)
return errors.WrapInsertQueryFailedError(err)
}
func (s pgStorage) RecordActorAlias(ctx context.Context, actorID uuid.UUID, alias string, aliasType ActorAliasType) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "INSERT INTO actor_aliases (id, actor_id, alias, alias_type, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $5) ON CONFLICT DO NOTHING", NewV4(), actorID, alias, aliasType, now)
return errors.WrapInsertQueryFailedError(err)
}
func (s pgStorage) RecordActorKey(ctx context.Context, actorID uuid.UUID, keyID, pem string) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "INSERT INTO actor_keys (id, actor_id, key_id, pem, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $5) ON CONFLICT ON CONSTRAINT actor_keys_lookup DO UPDATE SET pem = $4, updated_at = $5", NewV4(), actorID, keyID, pem, now)
return errors.WrapActorKeyInsertFailedError(err)
}
func (s pgStorage) ActorsByActorID(ctx context.Context, actorIDs []string) ([]*Actor, error) {
if len(actorIDs) == 0 {
return nil, nil
}
valuesPlaceholder := strings.Join(common.DollarForEach(len(actorIDs)), ",")
query := actorsSelectQuery("", []string{
fmt.Sprintf("a.actor_id IN (%s)", valuesPlaceholder),
})
return s.getActors(ctx, query, common.StringsToInterfaces(actorIDs)...)
}
func (s pgStorage) ActorRowIDForActorID(ctx context.Context, actorID string) (uuid.UUID, error) {
var actorRowID uuid.UUID
err := s.db.QueryRowContext(ctx, `SELECT id FROM actors WHERE actor_id = $1`, actorID).Scan(&actorRowID)
if err != nil {
if err == sql.ErrNoRows {
return uuid.Nil, errors.NewNotFoundError(err)
}
return uuid.Nil, errors.NewQueryFailedError(err)
}
return actorRowID, nil
}
func (s pgStorage) ActorIDForActorRowID(ctx context.Context, actorRowID uuid.UUID) (string, error) {
var actorID string
err := s.db.QueryRowContext(ctx, `SELECT actor_id FROM actors WHERE id = $1`, actorRowID).Scan(&actorID)
if err != nil {
if err == sql.ErrNoRows {
return "", errors.NewActorNotFoundError(err)
}
return "", errors.NewActorSelectFailedError(err)
}
return actorID, nil
}
func (s pgStorage) ActorSubjects(ctx context.Context, actors []uuid.UUID) ([]ActorAlias, error) {
if len(actors) == 0 {
return nil, nil
}
valuesPlaceholder := fmt.Sprintf("aa.actor_id IN (%s)", strings.Join(common.DollarForEach(len(actors)), ","))
query := actorAliasesSelectQuery("", []string{valuesPlaceholder})
rows, err := s.db.QueryContext(ctx, query, common.UUIDsToInterfaces(actors)...)
if err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
defer rows.Close()
var results []ActorAlias
for rows.Next() {
var actorSubject ActorAlias
if err := rows.Scan(&actorSubject.ID, &actorSubject.Actor, &actorSubject.Alias, &actorSubject.AliasType, &actorSubject.CreatedAt, &actorSubject.UpdatedAt); err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
results = append(results, actorSubject)
}
return results, nil
}
func CollectActorSubjectsActorToSubject(actorSubjects []ActorAlias) map[uuid.UUID]string {
results := make(map[uuid.UUID]string)
for _, as := range actorSubjects {
if as.AliasType == ActorAliasSubject {
results[as.Actor] = as.Alias
}
}
return results
}
func ActorFromUserInfo(name, displayName, domain, publicKey string, privateKey *rsa.PrivateKey) Payload {
actor := EmptyPayload()
actor["@context"] = "https://www.w3.org/ns/activitystreams"
actor["id"] = common.ActorURL(domain, name)
actor["inbox"] = fmt.Sprintf("%s/inbox", common.ActorURL(domain, name))
actor["outbox"] = fmt.Sprintf("%s/outbox", common.ActorURL(domain, name))
actor["name"] = displayName
actor["preferredUsername"] = name
actor["summary"] = ""
actor["type"] = "Person"
actor["url"] = common.ActorURL(domain, name)
actor["followers"] = fmt.Sprintf("%s/followers", common.ActorURL(domain, name))
actor["following"] = fmt.Sprintf("%s/following", common.ActorURL(domain, name))
n := privateKey.PublicKey.N.Bytes()
e := big.NewInt(int64(privateKey.PublicKey.E)).Bytes()
fingerPrint := md5.New()
fingerPrint.Write(n)
fingerPrint.Write(e)
keyID := hex.EncodeToString(fingerPrint.Sum(nil))
key := EmptyPayload()
key["id"] = fmt.Sprintf("%s#%s", common.ActorURL(domain, name), keyID)
key["owner"] = common.ActorURL(domain, name)
key["publicKeyPem"] = publicKey
actor["publicKey"] = key
return actor
}
func ActorFromGroupInfo(name, displayName, domain, publicKey string, privateKey *rsa.PrivateKey) Payload {
actor := EmptyPayload()
actor["@context"] = "https://www.w3.org/ns/activitystreams"
actor["id"] = common.GroupActorURL(domain, name)
actor["inbox"] = fmt.Sprintf("%s/inbox", common.GroupActorURL(domain, name))
actor["outbox"] = fmt.Sprintf("%s/outbox", common.GroupActorURL(domain, name))
actor["name"] = displayName
actor["preferredUsername"] = name
actor["summary"] = ""
actor["type"] = "Group"
actor["url"] = common.GroupActorURL(domain, name)
actor["followers"] = fmt.Sprintf("%s/followers", common.GroupActorURL(domain, name))
actor["following"] = fmt.Sprintf("%s/following", common.GroupActorURL(domain, name))
n := privateKey.PublicKey.N.Bytes()
e := big.NewInt(int64(privateKey.PublicKey.E)).Bytes()
fingerPrint := md5.New()
fingerPrint.Write(n)
fingerPrint.Write(e)
keyID := hex.EncodeToString(fingerPrint.Sum(nil))
key := EmptyPayload()
key["id"] = fmt.Sprintf("%s#%s", common.GroupActorURL(domain, name), keyID)
key["owner"] = common.GroupActorURL(domain, name)
key["publicKeyPem"] = publicKey
actor["publicKey"] = key
return actor
}
func (s pgStorage) ActorAliasSubjectExists(ctx context.Context, alias string) (bool, error) {
return s.wrappedExists(errors.WrapActorAliasQueryFailedError, ctx, `SELECT COUNT(*) FROM actor_aliases WHERE alias = $1 AND alias_type = 0`, alias)
}
func (s pgStorage) FilterGroupsByActorID(ctx context.Context, actorIDs []string) ([]string, error) {
if len(actorIDs) == 0 {
return []string{}, nil
}
query := fmt.Sprintf(`SELECT actor_id FROM actors WHERE payload->>'type' = 'Group' AND actor_id in (%s)`, strings.Join(common.DollarForEach(len(actorIDs)), ","))
return s.selectStrings(errors.WrapActorQueryFailedError, ctx, query, common.StringsToInterfaces(actorIDs)...)
}
func (s pgStorage) UpdateActorPayload(ctx context.Context, actorRowID uuid.UUID, payload Payload) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "UPDATE actors SET payload = $3, updated_at = $2 WHERE id = $1", actorRowID, now, payload)
return errors.WrapActorUpdateFailedError(err)
}