tavern/storage/actor.go

308 lines
7.5 KiB
Go

package storage
import (
"context"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/pem"
"fmt"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/ngerakines/tavern/errors"
)
type ActorStorage interface {
GetActor(ctx context.Context, id string) (Actor, error)
GetActorByAlias(ctx context.Context, alias string) (Actor, error)
CreateActor(context.Context, uuid.UUID, uuid.UUID, string, string, string, string, string, string) error
GetKey(ctx context.Context, keyID string) (*Key, error)
RecordActorAlias(ctx context.Context, actorID, alias string) error
ActorPayloadsByActorID(ctx context.Context, actorIDs []string) ([]Payload, error)
}
type Actor interface {
GetID() string
GetInbox() string
GetPublicKey() string
GetKeyID() string
GetDecodedPublicKey() (*rsa.PublicKey, error)
}
type ActorID string
type actor struct {
ID uuid.UUID
ActorID string
CreatedAt time.Time
UpdatedAt time.Time
Inbox string
PublicKey string
KeyID string
payload string
}
type LocalActor struct {
User *User
ActorID ActorID
}
type Key struct {
ID uuid.UUID
KeyID string
PublicKey string
CreatedAt time.Time
}
var _ Actor = &actor{}
var _ Actor = &LocalActor{}
func (s pgStorage) GetActor(ctx context.Context, id string) (Actor, error) {
return s.getActor(ctx, `SELECT id, actor_id, payload, created_at, updated_at from actors WHERE actor_id = $1`, id)
}
func (s pgStorage) GetActorByAlias(ctx context.Context, alias string) (Actor, error) {
return s.getActor(ctx, `SELECT id, actor_id, payload, created_at, updated_at from actors WHERE aliases @> ARRAY[$1]::varchar[]`, alias)
}
func (s pgStorage) getActor(ctx context.Context, query string, args ...interface{}) (Actor, error) {
a := &actor{}
err := s.db.
QueryRowContext(ctx, query, args...).
Scan(&a.ID,
&a.ActorID,
&a.payload,
&a.CreatedAt,
&a.UpdatedAt)
if err != nil {
return nil, err
}
if err = a.init(); err != nil {
return nil, err
}
return a, nil
}
func (s pgStorage) GetKey(ctx context.Context, keyID string) (*Key, error) {
key := &Key{}
err := s.db.
QueryRowContext(ctx, `SELECT id, key_id, public_key, created_at from keys WHERE key_id = $1`, keyID).
Scan(&key.ID,
&key.KeyID,
&key.PublicKey,
&key.CreatedAt)
if err != nil {
return nil, err
}
return key, nil
}
func (s pgStorage) CreateActor(ctx context.Context, actorRowID, keyRowID uuid.UUID, actorID, payload, keyID, pem, name, domain string) error {
now := s.now()
return runTransactionWithOptions(s.db, func(tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "INSERT INTO actors (id, actor_id, payload, created_at, name, domain) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO NOTHING", actorRowID, actorID, payload, now, name, domain)
if err != nil {
return errors.NewInsertQueryFailedError(err)
}
_, err = tx.ExecContext(ctx, "INSERT INTO keys (id, key_id, public_key, created_at) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", keyRowID, keyID, pem, now)
return errors.WrapInsertQueryFailedError(err)
})
}
func (s pgStorage) RecordActorAlias(ctx context.Context, actorID, alias string) error {
now := s.now()
_, err := s.db.ExecContext(ctx, `UPDATE actors SET aliases = array_append(aliases, $2), updated_at = $3 WHERE actor_id = $1 AND NOT(aliases @> ARRAY[$2]::varchar[])`, actorID, alias, now)
return errors.WrapUpdateQueryFailedError(err)
}
func (s pgStorage) ActorPayloadsByActorID(ctx context.Context, actorIDs []string) ([]Payload, error) {
var results []Payload
if len(actorIDs) == 0 {
return results, nil
}
params := make([]string, len(actorIDs))
args := make([]interface{}, len(actorIDs))
for i, id := range actorIDs {
params[i] = fmt.Sprintf("$%d", i+1)
args[i] = id
}
query := fmt.Sprintf("SELECT payload FROM actors WHERE actor_id IN (%s)", strings.Join(params, ", "))
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
defer rows.Close()
for rows.Next() {
var data string
if err := rows.Scan(&data); err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
payload, err := PayloadFromString(data)
if err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
results = append(results, payload)
}
return results, nil
}
func NewActorID(name, domain string) ActorID {
return ActorID(fmt.Sprintf("https://%s/users/%s", domain, name))
}
func KeyFromActor(actor Payload) (string, string, error) {
if actor == nil {
return "", "", fmt.Errorf("unable to get key from actor: actor nil")
}
publicKey, ok := JSONMap(actor, "publicKey")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object missing")
}
keyID, ok := JSONString(publicKey, "id")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object id invalid")
}
pem, ok := JSONString(publicKey, "publicKeyPem")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object id invalid")
}
return keyID, pem, nil
}
func InboxFromActor(actor Payload) (string, error) {
inbox, ok := JSONString(actor, "inbox")
if !ok {
return "", fmt.Errorf("unable to get inbox from actor: inbox field")
}
return inbox, nil
}
func (a *actor) GetDecodedPublicKey() (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(a.PublicKey))
if block == nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
rsaPublicKey, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("invalid RSA PEM")
}
return rsaPublicKey, nil
}
func (a *actor) GetID() string {
return a.ActorID
}
func (a *actor) GetInbox() string {
return a.Inbox
}
func (a *actor) GetKeyID() string {
return a.KeyID
}
func (a *actor) GetPublicKey() string {
return a.PublicKey
}
func (a *actor) init() error {
p, err := PayloadFromString(a.payload)
if err != nil {
return err
}
a.Inbox, err = InboxFromActor(p)
if err != nil {
return err
}
a.KeyID, a.PublicKey, err = KeyFromActor(p)
if err != nil {
return err
}
return nil
}
func (l LocalActor) GetID() string {
return string(l.ActorID)
}
func (l LocalActor) GetInbox() string {
return l.ActorID.Inbox()
}
func (l LocalActor) GetPublicKey() string {
return l.User.PublicKey
}
func (l LocalActor) GetKeyID() string {
return l.ActorID.MainKey()
}
func (l LocalActor) GetDecodedPublicKey() (*rsa.PublicKey, error) {
return l.GetDecodedPublicKey()
}
func (k *Key) GetDecodedPublicKey() (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(k.PublicKey))
if block == nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
rsaPublicKey, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("invalid RSA PEM")
}
return rsaPublicKey, nil
}
func (ID ActorID) Followers() string {
return fmt.Sprintf("%s/followers", ID)
}
func (ID ActorID) FollowersPage(page int) string {
return fmt.Sprintf("%s/followers?page=%d", ID, page)
}
func (ID ActorID) Following() string {
return fmt.Sprintf("%s/following", ID)
}
func (ID ActorID) FollowingPage(page int) string {
return fmt.Sprintf("%s/following?page=%d", ID, page)
}
func (ID ActorID) Outbox() string {
return fmt.Sprintf("%s/outbox", ID)
}
func (ID ActorID) OutboxPage(page int) string {
return fmt.Sprintf("%s/outbox?page=%d", ID, page)
}
func (ID ActorID) Inbox() string {
return fmt.Sprintf("%s/inbox", ID)
}
func (ID ActorID) MainKey() string {
return fmt.Sprintf("%s#main-key", ID)
}