tavern/storage/user.go

177 lines
5.5 KiB
Go

package storage
import (
"context"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/pem"
"fmt"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gofrs/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/ngerakines/tavern/errors"
)
type UserStorage interface {
AuthenticateUser(ctx context.Context, email string, password []byte) (uuid.UUID, error)
CreateUser(ctx context.Context, userID uuid.UUID, email, locale, name, displayName, about, publicKey, privateKey string, password []byte) error
GetUser(ctx context.Context, userID uuid.UUID) (*User, error)
GetUserByName(ctx context.Context, name string) (*User, error)
GetUserBySession(ctx context.Context, session sessions.Session) (*User, error)
UpdateUserLastAuth(ctx context.Context, userID uuid.UUID) error
UpdateUserAutoAcceptFollowers(ctx context.Context, userID uuid.UUID, value bool) error
UpdateUserReplyCollectionUpdates(ctx context.Context, userID uuid.UUID, value bool) error
}
type User struct {
ID uuid.UUID
Email string
Password []byte
CreatedAt time.Time
UpdatedAt time.Time
LastAuthAt time.Time
Location string
MuteEmail bool
Locale string
PrivateKey string
PublicKey string
Name string
DisplayName string
About string
AcceptFollowers bool
ActorID uuid.UUID
ReplyCollectionUpdates bool
}
func (u *User) GetPrivateKey() (*rsa.PrivateKey, error) {
block, _ := pem.Decode([]byte(u.PrivateKey))
if block == nil {
return nil, errors.New("invalid RSA PEM")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return key, nil
}
func (u *User) GetDecodedPublicKey() (*rsa.PublicKey, error) {
return DecodePublicKey(u.PublicKey)
}
var userFields = []string{
"id",
"email",
"created_at",
"updated_at",
"last_auth_at",
"location",
"mute_email",
"locale",
"public_key",
"private_key",
"name",
"display_name",
"about",
"accept_followers",
"actor_id",
"reply_collection_updates",
}
func (s pgStorage) GetUserBySession(ctx context.Context, session sessions.Session) (*User, error) {
userID, err := userIDFromSession(session)
if err != nil {
return nil, err
}
return s.GetUser(ctx, userID)
}
func (s pgStorage) GetUserByName(ctx context.Context, name string) (*User, error) {
query := fmt.Sprintf("SELECT %s FROM users WHERE name = $1", strings.Join(userFields, ", "))
return s.GetUserWithQuery(ctx, query, name)
}
func (s pgStorage) GetUser(ctx context.Context, userID uuid.UUID) (*User, error) {
query := fmt.Sprintf("SELECT %s FROM users WHERE id = $1", strings.Join(userFields, ", "))
return s.GetUserWithQuery(ctx, query, userID)
}
func (s pgStorage) GetUserWithQuery(ctx context.Context, query string, args ...interface{}) (*User, error) {
user := &User{}
err := s.db.
QueryRowContext(ctx, query, args...).
Scan(&user.ID,
&user.Email,
&user.CreatedAt,
&user.UpdatedAt,
&user.LastAuthAt,
&user.Location,
&user.MuteEmail,
&user.Locale,
&user.PublicKey,
&user.PrivateKey,
&user.Name,
&user.DisplayName,
&user.About,
&user.AcceptFollowers,
&user.ActorID,
&user.ReplyCollectionUpdates)
if err != nil {
if err == sql.ErrNoRows {
return nil, errors.NewUserNotFoundError(errors.NewUserQueryFailedError(err))
}
return nil, errors.NewUserQueryFailedError(err)
}
return user, nil
}
func (s pgStorage) CreateUser(ctx context.Context, actorID uuid.UUID, email, locale, name, displayName, about, publicKey, privateKey string, password []byte) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "INSERT INTO users (id, email, password, created_at, updated_at, last_auth_at, locale, public_key, private_key, name, display_name, about, actor_id) VALUES ($1, $2, $3, $4, $4, $4, $5, $6, $7, $8, $9, $10, $11)", NewV4(), email, password, now, locale, publicKey, privateKey, name, displayName, about, actorID)
return errors.WrapUserInsertFailedError(err)
}
func (s pgStorage) AuthenticateUser(ctx context.Context, email string, password []byte) (uuid.UUID, error) {
var userID uuid.UUID
var currentPassword []byte
err := s.db.QueryRowContext(ctx, "SELECT id, password FROM users WHERE email = $1", email).Scan(&userID, &currentPassword)
if err != nil {
if err == sql.ErrNoRows {
return uuid.Nil, errors.NewUserNotFoundError(err)
}
return uuid.Nil, errors.NewUserQueryFailedError(err)
}
err = bcrypt.CompareHashAndPassword(currentPassword, password)
if err != nil {
return uuid.Nil, errors.NewUserNotFoundError(err)
}
return userID, nil
}
func (s pgStorage) UpdateUserLastAuth(ctx context.Context, userID uuid.UUID) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "UPDATE users SET last_auth_at = $2, updated_at = $2 WHERE id = $1", userID, now)
return errors.WrapUserUpdateFailedError(err)
}
func (s pgStorage) UpdateUserAutoAcceptFollowers(ctx context.Context, userID uuid.UUID, value bool) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "UPDATE users SET accept_followers = $3, updated_at = $2 WHERE id = $1", userID, now, value)
return errors.WrapUserUpdateFailedError(err)
}
func (s pgStorage) UpdateUserReplyCollectionUpdates(ctx context.Context, userID uuid.UUID, value bool) error {
now := s.now()
_, err := s.db.ExecContext(ctx, "UPDATE users SET reply_collection_updates = $3, updated_at = $2 WHERE id = $1", userID, now, value)
return errors.WrapUserUpdateFailedError(err)
}