tavern/storage/storage.go

264 lines
6.3 KiB
Go

package storage
import (
"context"
"database/sql"
"time"
"github.com/gofrs/uuid"
_ "github.com/lib/pq"
"github.com/ngerakines/tavern/errors"
)
type Storage interface {
RowCount(ctx context.Context, query string, args ...interface{}) (int, error)
UserStorage
ActorStorage
FollowerStorage
PeerStorage
ObjectStorage
AssetStorage
InstanceStatsStorage
GroupStorage
ACLStorage
GetExecutor() QueryExecute
}
type Count struct {
Key string
Count int
}
func DefaultStorage(db QueryExecute) Storage {
return pgStorage{
db: db,
now: defaultNowFunc,
}
}
type pgStorage struct {
db QueryExecute
now nowFunc
}
type transactionScopedWork func(db QueryExecute) error
type transactionScopedStorage func(storage Storage) error
type nowFunc func() time.Time
func defaultNowFunc() time.Time {
return time.Now().UTC()
}
// deprecated
func runTransactionWithOptions(db QueryExecute, txBody transactionScopedWork) error {
return txBody(db)
}
func TransactionalStorage(ctx context.Context, storage Storage, txBody transactionScopedStorage) error {
if _, ok := storage.GetExecutor().(TransactionSQLDriver); ok {
return txBody(storage)
}
executor := storage.GetExecutor()
realDB, ok := executor.(*sql.DB)
if !ok {
return txBody(storage)
}
tx, err := realDB.BeginTx(ctx, nil)
if err != nil {
return errors.NewDatabaseTransactionFailedError(err)
}
err = txBody(pgStorage{
db: TransactionSQLDriver{Driver: tx},
now: defaultNowFunc,
})
if err != nil {
if txErr := tx.Rollback(); txErr != nil {
return errors.NewDatabaseTransactionFailedError(txErr)
}
return err
}
return errors.WrapDatabaseTransactionFailedError(tx.Commit())
}
func (s pgStorage) GetExecutor() QueryExecute {
return s.db
}
func (s pgStorage) RowCount(ctx context.Context, query string, args ...interface{}) (int, error) {
return s.wrappedRowCount(errors.WrapQueryFailedError, ctx, query, args...)
}
func (s pgStorage) rowCount(ctx context.Context, query string, args ...interface{}) (int, error) {
return s.wrappedRowCount(errors.WrapQueryFailedError, ctx, query, args...)
}
func (s pgStorage) wrappedRowCount(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (int, error) {
var total int
err := s.db.QueryRowContext(ctx, query, args...).Scan(&total)
if err != nil {
return -1, ew(err)
}
return total, nil
}
func (s pgStorage) wrappedExists(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (bool, error) {
var total int
err := s.db.QueryRowContext(ctx, query, args...).Scan(&total)
if err != nil {
return false, ew(err)
}
return total != 0, nil
}
func (s pgStorage) wrappedSelectUUID(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (uuid.UUID, error) {
var id uuid.UUID
err := s.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
if err == sql.ErrNoRows {
return uuid.Nil, ew(errors.NewNotFoundError(err))
}
return uuid.Nil, ew(err)
}
return id, nil
}
func (s pgStorage) wrappedSelectInt(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (int, error) {
var value int
err := s.db.QueryRowContext(ctx, query, args...).Scan(&value)
if err != nil {
if err == sql.ErrNoRows {
return -1, ew(errors.NewNotFoundError(err))
}
return -1, ew(err)
}
return value, nil
}
func (s pgStorage) keyedCount(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) ([]Count, error) {
results := make([]Count, 0)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, ew(err)
}
defer rows.Close()
for rows.Next() {
var count Count
if err := rows.Scan(&count.Key, &count.Count); err != nil {
return nil, ew(err)
}
results = append(results, count)
}
return results, nil
}
func (s pgStorage) keysToUUID(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (map[string]uuid.UUID, error) {
results := make(map[string]uuid.UUID)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, ew(err)
}
defer rows.Close()
for rows.Next() {
var key string
var id uuid.UUID
if err := rows.Scan(&key, &id); err != nil {
return nil, ew(err)
}
results[key] = id
}
return results, nil
}
func (s pgStorage) uuidsToPayload(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (map[uuid.UUID]Payload, error) {
results := make(map[uuid.UUID]Payload)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, ew(err)
}
defer rows.Close()
for rows.Next() {
var key uuid.UUID
var payload Payload
if err := rows.Scan(&key, &payload); err != nil {
return nil, ew(err)
}
results[key] = payload
}
return results, nil
}
func (s pgStorage) uuidsToUuids(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (map[uuid.UUID]uuid.UUID, error) {
results := make(map[uuid.UUID]uuid.UUID)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, ew(err)
}
defer rows.Close()
for rows.Next() {
var key uuid.UUID
var value uuid.UUID
if err := rows.Scan(&key, &value); err != nil {
return nil, ew(err)
}
results[key] = value
}
return results, nil
}
func (s pgStorage) toUUIDMultiMap(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) (map[uuid.UUID][]uuid.UUID, error) {
results := make(map[uuid.UUID][]uuid.UUID)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, ew(err)
}
defer rows.Close()
for rows.Next() {
var key uuid.UUID
var value uuid.UUID
if err := rows.Scan(&key, &value); err != nil {
return nil, ew(err)
}
values, hasValues := results[key]
if !hasValues {
values = make([]uuid.UUID, 0)
}
results[key] = append(values, value)
}
return results, nil
}
func (s pgStorage) selectStrings(ew errors.ErrorWrapper, ctx context.Context, query string, args ...interface{}) ([]string, error) {
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, ew(err)
}
defer rows.Close()
var results []string
for rows.Next() {
var value string
if err := rows.Scan(&value); err != nil {
return nil, ew(err)
}
results = append(results, value)
}
return results, nil
}
func CountMap(counts []Count) map[string]int {
results := make(map[string]int)
for _, c := range counts {
results[c.Key] = c.Count
}
return results
}