mirror of https://gitlab.com/ngerakines/tavern.git
264 lines
6.3 KiB
Go
264 lines
6.3 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"time"
|
|
|
|
"github.com/gofrs/uuid"
|
|
_ "github.com/lib/pq"
|
|
|
|
"github.com/ngerakines/tavern/errors"
|
|
)
|
|
|
|
type Storage interface {
|
|
RowCount(ctx context.Context, query string, args ...interface{}) (int, error)
|
|
|
|
UserStorage
|
|
ActorStorage
|
|
FollowerStorage
|
|
PeerStorage
|
|
ObjectStorage
|
|
AssetStorage
|
|
InstanceStatsStorage
|
|
GroupStorage
|
|
ACLStorage
|
|
|
|
GetExecutor() QueryExecute
|
|
}
|
|
|
|
type Count struct {
|
|
Key string
|
|
Count int
|
|
}
|
|
|
|
func DefaultStorage(db QueryExecute) Storage {
|
|
return pgStorage{
|
|
db: db,
|
|
now: defaultNowFunc,
|
|
}
|
|
}
|
|
|
|
type pgStorage struct {
|
|
db QueryExecute
|
|
now nowFunc
|
|
}
|
|
|
|
type transactionScopedWork func(db QueryExecute) error
|
|
type transactionScopedStorage func(storage Storage) error
|
|
type nowFunc func() time.Time
|
|
|
|
func defaultNowFunc() time.Time {
|
|
return time.Now().UTC()
|
|
}
|
|
|
|
// deprecated
|
|
func runTransactionWithOptions(db QueryExecute, txBody transactionScopedWork) error {
|
|
return txBody(db)
|
|
}
|
|
|
|
func TransactionalStorage(ctx context.Context, storage Storage, txBody transactionScopedStorage) error {
|
|
if _, ok := storage.GetExecutor().(TransactionSQLDriver); ok {
|
|
return txBody(storage)
|
|
}
|
|
|
|
executor := storage.GetExecutor()
|
|
realDB, ok := executor.(*sql.DB)
|
|
if !ok {
|
|
return txBody(storage)
|
|
}
|
|
|
|
tx, err := realDB.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return errors.NewDatabaseTransactionFailedError(err)
|
|
}
|
|
|
|
err = txBody(pgStorage{
|
|
db: TransactionSQLDriver{Driver: tx},
|
|
now: defaultNowFunc,
|
|
})
|
|
if err != nil {
|
|
if txErr := tx.Rollback(); txErr != nil {
|
|
return errors.NewDatabaseTransactionFailedError(txErr)
|
|
}
|
|
return err
|
|
}
|
|
return errors.WrapDatabaseTransactionFailedError(tx.Commit())
|
|
}
|
|
|
|
func (s pgStorage) GetExecutor() QueryExecute {
|
|
return s.db
|
|
}
|
|
|
|
func (s pgStorage) RowCount(ctx context.Context, query string, args ...interface{}) (int, error) {
|
|
return s.wrappedRowCount(errors.WrapQueryFailedError, ctx, query, args...)
|
|
}
|
|
|
|
func (s pgStorage) rowCount(ctx context.Context, query string, args ...interface{}) (int, error) {
|
|
return s.wrappedRowCount(errors.WrapQueryFailedError, ctx, query, args...)
|
|
}
|
|
|
|
func (s pgStorage) wrappedRowCount(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (int, error) {
|
|
var total int
|
|
|
|
err := s.db.QueryRowContext(ctx, query, args...).Scan(&total)
|
|
if err != nil {
|
|
return -1, ew(err)
|
|
}
|
|
return total, nil
|
|
}
|
|
|
|
func (s pgStorage) wrappedExists(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (bool, error) {
|
|
var total int
|
|
err := s.db.QueryRowContext(ctx, query, args...).Scan(&total)
|
|
if err != nil {
|
|
return false, ew(err)
|
|
}
|
|
return total != 0, nil
|
|
}
|
|
|
|
func (s pgStorage) wrappedSelectUUID(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (uuid.UUID, error) {
|
|
var id uuid.UUID
|
|
err := s.db.QueryRowContext(ctx, query, args...).Scan(&id)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return uuid.Nil, ew(errors.NewNotFoundError(err))
|
|
}
|
|
return uuid.Nil, ew(err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (s pgStorage) wrappedSelectInt(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (int, error) {
|
|
var value int
|
|
err := s.db.QueryRowContext(ctx, query, args...).Scan(&value)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return -1, ew(errors.NewNotFoundError(err))
|
|
}
|
|
return -1, ew(err)
|
|
}
|
|
return value, nil
|
|
}
|
|
|
|
func (s pgStorage) keyedCount(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) ([]Count, error) {
|
|
results := make([]Count, 0)
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var count Count
|
|
if err := rows.Scan(&count.Key, &count.Count); err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
results = append(results, count)
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (s pgStorage) keysToUUID(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (map[string]uuid.UUID, error) {
|
|
results := make(map[string]uuid.UUID)
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var key string
|
|
var id uuid.UUID
|
|
if err := rows.Scan(&key, &id); err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
results[key] = id
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (s pgStorage) uuidsToPayload(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (map[uuid.UUID]Payload, error) {
|
|
results := make(map[uuid.UUID]Payload)
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var key uuid.UUID
|
|
var payload Payload
|
|
if err := rows.Scan(&key, &payload); err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
results[key] = payload
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (s pgStorage) uuidsToUuids(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (map[uuid.UUID]uuid.UUID, error) {
|
|
results := make(map[uuid.UUID]uuid.UUID)
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var key uuid.UUID
|
|
var value uuid.UUID
|
|
if err := rows.Scan(&key, &value); err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
results[key] = value
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (s pgStorage) toUUIDMultiMap(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (map[uuid.UUID][]uuid.UUID, error) {
|
|
results := make(map[uuid.UUID][]uuid.UUID)
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var key uuid.UUID
|
|
var value uuid.UUID
|
|
if err := rows.Scan(&key, &value); err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
values, hasValues := results[key]
|
|
if !hasValues {
|
|
values = make([]uuid.UUID, 0)
|
|
}
|
|
results[key] = append(values, value)
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (s pgStorage) selectStrings(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) ([]string, error) {
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var results []string
|
|
|
|
for rows.Next() {
|
|
var value string
|
|
if err := rows.Scan(&value); err != nil {
|
|
return nil, ew(err)
|
|
}
|
|
results = append(results, value)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func CountMap(counts []Count) map[string]int {
|
|
results := make(map[string]int)
|
|
for _, c := range counts {
|
|
results[c.Key] = c.Count
|
|
}
|
|
return results
|
|
}
|