Refactoring actor, actor alias, and actor key storage. Closes #23, #25, and #27. Partially closes #30.

This commit is contained in:
Nick Gerakines 2020-03-11 13:56:32 -04:00
parent 4b04329960
commit 8a5b6af137
21 changed files with 822 additions and 538 deletions

2
common/query.go Normal file
View File

@ -0,0 +1,2 @@
package common

62
common/strings.go Normal file
View File

@ -0,0 +1,62 @@
package common
import (
"fmt"
"strconv"
"github.com/gofrs/uuid"
)
func IntRange(min, max int) []int {
a := make([]int, max-min+1)
for i := range a {
a[i] = min + i
}
return a
}
func StringIntRange(min, max int) []string {
results := make([]string, max-min+1)
for i, v := range IntRange(min, max) {
results[i] = strconv.Itoa(v)
}
return results
}
func MapStrings(in []string, transform func(string) string) []string {
results := make([]string, len(in))
for i, v := range in {
results[i] = transform(v)
}
return results
}
func AddPrefix(prefix string) func(string) string {
return func(in string) string {
return fmt.Sprintf("%s%s", prefix, in)
}
}
func Dollar(in string) string {
return fmt.Sprintf("$%s", in)
}
func DollarForEach(max int) []string {
return MapStrings(StringIntRange(1, max), Dollar)
}
func StringsToInterfaces(input []string) []interface{} {
results := make([]interface{}, len(input))
for i, value := range input {
results[i] = value
}
return results
}
func UUIDsToInterfaces(input []uuid.UUID) []interface{} {
results := make([]interface{}, len(input))
for i, value := range input {
results[i] = value.String()
}
return results
}

View File

@ -41,7 +41,7 @@ func (client ActivityClient) GetSigned(location string, localActor storage.Local
return "", nil, err
}
return ldJsonGetSigned(client.HTTPClient, location, signer, localActor.GetKeyID(), privateKey)
return ldJsonGetSigned(client.HTTPClient, location, signer, localActor.Actor.GetKeyID(), privateKey)
}
func ldJsonGet(client common.HTTPClient, location string) (string, storage.Payload, error) {

View File

@ -6,7 +6,6 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"go.uber.org/zap"
@ -18,6 +17,7 @@ import (
type ActorClient struct {
HTTPClient common.HTTPClient
Logger *zap.Logger
Storage storage.Storage
}
func (client ActorClient) Get(location string) (string, storage.Payload, error) {
@ -51,93 +51,134 @@ func (client ActorClient) Get(location string) (string, storage.Payload, error)
return string(body), p, nil
}
func GetOrFetchActor(ctx context.Context, store storage.Storage, logger *zap.Logger, httpClient common.HTTPClient, hint string) (storage.Actor, error) {
var actorURL string
if strings.HasPrefix(hint, "https://") {
actorURL = hint
func GetOrFetchActor(ctx context.Context, store storage.Storage, logger *zap.Logger, httpClient common.HTTPClient, hint string) (*storage.Actor, error) {
wfc := WebFingerClient{
HTTPClient: httpClient,
Logger: logger,
}
count, err := store.RowCount(ctx, `SELECT COUNT(*) FROM actors WHERE aliases @> ARRAY[$1]::varchar[]`, hint)
if strings.HasPrefix(hint, "https://") {
count, err := store.RowCount(ctx, `SELECT COUNT(*) FROM actors WHERE actor_id = $1`, hint)
if err != nil {
return nil, err
}
if count > 0 {
return store.GetActorByActorID(ctx, hint)
}
} else {
count, err := store.RowCount(ctx, `SELECT COUNT(*) FROM actor_aliases WHERE alias = $1`, strings.TrimPrefix(hint, "@"))
if err != nil {
return nil, err
}
if count > 0 {
return store.GetActorByAlias(ctx, hint)
}
}
wfp, err := wfc.Fetch(hint)
if err != nil {
return nil, err
}
if count > 0 {
return store.GetActorByAlias(ctx, hint)
var actorURL string
subject, hasSubject := storage.JSONString(wfp, "subject")
if !hasSubject {
return nil, fmt.Errorf("webfinger request did not contain subject")
}
links, hasLinks := storage.JSONMapList(wfp, "links")
if hasLinks {
for _, link := range links {
rel, hasRel := storage.JSONString(link, "rel")
href, hasHref := storage.JSONString(link, "href")
if hasRel && hasHref && rel == "self" {
actorURL = href
}
}
}
if len(actorURL) == 0 {
wfc := WebFingerClient{
HTTPClient: httpClient,
Logger: logger,
}
wfp, err := wfc.Fetch(hint)
if err != nil {
return nil, err
}
actorURL, err = ActorIDFromWebFingerPayload(wfp)
if err != nil {
return nil, err
}
logger.Debug("parsed actor id from webfinger payload", zap.String("actor", actorURL))
return nil, fmt.Errorf("webfinger response did not contain self link for actor")
}
logger.Debug("parsed actor id from webfinger payload",
zap.String("actor", actorURL),
zap.String("subject", subject))
ac := ActorClient{
HTTPClient: httpClient,
Logger: logger,
}
actorBody, actorPayload, err := ac.Get(actorURL)
_, actorPayload, err := ac.Get(actorURL)
if err != nil {
return nil, err
}
keyID, keyPEM, err := storage.KeyFromActor(actorPayload)
actorRowID := storage.NewV4()
keyRowID := storage.NewV4()
actorID, ok := storage.JSONString(actorPayload, "id")
if !ok {
return nil, fmt.Errorf("no id found for actor")
}
name, ok := storage.JSONString(actorPayload, "preferredUsername")
if !ok {
return nil, fmt.Errorf("no preferredUsername found for actor")
}
u, err := url.Parse(actorID)
if err != nil {
return nil, err
}
domain := u.Hostname()
err = store.CreateActor(ctx, actorRowID, keyRowID, actorID, actorBody, keyID, keyPEM, name, domain)
keyID, keyPEM, err := storage.KeyFromActor(actorPayload)
if err != nil {
return nil, err
}
for _, a := range []string{actorID, actorURL, hint} {
if err = store.RecordActorAlias(ctx, actorID, a); err != nil {
return nil, err
}
err = store.CreateActor(ctx, actorID, actorPayload)
if err != nil {
return nil, err
}
if endpoints, ok := storage.JSONMap(actorPayload, "endpoints"); ok {
sharedInbox, ok := storage.JSONString(endpoints, "sharedInbox")
if ok {
peerID := storage.NewV4()
err = store.CreatePeer(ctx, peerID, sharedInbox)
if err != nil {
return nil, err
actorRowID, err := store.ActorRowIDForActorID(ctx, actorID)
if err != nil {
return nil, err
}
if err = store.RecordActorAlias(ctx, actorRowID, subject, storage.ActorAliasSubject); err != nil {
return nil, err
}
if hasLinks {
for _, link := range links {
rel, hasRel := storage.JSONString(link, "rel")
href, hasHref := storage.JSONString(link, "href")
if !hasRel || !hasHref {
continue
}
switch rel {
case "self":
if err = store.RecordActorAlias(ctx, actorRowID, href, storage.ActorAliasSelf); err != nil {
return nil, err
}
case "http://webfinger.net/rel/profile-page":
if err = store.RecordActorAlias(ctx, actorRowID, href, storage.ActorAliasProfilePage); err != nil {
return nil, err
}
}
}
}
return store.GetActor(ctx, actorID)
err = store.RecordActorKey(ctx, actorRowID, keyID, keyPEM)
if err != nil {
return nil, err
}
// if endpoints, ok := storage.JSONMap(actorPayload, "endpoints"); ok {
// sharedInbox, ok := storage.JSONString(endpoints, "sharedInbox")
// if ok {
// peerID := storage.NewV4()
// err = store.CreatePeer(ctx, peerID, sharedInbox)
// if err != nil {
// return nil, err
// }
// }
// }
return store.GetActor(ctx, actorRowID)
}
func ActorsFromActivity(activity storage.Payload) []string {

View File

@ -31,10 +31,15 @@ type Signer interface {
const CrawlerDefaultMaxCount = 30
func (c Crawler) Start(sa Signer, seed string) ([]string, []string, error) {
func (c Crawler) Start(user *storage.User, seed string) ([]string, []string, error) {
c.Logger.Info("crawler starting", zap.String("seed", seed))
ctx := context.Background()
userActor, err := c.Storage.GetActor(ctx, user.ActorID)
if err != nil {
return nil, nil, err
}
sigConfig := []httpsig.Algorithm{httpsig.RSA_SHA256}
headersToSign := []string{httpsig.RequestTarget, "date"}
signer, _, err := httpsig.NewSigner(sigConfig, headersToSign, httpsig.Signature)
@ -42,7 +47,7 @@ func (c Crawler) Start(sa Signer, seed string) ([]string, []string, error) {
return nil, nil, err
}
privateKey, err := sa.GetPrivateKey()
privateKey, err := user.GetPrivateKey()
if err != nil {
return nil, nil, err
}
@ -75,7 +80,7 @@ func (c Crawler) Start(sa Signer, seed string) ([]string, []string, error) {
}
if existingCount == 0 {
var body string
body, payload, err = ldJsonGetSigned(c.HTTPClient, location, signer, sa.GetKeyID(), privateKey)
body, payload, err = ldJsonGetSigned(c.HTTPClient, location, signer, userActor.GetKeyID(), privateKey)
if err != nil {
return nil, nil, err
}

View File

@ -26,25 +26,20 @@ func (f fakeInbox) GetInbox() string {
return string(f)
}
func (client ActorClient) Broadcast(ctx context.Context, store storage.Storage, localActor storage.LocalActor, payload []byte) error {
// TODO: Remove this hard-coded limit of 100 followers to broadcast to.
followers, err := store.ListAcceptedFollowers(ctx, localActor.User.ID, 100, 0)
if err != nil {
return err
}
actors, err := store.ActorPayloadsByActorID(ctx, followers)
destinations := make([]hasInbox, 0)
for _, actor := range actors {
inbox, hasInbox := storage.JSONString(actor, "inbox")
if !hasInbox {
continue
}
destinations = append(destinations, fakeInbox(inbox))
}
return client.SendToInboxes(ctx, localActor, destinations, payload)
}
// func (client ActorClient) Broadcast(ctx context.Context, store storage.Storage, localActor storage.LocalActor, payload []byte) error {
// // TODO: Remove this hard-coded limit of 100 followers to broadcast to.
// followers, err := store.ListAcceptedFollowers(ctx, localActor.User.ID, 100, 0)
// if err != nil {
// return err
// }
// actors, err := store.ActorsByActorID(ctx, followers)
//
// destinations := make([]hasInbox, 0)
// for _, actor := range actors {
// destinations = append(destinations, actor)
// }
// return client.SendToInboxes(ctx, localActor, destinations, payload)
// }
func (client ActorClient) SendToInboxes(ctx context.Context, localActor storage.LocalActor, actors []hasInbox, payload []byte) error {
for _, actor := range actors {
@ -78,7 +73,7 @@ func (client ActorClient) SendToInbox(ctx context.Context, localActor storage.Lo
return err
}
if err = signer.SignRequest(privateKey, localActor.GetKeyID(), request); err != nil {
if err = signer.SignRequest(privateKey, localActor.Actor.GetKeyID(), request); err != nil {
return err
}
resp, err := client.HTTPClient.Do(request)

View File

@ -8,6 +8,7 @@ import (
"go.uber.org/zap"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/errors"
"github.com/ngerakines/tavern/storage"
)
@ -16,6 +17,8 @@ type WebFingerClient struct {
Logger *zap.Logger
}
var invalidWebFingerResource = fmt.Errorf("invalid webfinger resource")
func (client WebFingerClient) Fetch(location string) (storage.Payload, error) {
destination, err := BuildWebFingerURL(location)
if err != nil {
@ -28,61 +31,42 @@ func (client WebFingerClient) Fetch(location string) (storage.Payload, error) {
}
func BuildWebFingerURL(location string) (string, error) {
if strings.HasPrefix(location, "@") {
location = strings.TrimPrefix(location, "@")
parts := strings.Split(location, "@")
if len(parts) != 2 {
return "", fmt.Errorf("invalid actor location: %s", location)
// example: https://webfinger.net/lookup/?resource=https%3A%2F%2Fmastodon.social%2F%40ngerakines
if strings.HasPrefix(location, "https://") {
u, err := url.Parse(location)
if err != nil {
return "", fmt.Errorf("%w: %s", invalidWebFingerResource, location)
}
q := url.QueryEscape(location)
return fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s", parts[1], q), nil
domain := u.Hostname()
return fmt.Sprintf("https://%s/.well-known/webfinger?resource=%s", domain, url.QueryEscape(location)), nil
}
// Some acct lookups may have a prefix from user input.
location = strings.TrimPrefix(location, "@")
subjectParts := strings.FieldsFunc(location, func(r rune) bool { return r == '@' })
if len(subjectParts) != 2 {
return "", fmt.Errorf("%w: %s", invalidWebFingerResource, location)
}
if strings.Index(location, "@") != -1 {
parts := strings.Split(strings.TrimPrefix(location, "@"), "@")
if len(parts) != 2 {
return "", fmt.Errorf("invalid actor location: %s", location)
}
q := url.QueryEscape(location)
return fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s", parts[1], q), nil
}
u, err := url.Parse(location)
if err != nil {
return "", err
}
q := url.QueryEscape(location)
return fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s", u.Host, q), nil
return fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s", subjectParts[1], url.QueryEscape(location)), nil
}
func ActorIDFromWebFingerPayload(wfp storage.Payload) (string, error) {
if wfp == nil {
return "", fmt.Errorf("unable to get actor href from webfinger content: wfp nil")
func ActorIDFromWebFingerPayload(wfp storage.Payload) (string, string, error) {
subject, _ := storage.JSONString(wfp, "subject")
links, ok := storage.JSONMapList(wfp, "links")
if ok {
for _, link := range links {
rel, hasRel := storage.JSONString(link, "rel")
linkType, hasType := storage.JSONString(link, "type")
href, hasHref := storage.JSONString(link, "href")
if !hasRel || !hasType || !hasHref {
continue
}
if rel == "self" && linkType == "application/activity+json" {
return href, subject, nil
}
}
}
links, ok := wfp["links"]
if !ok {
return "", fmt.Errorf("unable to get actor href from webfinger content")
}
objLinks, ok := links.([]interface{})
if !ok {
return "", fmt.Errorf("unable to get actor href from webfinger content: links not array")
}
if len(objLinks) < 1 {
return "", fmt.Errorf("unable to get actor href from webfinger content: links empty")
}
first := objLinks[0]
firstMap, ok := first.(map[string]interface{})
if !ok {
return "", fmt.Errorf("unable to get actor href from webfinger content: first link not map")
}
href, ok := firstMap["href"]
if !ok {
return "", fmt.Errorf("unable to get actor href from webfinger content: href present")
}
hrefStr, ok := href.(string)
if !ok {
return "", fmt.Errorf("unable to get actor href from webfinger content: href not string")
}
return hrefStr, nil
return "", "", errors.NewNotFoundError(nil)
}

2
go.mod
View File

@ -16,7 +16,7 @@ require (
github.com/golang-migrate/migrate/v4 v4.9.1
github.com/kr/pretty v0.1.0
github.com/lib/pq v1.3.0
github.com/lucasb-eyer/go-colorful v1.0.3
github.com/lucasb-eyer/go-colorful v1.0.3 // indirect
github.com/microcosm-cc/bluemonday v1.0.2
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/oklog/run v1.0.0

1
go.sum
View File

@ -439,6 +439,7 @@ golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHl
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs=
golang.org/x/lint v0.0.0-20200130185559-910be7a94367 h1:0IiAsCRByjO2QjX7ZPkw5oU9x+n1YqRL802rjC0c3Aw=
golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=

View File

@ -2,99 +2,217 @@ package storage
import (
"context"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/pem"
"fmt"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/errors"
)
type ActorStorage interface {
GetActor(ctx context.Context, id string) (Actor, error)
GetActorByAlias(ctx context.Context, alias string) (Actor, error)
CreateActor(context.Context, uuid.UUID, uuid.UUID, string, string, string, string, string, string) error
GetActor(ctx context.Context, id uuid.UUID) (*Actor, error)
GetActorByActorID(ctx context.Context, actorID string) (*Actor, error)
CreateActor(ctx context.Context, actorID string, payload Payload) error
GetKey(ctx context.Context, keyID string) (*Key, error)
RecordActorAlias(ctx context.Context, actorID, alias string) error
ActorPayloadsByActorID(ctx context.Context, actorIDs []string) ([]Payload, error)
RecordActorAlias(ctx context.Context, actorID uuid.UUID, alias string, aliasType ActorAliasType) error
RecordActorKey(ctx context.Context, actorID uuid.UUID, keyID, pem string) error
ActorsByActorID(ctx context.Context, actorIDs []string) ([]*Actor, error)
ActorRowIDForActorID(ctx context.Context, actorID string) (uuid.UUID, error)
GetActorByAlias(ctx context.Context, subject string) (*Actor, error)
ActorSubjects(ctx context.Context, actors []uuid.UUID) ([]ActorAlias, error)
}
type Actor interface {
GetID() string
GetInbox() string
GetPublicKey() string
GetKeyID() string
GetDecodedPublicKey() (*rsa.PublicKey, error)
}
type ActorID string
type actor struct {
type Actor struct {
ID uuid.UUID
ActorID string
Payload Payload
CreatedAt time.Time
UpdatedAt time.Time
Inbox string
PublicKey string
KeyID string
ActorType string
PreferredUsername string
Name string
Inbox string
payload string
CurrentKey *KeyData
}
type LocalActor struct {
User *User
ActorID ActorID
type ActorAlias struct {
ID uuid.UUID
Actor uuid.UUID
Alias string
AliasType ActorAliasType
CreatedAt time.Time
UpdatedAt time.Time
}
type KeyData struct {
KeyID string
Owner string
PEM string
}
type Key struct {
ID uuid.UUID
Actor uuid.UUID
KeyID string
PublicKey string
PEM string
CreatedAt time.Time
UpdatedAt time.Time
}
var _ Actor = &actor{}
var _ Actor = &LocalActor{}
type ActorAliasType int16
func (s pgStorage) GetActor(ctx context.Context, id string) (Actor, error) {
return s.getActor(ctx, `SELECT id, actor_id, payload, created_at, updated_at from actors WHERE actor_id = $1`, id)
const (
ActorAliasSubject ActorAliasType = iota
ActorAliasSelf
ActorAliasProfilePage
)
var ActorsFields = []string{
"id",
"actor_id",
"payload",
"created_at",
"updated_at",
}
func (s pgStorage) GetActorByAlias(ctx context.Context, alias string) (Actor, error) {
return s.getActor(ctx, `SELECT id, actor_id, payload, created_at, updated_at from actors WHERE aliases @> ARRAY[$1]::varchar[]`, alias)
var ActorSubjectsFields = []string{
"id",
"actor_id",
"alias",
"alias_type",
"created_at",
"updated_at",
}
func (s pgStorage) getActor(ctx context.Context, query string, args ...interface{}) (Actor, error) {
a := &actor{}
err := s.db.
QueryRowContext(ctx, query, args...).
Scan(&a.ID,
&a.ActorID,
&a.payload,
&a.CreatedAt,
&a.UpdatedAt)
var ActorKeysFields = []string{
"id",
"actor_id",
"key_id",
"pem",
"created_at",
"updated_at",
}
func actorsSelectQuery(join string, where []string) string {
var query strings.Builder
query.WriteString("SELECT ")
query.WriteString(strings.Join(common.MapStrings(ActorsFields, common.AddPrefix("a.")), ", "))
query.WriteString(" FROM actors a")
if len(join) > 0 {
query.WriteString(" INNER JOIN ")
query.WriteString(join)
}
if len(where) > 0 {
query.WriteString(" WHERE")
for i, w := range where {
if i > 1 {
query.WriteString(" AND ")
}
query.WriteString(" ")
query.WriteString(w)
}
}
return query.String()
}
func actorAliasesSelectQuery(join string, where []string) string {
var query strings.Builder
query.WriteString("SELECT ")
query.WriteString(strings.Join(common.MapStrings(ActorSubjectsFields, common.AddPrefix("aa.")), ", "))
query.WriteString(" FROM actor_aliases aa")
if len(join) > 0 {
query.WriteString(" INNER JOIN ")
query.WriteString(join)
}
if len(where) > 0 {
query.WriteString(" WHERE")
for i, w := range where {
if i > 1 {
query.WriteString(" AND ")
}
query.WriteString(" ")
query.WriteString(w)
}
}
return query.String()
}
func (s pgStorage) GetActor(ctx context.Context, id uuid.UUID) (*Actor, error) {
return s.getFirstActor(s.db, ctx, actorsSelectQuery("", []string{"a.id = $1"}), id)
}
func (s pgStorage) GetActorByActorID(ctx context.Context, actorID string) (*Actor, error) {
return s.getFirstActor(s.db, ctx, actorsSelectQuery("", []string{"a.actor_id = $1"}), actorID)
}
func (s pgStorage) GetActorByAlias(ctx context.Context, alias string) (*Actor, error) {
query := actorsSelectQuery("actor_aliases aa ON a.id = aa.actor_id", []string{"aa.alias = $1"})
return s.getFirstActor(s.db, ctx, query, alias)
}
func (s pgStorage) getFirstActor(qc QueryContext, ctx context.Context, query string, args ...interface{}) (*Actor, error) {
results, err := s.getActors(qc, ctx, query, args...)
if err != nil {
return nil, err
}
if err = a.init(); err != nil {
if len(results) == 0 {
return nil, errors.NewNotFoundError(nil)
}
return results[0], nil
}
func (s pgStorage) getActors(qc QueryContext, ctx context.Context, query string, args ...interface{}) ([]*Actor, error) {
results := make([]*Actor, 0)
rows, err := qc.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return a, nil
defer rows.Close()
for rows.Next() {
a := &Actor{}
err = rows.Scan(&a.ID, &a.ActorID, &a.Payload, &a.CreatedAt, &a.UpdatedAt)
if err != nil {
return nil, err
}
a.ActorType, _ = JSONString(a.Payload, "type")
a.Inbox, _ = JSONString(a.Payload, "inbox")
a.PreferredUsername, _ = JSONString(a.Payload, "preferredUsername")
a.Name, _ = JSONString(a.Payload, "name")
keyID, hasKeyID := JSONDeepString(a.Payload, "publicKey", "id")
keyOwner, hasKeyOwner := JSONDeepString(a.Payload, "publicKey", "owner")
keyPEM, hasKeyPEM := JSONDeepString(a.Payload, "publicKey", "publicKeyPem")
if hasKeyID && hasKeyOwner && hasKeyPEM {
a.CurrentKey = &KeyData{
KeyID: keyID,
Owner: keyOwner,
PEM: keyPEM,
}
}
results = append(results, a)
}
return results, nil
}
func (s pgStorage) GetKey(ctx context.Context, keyID string) (*Key, error) {
key := &Key{}
err := s.db.
QueryRowContext(ctx, `SELECT id, key_id, public_key, created_at from keys WHERE key_id = $1`, keyID).
QueryRowContext(ctx, `SELECT id, actor_id, key_id, pem, created_at, updated_at from actor_keys WHERE key_id = $1`, keyID).
Scan(&key.ID,
&key.Actor,
&key.KeyID,
&key.PublicKey,
&key.PEM,
&key.CreatedAt)
if err != nil {
return nil, err
@ -102,206 +220,78 @@ func (s pgStorage) GetKey(ctx context.Context, keyID string) (*Key, error) {
return key, nil
}
func (s pgStorage) CreateActor(ctx context.Context, actorRowID, keyRowID uuid.UUID, actorID, payload, keyID, pem, name, domain string) error {
func (s pgStorage) CreateActor(ctx context.Context, actorID string, payload Payload) error {
now := s.now()
return runTransactionWithOptions(s.db, func(tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, "INSERT INTO actors (id, actor_id, payload, created_at, name, domain) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT DO NOTHING", actorRowID, actorID, payload, now, name, domain)
if err != nil {
return errors.NewInsertQueryFailedError(err)
}
_, err = tx.ExecContext(ctx, "INSERT INTO keys (id, key_id, public_key, created_at) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", keyRowID, keyID, pem, now)
return errors.WrapInsertQueryFailedError(err)
})
_, err := s.db.ExecContext(ctx, "INSERT INTO actors (id, actor_id, payload, created_at, updated_at) VALUES ($1, $2, $3, $4, $4) ON CONFLICT ON CONSTRAINT actors_actor_id DO UPDATE SET payload = $3, updated_at = $4", NewV4(), actorID, payload, now)
return errors.WrapInsertQueryFailedError(err)
}
func (s pgStorage) RecordActorAlias(ctx context.Context, actorID, alias string) error {
func (s pgStorage) RecordActorAlias(ctx context.Context, actorID uuid.UUID, alias string, aliasType ActorAliasType) error {
now := s.now()
_, err := s.db.ExecContext(ctx, `UPDATE actors SET aliases = array_append(aliases, $2), updated_at = $3 WHERE actor_id = $1 AND NOT(aliases @> ARRAY[$2]::varchar[])`, actorID, alias, now)
return errors.WrapUpdateQueryFailedError(err)
_, err := s.db.ExecContext(ctx, "INSERT INTO actor_aliases (id, actor_id, alias, alias_type, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $5) ON CONFLICT DO NOTHING", NewV4(), actorID, alias, aliasType, now)
return errors.WrapInsertQueryFailedError(err)
}
func (s pgStorage) ActorPayloadsByActorID(ctx context.Context, actorIDs []string) ([]Payload, error) {
var results []Payload
func (s pgStorage) RecordActorKey(ctx context.Context, actorID uuid.UUID, keyID, pem string) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "INSERT INTO actor_keys (id, actor_id, key_id, pem, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $5) ON CONFLICT ON CONSTRAINT actor_keys_lookup DO UPDATE SET pem = $4, updated_at = $5", NewV4(), actorID, keyID, pem, now)
return errors.WrapInsertQueryFailedError(err)
}
func (s pgStorage) ActorsByActorID(ctx context.Context, actorIDs []string) ([]*Actor, error) {
if len(actorIDs) == 0 {
return results, nil
return nil, nil
}
valuesPlaceholder := strings.Join(common.DollarForEach(len(actorIDs)), ",")
query := actorsSelectQuery("", []string{
fmt.Sprintf("a.actor_id IN (%s)", valuesPlaceholder),
})
return s.getActors(s.db, ctx, query, common.StringsToInterfaces(actorIDs)...)
}
params := make([]string, len(actorIDs))
args := make([]interface{}, len(actorIDs))
for i, id := range actorIDs {
params[i] = fmt.Sprintf("$%d", i+1)
args[i] = id
func (s pgStorage) ActorRowIDForActorID(ctx context.Context, actorID string) (uuid.UUID, error) {
var actorRowID uuid.UUID
err := s.db.QueryRowContext(ctx, `SELECT id FROM actors WHERE actor_id = $1`, actorID).Scan(&actorRowID)
if err != nil {
if err == sql.ErrNoRows {
return uuid.Nil, errors.NewNotFoundError(err)
}
return uuid.Nil, errors.NewQueryFailedError(err)
}
return actorRowID, nil
}
query := fmt.Sprintf("SELECT payload FROM actors WHERE actor_id IN (%s)", strings.Join(params, ", "))
rows, err := s.db.QueryContext(ctx, query, args...)
func (s pgStorage) ActorSubjects(ctx context.Context, actors []uuid.UUID) ([]ActorAlias, error) {
if len(actors) == 0 {
return nil, nil
}
valuesPlaceholder := fmt.Sprintf("aa.actor_id IN (%s)", strings.Join(common.DollarForEach(len(actors)), ","))
query := actorAliasesSelectQuery("", []string{valuesPlaceholder})
rows, err := s.db.QueryContext(ctx, query, common.UUIDsToInterfaces(actors)...)
if err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
defer rows.Close()
var results []ActorAlias
for rows.Next() {
var data string
if err := rows.Scan(&data); err != nil {
var actorSubject ActorAlias
if err := rows.Scan(&actorSubject.ID, &actorSubject.Actor, &actorSubject.Alias, &actorSubject.AliasType, &actorSubject.CreatedAt, &actorSubject.UpdatedAt); err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
payload, err := PayloadFromString(data)
if err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
results = append(results, payload)
results = append(results, actorSubject)
}
return results, nil
}
func NewActorID(name, domain string) ActorID {
return ActorID(fmt.Sprintf("https://%s/users/%s", domain, name))
}
func KeyFromActor(actor Payload) (string, string, error) {
if actor == nil {
return "", "", fmt.Errorf("unable to get key from actor: actor nil")
func CollectActorSubjectsActorToSubject(actorSubjects []ActorAlias) map[uuid.UUID]string {
results := make(map[uuid.UUID]string)
for _, as := range actorSubjects {
if as.AliasType == ActorAliasSubject {
results[as.Actor] = as.Alias
}
}
publicKey, ok := JSONMap(actor, "publicKey")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object missing")
}
keyID, ok := JSONString(publicKey, "id")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object id invalid")
}
pem, ok := JSONString(publicKey, "publicKeyPem")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object id invalid")
}
return keyID, pem, nil
}
func InboxFromActor(actor Payload) (string, error) {
inbox, ok := JSONString(actor, "inbox")
if !ok {
return "", fmt.Errorf("unable to get inbox from actor: inbox field")
}
return inbox, nil
}
func (a *actor) GetDecodedPublicKey() (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(a.PublicKey))
if block == nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
rsaPublicKey, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("invalid RSA PEM")
}
return rsaPublicKey, nil
}
func (a *actor) GetID() string {
return a.ActorID
}
func (a *actor) GetInbox() string {
return a.Inbox
}
func (a *actor) GetKeyID() string {
return a.KeyID
}
func (a *actor) GetPublicKey() string {
return a.PublicKey
}
func (a *actor) init() error {
p, err := PayloadFromString(a.payload)
if err != nil {
return err
}
a.Inbox, err = InboxFromActor(p)
if err != nil {
return err
}
a.KeyID, a.PublicKey, err = KeyFromActor(p)
if err != nil {
return err
}
return nil
}
func (l LocalActor) GetID() string {
return string(l.ActorID)
}
func (l LocalActor) GetInbox() string {
return l.ActorID.Inbox()
}
func (l LocalActor) GetPublicKey() string {
return l.User.PublicKey
}
func (l LocalActor) GetKeyID() string {
return l.ActorID.MainKey()
}
func (l LocalActor) GetDecodedPublicKey() (*rsa.PublicKey, error) {
return l.GetDecodedPublicKey()
}
func (k *Key) GetDecodedPublicKey() (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(k.PublicKey))
if block == nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
rsaPublicKey, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("invalid RSA PEM")
}
return rsaPublicKey, nil
}
func (ID ActorID) Followers() string {
return fmt.Sprintf("%s/followers", ID)
}
func (ID ActorID) FollowersPage(page int) string {
return fmt.Sprintf("%s/followers?page=%d", ID, page)
}
func (ID ActorID) Following() string {
return fmt.Sprintf("%s/following", ID)
}
func (ID ActorID) FollowingPage(page int) string {
return fmt.Sprintf("%s/following?page=%d", ID, page)
}
func (ID ActorID) Outbox() string {
return fmt.Sprintf("%s/outbox", ID)
}
func (ID ActorID) OutboxPage(page int) string {
return fmt.Sprintf("%s/outbox?page=%d", ID, page)
}
func (ID ActorID) Inbox() string {
return fmt.Sprintf("%s/inbox", ID)
}
func (ID ActorID) MainKey() string {
return fmt.Sprintf("%s#main-key", ID)
return results
}

129
storage/actor_id.go Normal file
View File

@ -0,0 +1,129 @@
package storage
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
)
// type Actor interface {
// GetID() string
// GetInbox() string
// GetPublicKey() string
// GetKeyID() string
// GetDecodedPublicKey() (*rsa.PublicKey, error)
// }
type ActorID string
type LocalActor struct {
User *User
Actor *Actor
ActorID ActorID
}
func NewActorID(name, domain string) ActorID {
return ActorID(fmt.Sprintf("https://%s/users/%s", domain, name))
}
func KeyFromActor(actor Payload) (string, string, error) {
if actor == nil {
return "", "", fmt.Errorf("unable to get key from actor: actor nil")
}
publicKey, ok := JSONMap(actor, "publicKey")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object missing")
}
keyID, ok := JSONString(publicKey, "id")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object id invalid")
}
publicKeyPem, ok := JSONString(publicKey, "publicKeyPem")
if !ok {
return "", "", fmt.Errorf("unable to get key from actor: public key object id invalid")
}
return keyID, publicKeyPem, nil
}
func (a *Actor) GetDecodedPublicKey() (*rsa.PublicKey, error) {
if a.CurrentKey == nil {
return nil, fmt.Errorf("actor has no key")
}
return DecodePublicKey(a.CurrentKey.PEM)
}
func (a *Actor) GetID() string {
return a.ActorID
}
func (a *Actor) GetInbox() string {
return a.Inbox
}
func (a *Actor) GetKeyID() string {
if a.CurrentKey != nil {
return a.CurrentKey.KeyID
}
return ""
}
func (a *Actor) GetPublicKey() string {
if a.CurrentKey != nil {
return a.CurrentKey.PEM
}
return ""
}
func (k *Key) GetDecodedPublicKey() (*rsa.PublicKey, error) {
return DecodePublicKey(k.PEM)
}
func (ID ActorID) Followers() string {
return fmt.Sprintf("%s/followers", ID)
}
func (ID ActorID) FollowersPage(page int) string {
return fmt.Sprintf("%s/followers?page=%d", ID, page)
}
func (ID ActorID) Following() string {
return fmt.Sprintf("%s/following", ID)
}
func (ID ActorID) FollowingPage(page int) string {
return fmt.Sprintf("%s/following?page=%d", ID, page)
}
func (ID ActorID) Outbox() string {
return fmt.Sprintf("%s/outbox", ID)
}
func (ID ActorID) OutboxPage(page int) string {
return fmt.Sprintf("%s/outbox?page=%d", ID, page)
}
func (ID ActorID) Inbox() string {
return fmt.Sprintf("%s/inbox", ID)
}
func (ID ActorID) MainKey() string {
return fmt.Sprintf("%s#main-key", ID)
}
func DecodePublicKey(data string) (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(data))
if block == nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
rsaPublicKey, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("invalid RSA PEM")
}
return rsaPublicKey, nil
}

View File

@ -2,11 +2,13 @@ package storage
import (
"context"
"fmt"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/errors"
)
@ -15,88 +17,80 @@ type FollowerStorage interface {
ListAcceptedFollowing(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error)
ListPendingFollowers(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error)
ListPendingFollowing(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error)
CreatePendingFollowing(ctx context.Context, userID uuid.UUID, actor, activity string) error
CreatePendingFollower(ctx context.Context, userID uuid.UUID, actor, activity string) error
ActivityForFollowing(ctx context.Context, userID uuid.UUID, actor string) (Payload, error)
ActivityForFollower(ctx context.Context, userID uuid.UUID, actor string) (Payload, error)
UpdateFollowingAccepted(ctx context.Context, userID uuid.UUID, actor string) error
UpdateFollowingRejected(ctx context.Context, userID uuid.UUID, actor string) error
UpdateFollowerApproved(ctx context.Context, userID uuid.UUID, actor string) error
RemoveFollowing(ctx context.Context, userID uuid.UUID, actor string) error
RemoveFollower(ctx context.Context, userID uuid.UUID, actor string) error
IsFollowing(ctx context.Context, userID uuid.UUID, actor string) (bool, error)
IsFollower(ctx context.Context, userID uuid.UUID, actor string) (bool, error)
CreatePendingFollowing(ctx context.Context, userID, actorID uuid.UUID, activity Payload) error
CreatePendingFollower(ctx context.Context, userID, actorID uuid.UUID, activity Payload) error
ActivityForFollowing(ctx context.Context, userID, actorID uuid.UUID) (Payload, error)
ActivityForFollower(ctx context.Context, userID, actorID uuid.UUID) (Payload, error)
UpdateFollowingAccepted(ctx context.Context, userID, actorID uuid.UUID) error
UpdateFollowingRejected(ctx context.Context, userID, actorID uuid.UUID) error
UpdateFollowerApproved(ctx context.Context, userID, actorID uuid.UUID) error
RemoveFollowing(ctx context.Context, userID, actorID uuid.UUID) error
RemoveFollower(ctx context.Context, userID, actorID uuid.UUID) error
IsFollowing(ctx context.Context, userID, actorID uuid.UUID) (bool, error)
IsFollower(ctx context.Context, userID, actorID uuid.UUID) (bool, error)
}
type Follower struct {
ID uuid.UUID
UserID uuid.UUID
Actor string
RequestActivity string // Saved so we can reply with approved or rejected
Status int
CreatedAt time.Time
UpdatedAt time.Time
type NetworkRelationship struct {
ID uuid.UUID
RelationshipType RelationshipType
UserID uuid.UUID
ActorID uuid.UUID
RequestActivity Payload // Saved so we can undo, accept, and reject
Status int
CreatedAt time.Time
UpdatedAt time.Time
}
type Following struct {
ID uuid.UUID
UserID uuid.UUID
Actor string
RequestActivity string // Saved so we can reply send undo
Status int
CreatedAt time.Time
UpdatedAt time.Time
var relationshipGraphFields = []string{
"id",
"user_id",
"actor_id",
"activity",
"relationship_type",
"relationship_status",
"created_at",
"updated_at",
}
type FollowingStatus int
type RelationshipType int16
const (
PendingFollowingStatus = 0
AcceptedFollowingStatus = 1
RejectedFollowingStatus = 2
UserFollowsRelationship RelationshipType = iota
UserFollowedByRelationship
)
type FollowerStatus int
type RelationshipStatus int16
const (
PendingFollowerStatus = 0
ApprovedFollowerStatus = 1
// PendingRelationshipStatus indicates either a) an incoming (remote actor
// to local actor) follow request has not been accepted or rejected by the
// local actor OR b) that the outgoing (local actor to remote actor)
// request has not been accepted or rejected by the remote actor.
PendingRelationshipStatus RelationshipStatus = iota
// AcceptRelationshipStatus indicates either a) an incoming (remote actor
// to local actor) follow request was accepted by the local actor OR b)
// that an outgoing (local actor to remote actor) follow request was
// accepted by the remote actor.
AcceptRelationshipStatus
// RejectRelationshipStatus indicates that an outgoing (local actor to
// remote actor) follow request was accepted by the remote actor. Incoming
// follow requests (remote actor to local actor) that rejected by the user
// (local actor) are deleted.
RejectRelationshipStatus
)
var _ FollowerStorage = &pgStorage{}
func (s pgStorage) ListAcceptedFollowers(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error) {
return s.networkQuery(ctx, "followers", `user_id = $3 AND status = $4 ORDER BY created_at ASC LIMIT $1 OFFSET $2`, limit, offset, userID, AcceptedFollowingStatus)
}
func (s pgStorage) ListAcceptedFollowing(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error) {
return s.networkQuery(ctx, "following", `user_id = $3 AND status = $4 ORDER BY created_at ASC LIMIT $1 OFFSET $2`, limit, offset, userID, ApprovedFollowerStatus)
}
func (s pgStorage) ListPendingFollowers(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error) {
return s.networkQuery(ctx, "followers", `user_id = $3 AND status = $4 ORDER BY created_at ASC LIMIT $1 OFFSET $2`, limit, offset, userID, PendingFollowingStatus)
}
func (s pgStorage) ListPendingFollowing(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error) {
return s.networkQuery(ctx, "following", `user_id = $3 AND status = $4 ORDER BY created_at ASC LIMIT $1 OFFSET $2`, limit, offset, userID, PendingFollowerStatus)
}
func (s pgStorage) networkQuery(ctx context.Context, table, where string, args ...interface{}) ([]string, error) {
var actors []string
var query strings.Builder
query.WriteString("SELECT actor FROM ")
query.WriteString(table)
if len(where) > 0 {
query.WriteString(" WHERE ")
query.WriteString(where)
}
rows, err := s.db.QueryContext(ctx, query.String(), args...)
func (s pgStorage) networkGraphQuery(qc QueryContext, ctx context.Context, userID uuid.UUID, relationshipType RelationshipType, relationshipStatus RelationshipStatus, limit, offset int) ([]string, error) {
query := `SELECT a.actor_id FROM network_graph n INNER JOIN actors a ON a.id = n.actor_id WHERE n.user_id = $3 AND n.relationship_type = $4 AND n.relationship_status = $5 ORDER BY n.created_at ASC LIMIT $1 OFFSET $2`
rows, err := qc.QueryContext(ctx, query, limit, offset, userID, relationshipType, relationshipStatus)
if err != nil {
return nil, errors.NewSelectQueryFailedError(err)
}
defer rows.Close()
var actors []string
for rows.Next() {
var actor string
if err := rows.Scan(&actor); err != nil {
@ -108,69 +102,89 @@ func (s pgStorage) networkQuery(ctx context.Context, table, where string, args .
return actors, nil
}
func (s pgStorage) CreatePendingFollowing(ctx context.Context, userID uuid.UUID, actor, activity string) error {
_, err := s.db.ExecContext(ctx, "INSERT INTO following (id, user_id, actor, request_activity, status, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $6) ON CONFLICT ON CONSTRAINT following_user_actor DO UPDATE SET request_activity = $4", NewV4(), userID, actor, activity, PendingFollowingStatus, s.now())
func (s pgStorage) ListAcceptedFollowers(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error) {
return s.networkGraphQuery(s.db, ctx, userID, UserFollowedByRelationship, AcceptRelationshipStatus, limit, offset)
}
func (s pgStorage) ListAcceptedFollowing(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error) {
return s.networkGraphQuery(s.db, ctx, userID, UserFollowsRelationship, AcceptRelationshipStatus, limit, offset)
}
func (s pgStorage) ListPendingFollowers(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error) {
return s.networkGraphQuery(s.db, ctx, userID, UserFollowedByRelationship, PendingRelationshipStatus, limit, offset)
}
func (s pgStorage) ListPendingFollowing(ctx context.Context, userID uuid.UUID, limit, offset int) ([]string, error) {
return s.networkGraphQuery(s.db, ctx, userID, UserFollowsRelationship, PendingRelationshipStatus, limit, offset)
}
func (s pgStorage) createRelationshipGraphRecord(ec QueryExecute, ctx context.Context, userID, actorID uuid.UUID, activity Payload, relationshipType RelationshipType, relationshipStatus RelationshipStatus) error {
fields := strings.Join(relationshipGraphFields, ",")
valuesPlaceholder := strings.Join(common.DollarForEach(len(relationshipGraphFields)), ",")
query := fmt.Sprintf(`INSERT INTO network_graph (%s) VALUES (%s) ON CONFLICT ON CONSTRAINT relationship_graph_user_actor_rel DO UPDATE SET activity = $4, relationship_status = $6`, fields, valuesPlaceholder)
rowID := NewV4()
now := s.now()
_, err := s.db.ExecContext(ctx, query, rowID, userID, actorID, activity, relationshipType, relationshipStatus, now, now)
return errors.WrapInsertQueryFailedError(err)
}
func (s pgStorage) CreatePendingFollower(ctx context.Context, userID uuid.UUID, actor, activity string) error {
_, err := s.db.ExecContext(ctx, "INSERT INTO followers (id, user_id, actor, request_activity, status, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $6) ON CONFLICT ON CONSTRAINT followers_user_actor DO UPDATE SET request_activity = $4", NewV4(), userID, actor, activity, PendingFollowerStatus, s.now())
return errors.WrapInsertQueryFailedError(err)
func (s pgStorage) CreatePendingFollowing(ctx context.Context, userID, actorID uuid.UUID, activity Payload) error {
return s.createRelationshipGraphRecord(s.db, ctx, userID, actorID, activity, UserFollowsRelationship, PendingRelationshipStatus)
}
func (s pgStorage) ActivityForFollowing(ctx context.Context, userID uuid.UUID, actor string) (Payload, error) {
var activity string
err := s.db.
QueryRowContext(ctx, `SELECT request_activity FROM following WHERE user_id = $1 AND actor = $2`, userID, actor).
Scan(&activity)
func (s pgStorage) CreatePendingFollower(ctx context.Context, userID, actorID uuid.UUID, activity Payload) error {
return s.createRelationshipGraphRecord(s.db, ctx, userID, actorID, activity, UserFollowedByRelationship, PendingRelationshipStatus)
}
func (s pgStorage) ActivityForFollowing(ctx context.Context, userID, actorID uuid.UUID) (Payload, error) {
return s.networkGraphActivity(s.db, ctx, userID, actorID, UserFollowsRelationship)
}
func (s pgStorage) ActivityForFollower(ctx context.Context, userID, actorID uuid.UUID) (Payload, error) {
return s.networkGraphActivity(s.db, ctx, userID, actorID, UserFollowedByRelationship)
}
func (s pgStorage) networkGraphActivity(qc QueryContext, ctx context.Context, userID, actorID uuid.UUID, relationshipType RelationshipType) (Payload, error) {
var payload Payload
err := qc.QueryRowContext(ctx, `SELECT activity FROM network_graph WHERE user_id = $1 AND actor_id = $2 AND relationship_type = $3`, userID, actorID, relationshipType).
Scan(&payload)
if err != nil {
return nil, errors.WrapQueryFailedError(err)
}
return PayloadFromString(activity)
return payload, nil
}
func (s pgStorage) ActivityForFollower(ctx context.Context, userID uuid.UUID, actor string) (Payload, error) {
var activity string
err := s.db.
QueryRowContext(ctx, `SELECT request_activity FROM followers WHERE user_id = $1 AND actor = $2`, userID, actor).
Scan(&activity)
if err != nil {
return nil, errors.WrapQueryFailedError(err)
}
return PayloadFromString(activity)
}
func (s pgStorage) UpdateFollowingAccepted(ctx context.Context, userID uuid.UUID, actor string) error {
_, err := s.db.ExecContext(ctx, `UPDATE following SET status = $3 WHERE user_id = $1 AND actor = $2`, userID, actor, AcceptedFollowingStatus)
func (s pgStorage) UpdateFollowingAccepted(ctx context.Context, userID, actorID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, `UPDATE network_graph SET relationship_status = $4 WHERE user_id = $1 AND actor_id = $2 AND relationship_type = $3`, userID, actorID, UserFollowedByRelationship, AcceptRelationshipStatus)
return errors.WrapUpdateQueryFailedError(err)
}
func (s pgStorage) UpdateFollowingRejected(ctx context.Context, userID uuid.UUID, actor string) error {
_, err := s.db.ExecContext(ctx, `UPDATE following SET status = $3 WHERE user_id = $1 AND actor = $2`, userID, actor, RejectedFollowingStatus)
func (s pgStorage) UpdateFollowingRejected(ctx context.Context, userID, actorID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, `UPDATE network_graph SET relationship_status = $4 WHERE user_id = $1 AND actor_id = $2 AND relationship_type = $3`, userID, actorID, UserFollowedByRelationship, RejectRelationshipStatus)
return errors.WrapUpdateQueryFailedError(err)
}
func (s pgStorage) UpdateFollowerApproved(ctx context.Context, userID uuid.UUID, actor string) error {
_, err := s.db.ExecContext(ctx, `UPDATE followers SET status = $3 WHERE user_id = $1 AND actor = $2`, userID, actor, ApprovedFollowerStatus)
func (s pgStorage) UpdateFollowerApproved(ctx context.Context, userID, actorID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, `UPDATE network_graph SET relationship_status = $4 WHERE user_id = $1 AND actor_id = $2 AND relationship_type = $3`, userID, actorID, UserFollowsRelationship, AcceptRelationshipStatus)
return errors.WrapUpdateQueryFailedError(err)
}
func (s pgStorage) RemoveFollowing(ctx context.Context, userID uuid.UUID, actor string) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM following WHERE user_id = $1 AND actor = $2`, userID, actor)
func (s pgStorage) RemoveFollowing(ctx context.Context, userID, actorID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM network_graph WHERE user_id = $1 AND actor_id = $2 AND relationship_type = $3`, userID, actorID, UserFollowsRelationship)
return errors.WrapDeleteQueryFailedError(err)
}
func (s pgStorage) RemoveFollower(ctx context.Context, userID uuid.UUID, actor string) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM followers WHERE user_id = $1 AND actor = $2`, userID, actor)
func (s pgStorage) RemoveFollower(ctx context.Context, userID, actorID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM network_graph WHERE user_id = $1 AND actor_id = $2 AND relationship_type = $3`, userID, actorID, UserFollowedByRelationship)
return errors.WrapDeleteQueryFailedError(err)
}
func (s pgStorage) IsFollowing(ctx context.Context, userID uuid.UUID, actor string) (bool, error) {
c, err := s.RowCount(ctx, `SELECT COUNT(*) FROM following WHERE user_id = $1 AND actor = $2`, userID, actor)
func (s pgStorage) IsFollowing(ctx context.Context, userID, actorID uuid.UUID) (bool, error) {
c, err := s.RowCount(ctx, `SELECT COUNT(*) FROM network_graph WHERE user_id = $1 AND actor_id = $2 AND relationship_type = $3`, userID, actorID, UserFollowsRelationship)
return c == 1, err
}
func (s pgStorage) IsFollower(ctx context.Context, userID uuid.UUID, actor string) (bool, error) {
c, err := s.RowCount(ctx, `SELECT COUNT(*) FROM followers WHERE user_id = $1 AND actor = $2`, userID, actor)
func (s pgStorage) IsFollower(ctx context.Context, userID, actorID uuid.UUID) (bool, error) {
c, err := s.RowCount(ctx, `SELECT COUNT(*) FROM network_graph WHERE user_id = $1 AND actor_id = $2 AND relationship_type = $3`, userID, actorID, UserFollowedByRelationship)
return c == 1, err
}

View File

@ -2,7 +2,9 @@ package storage
import (
"bytes"
"database/sql/driver"
"encoding/json"
"errors"
"io"
"strings"
)
@ -21,6 +23,19 @@ func (p Payload) Bytes() []byte {
return buf.Bytes()
}
func (p Payload) Value() (driver.Value, error) {
return json.Marshal(p)
}
func (p *Payload) Scan(value interface{}) error {
b, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(b, &p)
}
func EmptyPayload() Payload {
return make(map[string]interface{})
}

View File

@ -34,6 +34,7 @@ type QueryRowCount interface {
type QueryContext interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
func DefaultStorage(db *sql.DB) Storage {

View File

@ -46,18 +46,7 @@ type User struct {
DisplayName string
About string
AcceptFollowers bool
}
func (u *User) GetID() string {
return ""
}
func (u *User) GetInbox() string {
return ""
}
func (u *User) GetKeyID() string {
return ""
ActorID uuid.UUID
}
func (u *User) GetPrivateKey() (*rsa.PrivateKey, error) {
@ -75,20 +64,7 @@ func (u *User) GetPrivateKey() (*rsa.PrivateKey, error) {
}
func (u *User) GetDecodedPublicKey() (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(u.PublicKey))
if block == nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("invalid RSA PEM")
}
rsaPublicKey, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("invalid RSA PEM")
}
return rsaPublicKey, nil
return DecodePublicKey(u.PublicKey)
}
var userFields = []string{
@ -106,6 +82,7 @@ var userFields = []string{
"display_name",
"about",
"accept_followers",
"actor_id",
}
func (s pgStorage) GetUserBySession(ctx context.Context, session sessions.Session) (*User, error) {
@ -137,13 +114,15 @@ func (s pgStorage) GetUserWithQuery(ctx context.Context, query string, args ...i
&user.UpdatedAt,
&user.LastAuthAt,
&user.Location,
&user.MuteEmail, &user.Locale,
&user.MuteEmail,
&user.Locale,
&user.PublicKey,
&user.PrivateKey,
&user.Name,
&user.DisplayName,
&user.About,
&user.AcceptFollowers)
&user.AcceptFollowers,
&user.ActorID)
if err != nil {
if err == sql.ErrNoRows {
return nil, errors.NewUserNotFoundError(err)
@ -154,9 +133,9 @@ func (s pgStorage) GetUserWithQuery(ctx context.Context, query string, args ...i
return user, nil
}
func (s pgStorage) CreateUser(ctx context.Context, userID uuid.UUID, email, locale, name, displayName, about, publicKey, privateKey string, password []byte) error {
func (s pgStorage) CreateUser(ctx context.Context, actorID uuid.UUID, email, locale, name, displayName, about, publicKey, privateKey string, password []byte) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "INSERT INTO users (id, email, password, created_at, updated_at, last_auth_at, locale, public_key, private_key, name, display_name, about) VALUES ($1, $2, $3, $4, $4, $4, $5, $6, $7, $8, $9, $10)", userID, email, password, now, locale, publicKey, privateKey, name, displayName, about)
_, err := s.db.ExecContext(ctx, "INSERT INTO users (id, email, password, created_at, updated_at, last_auth_at, locale, public_key, private_key, name, display_name, about, actor_id) VALUES ($1, $2, $3, $4, $4, $4, $5, $6, $7, $8, $9, $10, $11)", NewV4(), email, password, now, locale, publicKey, privateKey, name, displayName, about, actorID)
if err != nil {
return errors.NewCreateUserFailedError(err)
}

View File

@ -109,7 +109,13 @@ func (h handler) actorInboxAccept(c *gin.Context, user *storage.User, payload st
return
}
err := h.storage.UpdateFollowingAccepted(c.Request.Context(), user.ID, target)
remoteActorID, err := h.storage.ActorRowIDForActorID(c.Request.Context(), target)
if err != nil {
h.internalServerErrorJSON(c, err)
return
}
err = h.storage.UpdateFollowingAccepted(c.Request.Context(), user.ID, remoteActorID)
if err != nil {
h.internalServerErrorJSON(c, err)
return
@ -140,7 +146,13 @@ func (h handler) actorInboxReject(c *gin.Context, user *storage.User, payload st
return
}
err := h.storage.UpdateFollowingRejected(c.Request.Context(), user.ID, target)
remoteActorID, err := h.storage.ActorRowIDForActorID(c.Request.Context(), target)
if err != nil {
h.internalServerErrorJSON(c, err)
return
}
err = h.storage.UpdateFollowingRejected(c.Request.Context(), user.ID, remoteActorID)
if err != nil {
h.internalServerErrorJSON(c, err)
return
@ -178,7 +190,13 @@ func (h handler) actorInboxFollow(c *gin.Context, user *storage.User, payload st
return
}
err = h.storage.CreatePendingFollower(ctx, user.ID, target, string(body))
remoteActorID, err := h.storage.ActorRowIDForActorID(c.Request.Context(), target)
if err != nil {
h.internalServerErrorJSON(c, err)
return
}
err = h.storage.CreatePendingFollower(ctx, user.ID, remoteActorID, payload)
if err != nil {
h.internalServerErrorJSON(c, err)
return
@ -207,7 +225,7 @@ func (h handler) actorInboxFollow(c *gin.Context, user *storage.User, payload st
return
}
err = h.storage.UpdateFollowerApproved(ctx, user.ID, target)
err = h.storage.UpdateFollowerApproved(ctx, user.ID, remoteActorID)
if err != nil {
h.internalServerErrorJSON(c, err)
return
@ -240,7 +258,13 @@ func (h handler) actorInboxUndoFollow(c *gin.Context, user *storage.User, payloa
return
}
err := h.storage.RemoveFollower(ctx, user.ID, target)
remoteActorID, err := h.storage.ActorRowIDForActorID(c.Request.Context(), target)
if err != nil {
h.internalServerErrorJSON(c, err)
return
}
err = h.storage.RemoveFollower(ctx, user.ID, remoteActorID)
if err != nil {
h.internalServerErrorJSON(c, err)
return
@ -482,7 +506,13 @@ func skipActorInbox(j storage.Payload) bool {
func (h handler) isActivityRelevant(ctx context.Context, activity storage.Payload, user *storage.User) (bool, error) {
actor, ok := storage.JSONString(activity, "actor")
if ok {
isFollowing, err := h.storage.IsFollowing(ctx, user.ID, actor)
remoteActorID, err := h.storage.ActorRowIDForActorID(ctx, actor)
if err != nil {
return false, err
}
isFollowing, err := h.storage.IsFollowing(ctx, user.ID, remoteActorID)
if err != nil {
return false, err
}

View File

@ -8,7 +8,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/ngerakines/tavern/avatar"
"github.com/ngerakines/tavern/errors"
)
func (h handler) avatarSVG(c *gin.Context) {
@ -60,12 +59,13 @@ func (h handler) avatar(c *gin.Context) ([]byte, error) {
return []byte(svg), nil
}
exists, err := h.storage.RowCount(c.Request.Context(), `SELECT COUNT(*) FROM actors WHERE name = $1 AND domain = $2`, name, domain)
exists, err := h.storage.RowCount(c.Request.Context(), `SELECT COUNT(*) FROM actor_aliases WHERE alias = $1 AND alias_type = 0`, fmt.Sprintf("acct:%s@%s", name, domain))
if err != nil {
return nil, err
}
if exists == 0 {
return nil, errors.NewNotFoundError(nil)
svg := avatar.AvatarSVG("unknown", size, false)
return []byte(svg), nil
}
id := fmt.Sprintf("@%s@%s", name, domain)

View File

@ -167,7 +167,7 @@ func (h handler) createNote(c *gin.Context) {
}
}
mentionedActors := make(map[string]storage.Actor)
mentionedActors := make(map[string]*storage.Actor)
var mentionedActorNames []string
@ -186,7 +186,7 @@ func (h handler) createNote(c *gin.Context) {
foundActor, err := fed.GetOrFetchActor(ctx, h.storage, h.logger, common.DefaultHTTPClient(), mentionedActor)
if err == nil {
mentionedActors[mentionedActor] = foundActor
to = append(to, foundActor.GetID())
to = append(to, foundActor.ActorID)
}
}
@ -466,7 +466,7 @@ func (h handler) announceNote(c *gin.Context) {
announce["to"] = to
announce["cc"] = cc
announce["object"] = objectID
announcePayload := announce.Bytes()
// announcePayload := announce.Bytes()
objectEventID, err := h.storage.RecordObjectEvent(ctx, announceID, objectID, string(actor), storage.AnnounceNoteObjectEvent, now, to, cc)
if err != nil {
@ -480,15 +480,15 @@ func (h handler) announceNote(c *gin.Context) {
return
}
nc := fed.ActorClient{
HTTPClient: common.DefaultHTTPClient(),
Logger: h.logger,
}
err = nc.Broadcast(ctx, h.storage, storage.LocalActor{User: user, ActorID: storage.NewActorID(user.Name, h.domain)}, announcePayload)
if err != nil {
h.flashErrorOrFail(c, h.url("feed_recent"), err)
return
}
// nc := fed.ActorClient{
// HTTPClient: common.DefaultHTTPClient(),
// Logger: h.logger,
// }
// err = nc.Broadcast(ctx, h.storage, storage.LocalActor{User: user, ActorID: storage.NewActorID(user.Name, h.domain)}, announcePayload)
// if err != nil {
// h.flashErrorOrFail(c, h.url("feed_recent"), err)
// return
// }
c.Redirect(http.StatusFound, h.url("feed_mine"))
}

View File

@ -124,12 +124,21 @@ func (h handler) viewFeed(c *gin.Context) {
data["latest"] = uf[0].CreatedAt.UTC().Unix()
}
actors, err := h.storage.ActorPayloadsByActorID(ctx, append(actorIDs, vf.actorIDs...))
actors, err := h.storage.ActorsByActorID(ctx, append(actorIDs, vf.actorIDs...))
if err != nil {
h.hardFail(c, err)
return
}
allActors := h.gatherActors(actors)
var actorRowIDs []uuid.UUID
for _, actor := range actors {
actorRowIDs = append(actorRowIDs, actor.ID)
}
actorSubjects, err := h.storage.ActorSubjects(ctx, actorRowIDs)
if err != nil {
h.hardFail(c, err)
return
}
allActors := h.gatherActors(actors, actorSubjects)
var pages []int
for i := page - 3; i <= page+3; i++ {
@ -226,12 +235,21 @@ func (h handler) viewMyFeed(c *gin.Context) {
data["latest"] = uf[0].CreatedAt.UTC().Unix()
}
actors, err := h.storage.ActorPayloadsByActorID(ctx, append(actorIDs, vf.actorIDs...))
actors, err := h.storage.ActorsByActorID(ctx, append(actorIDs, vf.actorIDs...))
if err != nil {
h.hardFail(c, err)
return
}
allActors := h.gatherActors(actors)
var actorRowIDs []uuid.UUID
for _, actor := range actors {
actorRowIDs = append(actorRowIDs, actor.ID)
}
actorSubjects, err := h.storage.ActorSubjects(ctx, actorRowIDs)
if err != nil {
h.hardFail(c, err)
return
}
allActors := h.gatherActors(actors, actorSubjects)
var pages []int
for i := page - 3; i <= page+3; i++ {
@ -302,12 +320,21 @@ func (h handler) viewConversation(c *gin.Context) {
}
data["announcements"] = announcements
actors, err := h.storage.ActorPayloadsByActorID(ctx, vf.actorIDs)
actors, err := h.storage.ActorsByActorID(ctx, vf.actorIDs)
if err != nil {
h.hardFail(c, err)
return
}
allActors := h.gatherActors(actors)
var actorRowIDs []uuid.UUID
for _, actor := range actors {
actorRowIDs = append(actorRowIDs, actor.ID)
}
actorSubjects, err := h.storage.ActorSubjects(ctx, actorRowIDs)
if err != nil {
h.hardFail(c, err)
return
}
allActors := h.gatherActors(actors, actorSubjects)
data["actors"] = actorLookup{h.domain, allActors}
@ -326,35 +353,35 @@ func (h handler) viewConversation(c *gin.Context) {
c.HTML(http.StatusOK, "feed", data)
}
func (h handler) gatherActors(actors []storage.Payload) map[string]map[string]string {
func (h handler) gatherActors(actors []*storage.Actor, actorSubjects []storage.ActorAlias) map[string]map[string]string {
results := make(map[string]map[string]string)
subjects := storage.CollectActorSubjectsActorToSubject(actorSubjects)
for _, actor := range actors {
actorID, _ := storage.JSONString(actor, "id")
summary := make(map[string]string)
subject, hasSubject := subjects[actor.ID]
if hasSubject {
trimmed := strings.TrimPrefix(subject, "acct:")
subjectParts := strings.Split(trimmed, "@")
if len(subjectParts) == 2 {
summary["at"] = trimmed
summary["icon"] = fmt.Sprintf("https://%s/avatar/png/%s/%s", h.domain, subjectParts[1], subjectParts[0])
results[actor.ActorID] = summary
continue
}
}
actorID := actor.ActorID
u, err := url.Parse(actorID)
if err != nil {
continue
}
domain := u.Hostname()
summary := make(map[string]string)
preferredUsername, ok := storage.JSONString(actor, "preferredUsername")
if !ok {
preferredUsername = "unknown"
}
name, hasName := storage.JSONString(actor, "name")
if !hasName {
name = preferredUsername
}
summary["preferred_username"] = preferredUsername
summary["name"] = name
summary["domain"] = domain
summary["at"] = fmt.Sprintf("%s@%s", preferredUsername, domain)
summary["icon"] = fmt.Sprintf("https://%s/avatar/png/%s/%s", h.domain, domain, preferredUsername)
summary["at"] = fmt.Sprintf("%s@%s", actor.Name, domain)
summary["icon"] = fmt.Sprintf("https://%s/avatar/png/%s/%s", h.domain, domain, actor.Name)
results[actorID] = summary
}

View File

@ -110,7 +110,7 @@ func (h handler) networkFollow(c *gin.Context) {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
}
isFollowing, err := h.storage.IsFollowing(ctx, user.ID, actor.GetID())
isFollowing, err := h.storage.IsFollowing(ctx, user.ID, actor.ID)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
@ -141,7 +141,7 @@ func (h handler) networkFollow(c *gin.Context) {
return
}
err = h.storage.CreatePendingFollowing(ctx, user.ID, actor.GetID(), string(payload))
err = h.storage.CreatePendingFollowing(ctx, user.ID, actor.ID, follow)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
@ -169,7 +169,13 @@ func (h handler) networkUnfollow(c *gin.Context) {
return
}
target := c.PostForm("actor")
requestActivity, err := h.storage.ActivityForFollowing(ctx, user.ID, target)
targetActor, err := fed.GetOrFetchActor(ctx, h.storage, h.logger, common.DefaultHTTPClient(), target)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
}
requestActivity, err := h.storage.ActivityForFollowing(ctx, user.ID, targetActor.ID)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
@ -200,7 +206,7 @@ func (h handler) networkUnfollow(c *gin.Context) {
return
}
err = h.storage.RemoveFollowing(ctx, user.ID, target)
err = h.storage.RemoveFollowing(ctx, user.ID, targetActor.ID)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
@ -246,7 +252,7 @@ func (h handler) networkAccept(c *gin.Context) {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
}
activity, err := h.storage.ActivityForFollower(ctx, user.ID, targetActorID)
activity, err := h.storage.ActivityForFollower(ctx, user.ID, followerActor.ID)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
@ -273,7 +279,7 @@ func (h handler) networkAccept(c *gin.Context) {
return
}
err = h.storage.UpdateFollowerApproved(ctx, user.ID, targetActorID)
err = h.storage.UpdateFollowerApproved(ctx, user.ID, followerActor.ID)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
@ -309,7 +315,7 @@ func (h handler) networkReject(c *gin.Context) {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
}
activity, err := h.storage.ActivityForFollower(ctx, user.ID, targetActorID)
activity, err := h.storage.ActivityForFollower(ctx, user.ID, followerActor.ID)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return
@ -336,7 +342,7 @@ func (h handler) networkReject(c *gin.Context) {
return
}
err = h.storage.RemoveFollower(ctx, user.ID, targetActorID)
err = h.storage.RemoveFollower(ctx, user.ID, followerActor.ID)
if err != nil {
h.flashErrorOrFail(c, "/dashboard/network", err)
return

3
web/view_actor.go Normal file
View File

@ -0,0 +1,3 @@
package web