mirror of https://github.com/coder/coder.git
feat: add panic recovery middleware (#3687)
This commit is contained in:
parent
3cf17d34e7
commit
053fe6ff61
|
@ -1,9 +1,7 @@
|
||||||
package coderd
|
package coderd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -125,11 +123,8 @@ func New(options *Options) *API {
|
||||||
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
|
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
|
||||||
|
|
||||||
r.Use(
|
r.Use(
|
||||||
func(next http.Handler) http.Handler {
|
httpmw.Recover(api.Logger),
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httpmw.Logger(api.Logger),
|
||||||
next.ServeHTTP(middleware.NewWrapResponseWriter(w, r.ProtoMajor), r)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
httpmw.Prometheus(options.PrometheusRegistry),
|
httpmw.Prometheus(options.PrometheusRegistry),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -159,7 +154,6 @@ func New(options *Options) *API {
|
||||||
r.Use(
|
r.Use(
|
||||||
// Specific routes can specify smaller limits.
|
// Specific routes can specify smaller limits.
|
||||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||||
debugLogRequest(api.Logger),
|
|
||||||
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
|
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
|
||||||
)
|
)
|
||||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -438,15 +432,6 @@ func (api *API) Close() error {
|
||||||
return api.workspaceAgentCache.Close()
|
return api.workspaceAgentCache.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func debugLogRequest(log slog.Logger) func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
||||||
log.Debug(context.Background(), fmt.Sprintf("%s %s", r.Method, r.URL.Path))
|
|
||||||
next.ServeHTTP(rw, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func compressHandler(h http.Handler) http.Handler {
|
func compressHandler(h http.Handler) http.Handler {
|
||||||
cmp := middleware.NewCompressor(5,
|
cmp := middleware.NewCompressor(5,
|
||||||
"text/*",
|
"text/*",
|
||||||
|
|
|
@ -59,6 +59,18 @@ func Forbidden(rw http.ResponseWriter) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InternalServerError(rw http.ResponseWriter, err error) {
|
||||||
|
var details string
|
||||||
|
if err != nil {
|
||||||
|
details = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
|
Message: "An internal server error occurred.",
|
||||||
|
Detail: details,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Write outputs a standardized format to an HTTP response body.
|
// Write outputs a standardized format to an HTTP response body.
|
||||||
func Write(rw http.ResponseWriter, status int, response interface{}) {
|
func Write(rw http.ResponseWriter, status int, response interface{}) {
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
|
|
@ -10,11 +10,46 @@ import (
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/coder/coder/coderd/httpapi"
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
"github.com/coder/coder/codersdk"
|
"github.com/coder/coder/codersdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestInternalServerError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("NoError", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
httpapi.InternalServerError(w, nil)
|
||||||
|
|
||||||
|
var resp codersdk.Response
|
||||||
|
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
require.NotEmpty(t, resp.Message)
|
||||||
|
require.Empty(t, resp.Detail)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithError", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
httpErr = xerrors.New("error!")
|
||||||
|
)
|
||||||
|
|
||||||
|
httpapi.InternalServerError(w, httpErr)
|
||||||
|
|
||||||
|
var resp codersdk.Response
|
||||||
|
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
require.NotEmpty(t, resp.Message)
|
||||||
|
require.Equal(t, httpErr.Error(), resp.Detail)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestWrite(t *testing.T) {
|
func TestWrite(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("NoErrors", func(t *testing.T) {
|
t.Run("NoErrors", func(t *testing.T) {
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
package httpapi
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// XForwardedHostHeader is a header used by proxies to indicate the
|
||||||
|
// original host of the request.
|
||||||
|
XForwardedHostHeader = "X-Forwarded-Host"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestHost returns the name of the host from the request. It prioritizes
|
||||||
|
// 'X-Forwarded-Host' over r.Host since most requests are being proxied.
|
||||||
|
func RequestHost(r *http.Request) string {
|
||||||
|
host := r.Header.Get(XForwardedHostHeader)
|
||||||
|
if host != "" {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsWebsocketUpgrade(r *http.Request) bool {
|
||||||
|
vs := r.Header.Values("Upgrade")
|
||||||
|
for _, v := range vs {
|
||||||
|
if v == "websocket" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
package httpapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ http.ResponseWriter = (*StatusWriter)(nil)
|
||||||
|
var _ http.Hijacker = (*StatusWriter)(nil)
|
||||||
|
|
||||||
|
// StatusWriter intercepts the status of the request and the response body up
|
||||||
|
// to maxBodySize if Status >= 400. It is guaranteed to be the ResponseWriter
|
||||||
|
// directly downstream from Middleware.
|
||||||
|
type StatusWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
Status int
|
||||||
|
Hijacked bool
|
||||||
|
responseBody []byte
|
||||||
|
|
||||||
|
wroteHeader bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *StatusWriter) WriteHeader(status int) {
|
||||||
|
if !w.wroteHeader {
|
||||||
|
w.Status = status
|
||||||
|
w.wroteHeader = true
|
||||||
|
}
|
||||||
|
w.ResponseWriter.WriteHeader(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||||
|
const maxBodySize = 4096
|
||||||
|
|
||||||
|
if !w.wroteHeader {
|
||||||
|
w.Status = http.StatusOK
|
||||||
|
w.wroteHeader = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Status >= http.StatusBadRequest {
|
||||||
|
// This is technically wrong as multiple calls to write
|
||||||
|
// will simply overwrite w.ResponseBody but given that
|
||||||
|
// we typically only write to the response body once
|
||||||
|
// and this field is only used for logging I'm leaving
|
||||||
|
// this as-is.
|
||||||
|
w.responseBody = make([]byte, minInt(len(b), maxBodySize))
|
||||||
|
copy(w.responseBody, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func minInt(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter)
|
||||||
|
}
|
||||||
|
w.Hijacked = true
|
||||||
|
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *StatusWriter) ResponseBody() []byte {
|
||||||
|
return w.responseBody
|
||||||
|
}
|
|
@ -0,0 +1,129 @@
|
||||||
|
package httpapi_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStatusWriter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("WriteHeader", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||||
|
)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
require.Equal(t, http.StatusOK, w.Status)
|
||||||
|
// Validate that the code is written to the underlying Response.
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WriteHeaderTwice", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||||
|
code = http.StatusNotFound
|
||||||
|
)
|
||||||
|
|
||||||
|
w.WriteHeader(code)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
// Validate that we only record the first status code.
|
||||||
|
require.Equal(t, code, w.Status)
|
||||||
|
// Validate that the code is written to the underlying Response.
|
||||||
|
require.Equal(t, code, rec.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WriteNoHeader", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||||
|
body = []byte("hello")
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := w.Write(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should set the status to OK.
|
||||||
|
require.Equal(t, http.StatusOK, w.Status)
|
||||||
|
// We don't record the body for codes <400.
|
||||||
|
require.Equal(t, []byte(nil), w.ResponseBody())
|
||||||
|
require.Equal(t, body, rec.Body.Bytes())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WriteAfterHeader", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||||
|
body = []byte("hello")
|
||||||
|
code = http.StatusInternalServerError
|
||||||
|
)
|
||||||
|
|
||||||
|
w.WriteHeader(code)
|
||||||
|
_, err := w.Write(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, code, w.Status)
|
||||||
|
require.Equal(t, body, w.ResponseBody())
|
||||||
|
require.Equal(t, body, rec.Body.Bytes())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WriteMaxBody", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||||
|
// 8kb body.
|
||||||
|
body = make([]byte, 8<<10)
|
||||||
|
code = http.StatusInternalServerError
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := rand.Read(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
w.WriteHeader(code)
|
||||||
|
_, err = w.Write(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, code, w.Status)
|
||||||
|
require.Equal(t, body, rec.Body.Bytes())
|
||||||
|
require.Equal(t, body[:4096], w.ResponseBody())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Hijack", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
)
|
||||||
|
|
||||||
|
w := &httpapi.StatusWriter{ResponseWriter: hijacker{rec}}
|
||||||
|
|
||||||
|
_, _, err := w.Hijack()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "hijacked", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type hijacker struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return nil, nil, xerrors.New("hijacked")
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
package httpmw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Logger(log slog.Logger) func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
sw := &httpapi.StatusWriter{ResponseWriter: w}
|
||||||
|
|
||||||
|
httplog := log.With(
|
||||||
|
slog.F("host", httpapi.RequestHost(r)),
|
||||||
|
slog.F("path", r.URL.Path),
|
||||||
|
slog.F("proto", r.Proto),
|
||||||
|
slog.F("remote_addr", r.RemoteAddr),
|
||||||
|
)
|
||||||
|
|
||||||
|
next.ServeHTTP(sw, r)
|
||||||
|
|
||||||
|
// Don't log successful health check requests.
|
||||||
|
if r.URL.Path == "/api/v2" && sw.Status == 200 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
httplog = httplog.With(
|
||||||
|
slog.F("took", time.Since(start)),
|
||||||
|
slog.F("status_code", sw.Status),
|
||||||
|
slog.F("latency_ms", float64(time.Since(start)/time.Millisecond)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// For status codes 400 and higher we
|
||||||
|
// want to log the response body.
|
||||||
|
if sw.Status >= 400 {
|
||||||
|
httplog = httplog.With(
|
||||||
|
slog.F("response_body", string(sw.ResponseBody())),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
logLevelFn := httplog.Debug
|
||||||
|
if sw.Status >= 400 {
|
||||||
|
logLevelFn = httplog.Warn
|
||||||
|
}
|
||||||
|
if sw.Status >= 500 {
|
||||||
|
// Server errors should be treated as an ERROR
|
||||||
|
// log level.
|
||||||
|
logLevelFn = httplog.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
logLevelFn(r.Context(), r.Method)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,7 +6,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
chimw "github.com/go-chi/chi/v5/middleware"
|
|
||||||
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
@ -66,9 +67,9 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||||
rctx = chi.RouteContext(r.Context())
|
rctx = chi.RouteContext(r.Context())
|
||||||
)
|
)
|
||||||
|
|
||||||
sw, ok := w.(chimw.WrapResponseWriter)
|
sw, ok := w.(*httpapi.StatusWriter)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("dev error: http.ResponseWriter is not chimw.WrapResponseWriter")
|
panic("dev error: http.ResponseWriter is not *httpapi.StatusWriter")
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -76,7 +77,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||||
distOpts []string
|
distOpts []string
|
||||||
)
|
)
|
||||||
// We want to count WebSockets separately.
|
// We want to count WebSockets separately.
|
||||||
if isWebsocketUpgrade(r) {
|
if httpapi.IsWebsocketUpgrade(r) {
|
||||||
websocketsConcurrent.Inc()
|
websocketsConcurrent.Inc()
|
||||||
defer websocketsConcurrent.Dec()
|
defer websocketsConcurrent.Dec()
|
||||||
|
|
||||||
|
@ -93,20 +94,10 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||||
|
|
||||||
path := rctx.RoutePattern()
|
path := rctx.RoutePattern()
|
||||||
distOpts = append(distOpts, path)
|
distOpts = append(distOpts, path)
|
||||||
statusStr := strconv.Itoa(sw.Status())
|
statusStr := strconv.Itoa(sw.Status)
|
||||||
|
|
||||||
requestsProcessed.WithLabelValues(statusStr, method, path).Inc()
|
requestsProcessed.WithLabelValues(statusStr, method, path).Inc()
|
||||||
dist.WithLabelValues(distOpts...).Observe(float64(time.Since(start)) / 1e6)
|
dist.WithLabelValues(distOpts...).Observe(float64(time.Since(start)) / 1e6)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isWebsocketUpgrade(r *http.Request) bool {
|
|
||||||
vs := r.Header.Values("Upgrade")
|
|
||||||
for _, v := range vs {
|
|
||||||
if v == "websocket" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,10 +7,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
chimw "github.com/go-chi/chi/v5/middleware"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
"github.com/coder/coder/coderd/httpmw"
|
"github.com/coder/coder/coderd/httpmw"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ func TestPrometheus(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||||
res := chimw.NewWrapResponseWriter(httptest.NewRecorder(), 0)
|
res := &httpapi.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||||
reg := prometheus.NewRegistry()
|
reg := prometheus.NewRegistry()
|
||||||
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
package httpmw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"runtime/debug"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Recover(log slog.Logger) func(h http.Handler) http.Handler {
|
||||||
|
return func(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r != nil {
|
||||||
|
log.Warn(context.Background(),
|
||||||
|
"panic serving http request (recovered)",
|
||||||
|
slog.F("panic", r),
|
||||||
|
slog.F("stack", string(debug.Stack())),
|
||||||
|
)
|
||||||
|
|
||||||
|
var hijacked bool
|
||||||
|
if sw, ok := w.(*httpapi.StatusWriter); ok {
|
||||||
|
hijacked = sw.Hijacked
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only try to write errors on
|
||||||
|
// non-hijacked responses.
|
||||||
|
if !hijacked {
|
||||||
|
httpapi.InternalServerError(w, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
package httpmw_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
|
"github.com/coder/coder/coderd/httpmw"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecover(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
handler := func(isPanic, hijack bool) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if isPanic {
|
||||||
|
panic("Oh no!")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
Name string
|
||||||
|
Code int
|
||||||
|
Panic bool
|
||||||
|
Hijack bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "OK",
|
||||||
|
Code: http.StatusOK,
|
||||||
|
Panic: false,
|
||||||
|
Hijack: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Panic",
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
Panic: true,
|
||||||
|
Hijack: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Hijack",
|
||||||
|
Code: 0,
|
||||||
|
Panic: true,
|
||||||
|
Hijack: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
c := c
|
||||||
|
|
||||||
|
t.Run(c.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
log = slogtest.Make(t, nil)
|
||||||
|
r = httptest.NewRequest("GET", "/", nil)
|
||||||
|
w = &httpapi.StatusWriter{
|
||||||
|
ResponseWriter: httptest.NewRecorder(),
|
||||||
|
Hijacked: c.Hijack,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
httpmw.Recover(log)(handler(c.Panic, c.Hijack)).ServeHTTP(w, r)
|
||||||
|
|
||||||
|
require.Equal(t, c.Code, w.Status)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,11 +4,12 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi/middleware"
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||||
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
|
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPMW adds tracing to http routes.
|
// HTTPMW adds tracing to http routes.
|
||||||
|
@ -25,13 +26,15 @@ func HTTPMW(tracerProvider *sdktrace.TracerProvider, name string) func(http.Hand
|
||||||
defer span.End()
|
defer span.End()
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
wrw := middleware.NewWrapResponseWriter(rw, r.ProtoMajor)
|
sw, ok := rw.(*httpapi.StatusWriter)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("ResponseWriter not a *httpapi.StatusWriter; got %T", rw))
|
||||||
|
}
|
||||||
|
|
||||||
// pass the span through the request context and serve the request to the next middleware
|
// pass the span through the request context and serve the request to the next middleware
|
||||||
next.ServeHTTP(wrw, r)
|
next.ServeHTTP(sw, r)
|
||||||
|
|
||||||
// capture response data
|
// capture response data
|
||||||
EndHTTPSpan(r, wrw.Status())
|
EndHTTPSpan(r, sw.Status)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -512,7 +512,7 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Check if the name was already in use.
|
// Check if the name was already in use.
|
||||||
if database.IsUniqueViolation(err, database.UniqueWorkspacesOwnerIDLowerIndex) {
|
if database.IsUniqueViolation(err) {
|
||||||
httpapi.Write(rw, http.StatusConflict, codersdk.Response{
|
httpapi.Write(rw, http.StatusConflict, codersdk.Response{
|
||||||
Message: fmt.Sprintf("Workspace %q already exists.", req.Name),
|
Message: fmt.Sprintf("Workspace %q already exists.", req.Name),
|
||||||
Validations: []codersdk.ValidationError{{
|
Validations: []codersdk.ValidationError{{
|
||||||
|
|
Loading…
Reference in New Issue