tavern/web/handler.go

175 lines
5.1 KiB
Go

package web
import (
"net/http"
"strconv"
"strings"
"github.com/getsentry/sentry-go"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.uber.org/zap"
"github.com/ngerakines/tavern/asset"
"github.com/ngerakines/tavern/common"
"github.com/ngerakines/tavern/config"
"github.com/ngerakines/tavern/errors"
"github.com/ngerakines/tavern/storage"
)
type handler struct {
storage storage.Storage
logger *zap.Logger
domain string
sentryConfig config.SentryConfig
fedConfig config.FedConfig
groupConfig config.GroupConfig
webFingerQueue common.StringQueue
crawlQueue common.StringQueue
adminUser string
url func(parts ...interface{}) string
svgConverter SVGConverter
assetStorage asset.Storage
assetQueue common.StringQueue
httpClient common.HTTPClient
metricFactory promauto.Factory
publisherClient *publisherClient
serverActorRowID uuid.UUID
}
func (h handler) hardFail(ctx *gin.Context, err error, fields ...zap.Field) {
if h.sentryConfig.Enabled {
hub := sentry.CurrentHub().Clone()
hub.Scope().SetRequest(sentry.Request{}.FromHTTPRequest(ctx.Request))
hub.CaptureException(err)
}
trans, transOK := ctx.Get("trans")
if !transOK {
panic("trans not found in context")
}
fields = append(fields, zap.Error(err), zap.Strings("error_chain", errors.ErrorChain(err)))
h.logger.Error("request hard failed", fields...)
ctx.HTML(http.StatusInternalServerError, "error", gin.H{"error": err.Error(), "Trans": trans})
ctx.Abort()
}
func (h handler) unauthorizedJSON(ctx *gin.Context, err error, fields ...zap.Field) {
h.writeJSONError(ctx, http.StatusUnauthorized, err, fields...)
}
func (h handler) badRequestJSON(ctx *gin.Context, err error, fields ...zap.Field) {
h.writeJSONError(ctx, http.StatusBadRequest, err, fields...)
}
func (h handler) internalServerErrorJSON(ctx *gin.Context, err error, fields ...zap.Field) {
h.writeJSONError(ctx, http.StatusInternalServerError, err, fields...)
}
func (h handler) notFoundJSON(ctx *gin.Context, err error, fields ...zap.Field) {
h.writeJSONError(ctx, http.StatusNotFound, err, fields...)
}
func (h handler) writeJSONError(ctx *gin.Context, statusCode int, err error, fields ...zap.Field) {
if err != nil {
if h.sentryConfig.Enabled {
hub := sentry.CurrentHub().Clone()
hub.Scope().SetRequest(sentry.Request{}.FromHTTPRequest(ctx.Request))
hub.CaptureException(err)
}
fields = append(fields, zap.Error(err), zap.Strings("error_chain", errors.ErrorChain(err)))
h.logger.Error("error processing request", fields...)
}
ctx.JSON(statusCode, wrapJSONError(err))
}
func (h handler) writeJRD(c *gin.Context, statusCode int, data interface{}) {
c.Writer.Header().Set("Content-Type", "application/jrd+json")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Pragma", "no-cache")
c.JSON(statusCode, data)
}
func (h handler) writeJSONLD(c *gin.Context, statusCode int, data interface{}) {
c.Writer.Header().Set("Content-Type", "application/activity+json")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Pragma", "no-cache")
c.JSON(statusCode, data)
}
func (h handler) writeJSONLDProfile(c *gin.Context, statusCode int, data interface{}) {
c.Writer.Header().Set("Content-Type", `application/ld+json; profile="https://www.w3.org/ns/activitystreams"`)
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Pragma", "no-cache")
c.JSON(statusCode, data)
}
func (h handler) flashErrorOrFail(c *gin.Context, location string, err error) {
session := sessions.Default(c)
h.logger.Error("error", zap.Error(err))
if err := appendFlashError(session, err.Error()); err != nil {
h.hardFail(c, errors.NewCannotSaveSessionError(err))
return
}
c.Redirect(http.StatusFound, location)
}
func (h handler) flashSuccessOrFail(c *gin.Context, location string, message string) {
session := sessions.Default(c)
if err := appendFlashSuccess(session, message); err != nil {
h.hardFail(c, errors.NewCannotSaveSessionError(err))
return
}
c.Redirect(http.StatusFound, location)
}
func (h handler) userActor(user *storage.User, actor *storage.Actor) storage.LocalActor {
return storage.LocalActor{User: user, Actor: actor, ActorID: storage.NewActorID(user.Name, h.domain)}
}
func intParam(c *gin.Context, name string, defaultValue int) int {
if input := c.Query(name); input != "" {
value, err := strconv.Atoi(input)
if err == nil && value >= 0 {
return value
}
}
return defaultValue
}
func wrapJSONError(err error) interface{} {
if err == nil {
return nil
}
return gin.H{"error": err.Error()}
}
func requireAccept(c *gin.Context, contentType string) bool {
accepted := parseAccept(c.GetHeader("Accept"))
for _, accept := range accepted {
if accept == "*/*" || accept == contentType {
return true
}
}
return false
}
func parseAccept(acceptHeader string) []string {
parts := strings.Split(acceptHeader, ",")
out := make([]string, 0, len(parts))
for _, part := range parts {
if part = strings.TrimSpace(strings.Split(part, ";")[0]); part != "" {
out = append(out, part)
}
}
return out
}