diff --git a/coderd/coderd.go b/coderd/coderd.go index f7b8603367..cf5d70d0a3 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1,9 +1,7 @@ package coderd import ( - "context" "crypto/x509" - "fmt" "io" "net/http" "net/url" @@ -125,11 +123,8 @@ func New(options *Options) *API { apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false) r.Use( - func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(middleware.NewWrapResponseWriter(w, r.ProtoMajor), r) - }) - }, + httpmw.Recover(api.Logger), + httpmw.Logger(api.Logger), httpmw.Prometheus(options.PrometheusRegistry), ) @@ -159,7 +154,6 @@ func New(options *Options) *API { r.Use( // Specific routes can specify smaller limits. httpmw.RateLimitPerMinute(options.APIRateLimit), - debugLogRequest(api.Logger), tracing.HTTPMW(api.TracerProvider, "coderd.http"), ) r.Get("/", func(w http.ResponseWriter, r *http.Request) { @@ -438,15 +432,6 @@ func (api *API) Close() error { return api.workspaceAgentCache.Close() } -func debugLogRequest(log slog.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - log.Debug(context.Background(), fmt.Sprintf("%s %s", r.Method, r.URL.Path)) - next.ServeHTTP(rw, r) - }) - } -} - func compressHandler(h http.Handler) http.Handler { cmp := middleware.NewCompressor(5, "text/*", diff --git a/coderd/httpapi/httpapi.go b/coderd/httpapi/httpapi.go index b42d2257b4..5393a79bfc 100644 --- a/coderd/httpapi/httpapi.go +++ b/coderd/httpapi/httpapi.go @@ -59,6 +59,18 @@ func Forbidden(rw http.ResponseWriter) { }) } +func InternalServerError(rw http.ResponseWriter, err error) { + var details string + if err != nil { + details = err.Error() + } + + Write(rw, http.StatusInternalServerError, codersdk.Response{ + Message: "An internal server error occurred.", + Detail: details, + }) +} + // Write outputs a standardized format to an HTTP response body. func Write(rw http.ResponseWriter, status int, response interface{}) { buf := &bytes.Buffer{} diff --git a/coderd/httpapi/httpapi_test.go b/coderd/httpapi/httpapi_test.go index 35ed403ba4..79a26d54a2 100644 --- a/coderd/httpapi/httpapi_test.go +++ b/coderd/httpapi/httpapi_test.go @@ -10,11 +10,46 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" ) +func TestInternalServerError(t *testing.T) { + t.Parallel() + + t.Run("NoError", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + httpapi.InternalServerError(w, nil) + + var resp codersdk.Response + err := json.NewDecoder(w.Body).Decode(&resp) + require.NoError(t, err) + require.Equal(t, http.StatusInternalServerError, w.Code) + require.NotEmpty(t, resp.Message) + require.Empty(t, resp.Detail) + }) + + t.Run("WithError", func(t *testing.T) { + t.Parallel() + var ( + w = httptest.NewRecorder() + httpErr = xerrors.New("error!") + ) + + httpapi.InternalServerError(w, httpErr) + + var resp codersdk.Response + err := json.NewDecoder(w.Body).Decode(&resp) + require.NoError(t, err) + require.Equal(t, http.StatusInternalServerError, w.Code) + require.NotEmpty(t, resp.Message) + require.Equal(t, httpErr.Error(), resp.Detail) + }) +} + func TestWrite(t *testing.T) { t.Parallel() t.Run("NoErrors", func(t *testing.T) { diff --git a/coderd/httpapi/request.go b/coderd/httpapi/request.go new file mode 100644 index 0000000000..6a07ede6dc --- /dev/null +++ b/coderd/httpapi/request.go @@ -0,0 +1,30 @@ +package httpapi + +import "net/http" + +const ( + // XForwardedHostHeader is a header used by proxies to indicate the + // original host of the request. + XForwardedHostHeader = "X-Forwarded-Host" +) + +// RequestHost returns the name of the host from the request. It prioritizes +// 'X-Forwarded-Host' over r.Host since most requests are being proxied. +func RequestHost(r *http.Request) string { + host := r.Header.Get(XForwardedHostHeader) + if host != "" { + return host + } + + return r.Host +} + +func IsWebsocketUpgrade(r *http.Request) bool { + vs := r.Header.Values("Upgrade") + for _, v := range vs { + if v == "websocket" { + return true + } + } + return false +} diff --git a/coderd/httpapi/status_writer.go b/coderd/httpapi/status_writer.go new file mode 100644 index 0000000000..dcdb3345c6 --- /dev/null +++ b/coderd/httpapi/status_writer.go @@ -0,0 +1,74 @@ +package httpapi + +import ( + "bufio" + "net" + "net/http" + + "golang.org/x/xerrors" +) + +var _ http.ResponseWriter = (*StatusWriter)(nil) +var _ http.Hijacker = (*StatusWriter)(nil) + +// StatusWriter intercepts the status of the request and the response body up +// to maxBodySize if Status >= 400. It is guaranteed to be the ResponseWriter +// directly downstream from Middleware. +type StatusWriter struct { + http.ResponseWriter + Status int + Hijacked bool + responseBody []byte + + wroteHeader bool +} + +func (w *StatusWriter) WriteHeader(status int) { + if !w.wroteHeader { + w.Status = status + w.wroteHeader = true + } + w.ResponseWriter.WriteHeader(status) +} + +func (w *StatusWriter) Write(b []byte) (int, error) { + const maxBodySize = 4096 + + if !w.wroteHeader { + w.Status = http.StatusOK + w.wroteHeader = true + } + + if w.Status >= http.StatusBadRequest { + // This is technically wrong as multiple calls to write + // will simply overwrite w.ResponseBody but given that + // we typically only write to the response body once + // and this field is only used for logging I'm leaving + // this as-is. + w.responseBody = make([]byte, minInt(len(b), maxBodySize)) + copy(w.responseBody, b) + } + + return w.ResponseWriter.Write(b) +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter) + } + w.Hijacked = true + + return hijacker.Hijack() +} + +func (w *StatusWriter) ResponseBody() []byte { + return w.responseBody +} diff --git a/coderd/httpapi/status_writer_test.go b/coderd/httpapi/status_writer_test.go new file mode 100644 index 0000000000..ee713ac555 --- /dev/null +++ b/coderd/httpapi/status_writer_test.go @@ -0,0 +1,129 @@ +package httpapi_test + +import ( + "bufio" + "crypto/rand" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/httpapi" +) + +func TestStatusWriter(t *testing.T) { + t.Parallel() + + t.Run("WriteHeader", func(t *testing.T) { + t.Parallel() + + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + ) + + w.WriteHeader(http.StatusOK) + require.Equal(t, http.StatusOK, w.Status) + // Validate that the code is written to the underlying Response. + require.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("WriteHeaderTwice", func(t *testing.T) { + t.Parallel() + + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + code = http.StatusNotFound + ) + + w.WriteHeader(code) + w.WriteHeader(http.StatusOK) + // Validate that we only record the first status code. + require.Equal(t, code, w.Status) + // Validate that the code is written to the underlying Response. + require.Equal(t, code, rec.Code) + }) + + t.Run("WriteNoHeader", func(t *testing.T) { + t.Parallel() + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + body = []byte("hello") + ) + + _, err := w.Write(body) + require.NoError(t, err) + + // Should set the status to OK. + require.Equal(t, http.StatusOK, w.Status) + // We don't record the body for codes <400. + require.Equal(t, []byte(nil), w.ResponseBody()) + require.Equal(t, body, rec.Body.Bytes()) + }) + + t.Run("WriteAfterHeader", func(t *testing.T) { + t.Parallel() + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + body = []byte("hello") + code = http.StatusInternalServerError + ) + + w.WriteHeader(code) + _, err := w.Write(body) + require.NoError(t, err) + + require.Equal(t, code, w.Status) + require.Equal(t, body, w.ResponseBody()) + require.Equal(t, body, rec.Body.Bytes()) + }) + + t.Run("WriteMaxBody", func(t *testing.T) { + t.Parallel() + var ( + rec = httptest.NewRecorder() + w = &httpapi.StatusWriter{ResponseWriter: rec} + // 8kb body. + body = make([]byte, 8<<10) + code = http.StatusInternalServerError + ) + + _, err := rand.Read(body) + require.NoError(t, err) + + w.WriteHeader(code) + _, err = w.Write(body) + require.NoError(t, err) + + require.Equal(t, code, w.Status) + require.Equal(t, body, rec.Body.Bytes()) + require.Equal(t, body[:4096], w.ResponseBody()) + }) + + t.Run("Hijack", func(t *testing.T) { + t.Parallel() + var ( + rec = httptest.NewRecorder() + ) + + w := &httpapi.StatusWriter{ResponseWriter: hijacker{rec}} + + _, _, err := w.Hijack() + require.Error(t, err) + require.Equal(t, "hijacked", err.Error()) + }) +} + +type hijacker struct { + http.ResponseWriter +} + +func (hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, xerrors.New("hijacked") +} diff --git a/coderd/httpmw/logger.go b/coderd/httpmw/logger.go new file mode 100644 index 0000000000..6f3a700bc5 --- /dev/null +++ b/coderd/httpmw/logger.go @@ -0,0 +1,58 @@ +package httpmw + +import ( + "net/http" + "time" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/httpapi" +) + +func Logger(log slog.Logger) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + sw := &httpapi.StatusWriter{ResponseWriter: w} + + httplog := log.With( + slog.F("host", httpapi.RequestHost(r)), + slog.F("path", r.URL.Path), + slog.F("proto", r.Proto), + slog.F("remote_addr", r.RemoteAddr), + ) + + next.ServeHTTP(sw, r) + + // Don't log successful health check requests. + if r.URL.Path == "/api/v2" && sw.Status == 200 { + return + } + + httplog = httplog.With( + slog.F("took", time.Since(start)), + slog.F("status_code", sw.Status), + slog.F("latency_ms", float64(time.Since(start)/time.Millisecond)), + ) + + // For status codes 400 and higher we + // want to log the response body. + if sw.Status >= 400 { + httplog = httplog.With( + slog.F("response_body", string(sw.ResponseBody())), + ) + } + + logLevelFn := httplog.Debug + if sw.Status >= 400 { + logLevelFn = httplog.Warn + } + if sw.Status >= 500 { + // Server errors should be treated as an ERROR + // log level. + logLevelFn = httplog.Error + } + + logLevelFn(r.Context(), r.Method) + }) + } +} diff --git a/coderd/httpmw/prometheus.go b/coderd/httpmw/prometheus.go index acc57071f0..d954adc2cc 100644 --- a/coderd/httpmw/prometheus.go +++ b/coderd/httpmw/prometheus.go @@ -6,7 +6,8 @@ import ( "time" "github.com/go-chi/chi/v5" - chimw "github.com/go-chi/chi/v5/middleware" + + "github.com/coder/coder/coderd/httpapi" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -66,9 +67,9 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler rctx = chi.RouteContext(r.Context()) ) - sw, ok := w.(chimw.WrapResponseWriter) + sw, ok := w.(*httpapi.StatusWriter) if !ok { - panic("dev error: http.ResponseWriter is not chimw.WrapResponseWriter") + panic("dev error: http.ResponseWriter is not *httpapi.StatusWriter") } var ( @@ -76,7 +77,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler distOpts []string ) // We want to count WebSockets separately. - if isWebsocketUpgrade(r) { + if httpapi.IsWebsocketUpgrade(r) { websocketsConcurrent.Inc() defer websocketsConcurrent.Dec() @@ -93,20 +94,10 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler path := rctx.RoutePattern() distOpts = append(distOpts, path) - statusStr := strconv.Itoa(sw.Status()) + statusStr := strconv.Itoa(sw.Status) requestsProcessed.WithLabelValues(statusStr, method, path).Inc() dist.WithLabelValues(distOpts...).Observe(float64(time.Since(start)) / 1e6) }) } } - -func isWebsocketUpgrade(r *http.Request) bool { - vs := r.Header.Values("Upgrade") - for _, v := range vs { - if v == "websocket" { - return true - } - } - return false -} diff --git a/coderd/httpmw/prometheus_test.go b/coderd/httpmw/prometheus_test.go index 97c5575406..9514183471 100644 --- a/coderd/httpmw/prometheus_test.go +++ b/coderd/httpmw/prometheus_test.go @@ -7,10 +7,10 @@ import ( "testing" "github.com/go-chi/chi/v5" - chimw "github.com/go-chi/chi/v5/middleware" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" ) @@ -20,7 +20,7 @@ func TestPrometheus(t *testing.T) { t.Parallel() req := httptest.NewRequest("GET", "/", nil) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext())) - res := chimw.NewWrapResponseWriter(httptest.NewRecorder(), 0) + res := &httpapi.StatusWriter{ResponseWriter: httptest.NewRecorder()} reg := prometheus.NewRegistry() httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/coderd/httpmw/recover.go b/coderd/httpmw/recover.go new file mode 100644 index 0000000000..a25c063c5f --- /dev/null +++ b/coderd/httpmw/recover.go @@ -0,0 +1,40 @@ +package httpmw + +import ( + "context" + "net/http" + "runtime/debug" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/httpapi" +) + +func Recover(log slog.Logger) func(h http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + r := recover() + if r != nil { + log.Warn(context.Background(), + "panic serving http request (recovered)", + slog.F("panic", r), + slog.F("stack", string(debug.Stack())), + ) + + var hijacked bool + if sw, ok := w.(*httpapi.StatusWriter); ok { + hijacked = sw.Hijacked + } + + // Only try to write errors on + // non-hijacked responses. + if !hijacked { + httpapi.InternalServerError(w, nil) + } + } + }() + + h.ServeHTTP(w, r) + }) + } +} diff --git a/coderd/httpmw/recover_test.go b/coderd/httpmw/recover_test.go new file mode 100644 index 0000000000..f4b043f0ba --- /dev/null +++ b/coderd/httpmw/recover_test.go @@ -0,0 +1,74 @@ +package httpmw_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" +) + +func TestRecover(t *testing.T) { + t.Parallel() + + handler := func(isPanic, hijack bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isPanic { + panic("Oh no!") + } + + w.WriteHeader(http.StatusOK) + }) + } + + cases := []struct { + Name string + Code int + Panic bool + Hijack bool + }{ + { + Name: "OK", + Code: http.StatusOK, + Panic: false, + Hijack: false, + }, + { + Name: "Panic", + Code: http.StatusInternalServerError, + Panic: true, + Hijack: false, + }, + { + Name: "Hijack", + Code: 0, + Panic: true, + Hijack: true, + }, + } + + for _, c := range cases { + c := c + + t.Run(c.Name, func(t *testing.T) { + t.Parallel() + + var ( + log = slogtest.Make(t, nil) + r = httptest.NewRequest("GET", "/", nil) + w = &httpapi.StatusWriter{ + ResponseWriter: httptest.NewRecorder(), + Hijacked: c.Hijack, + } + ) + + httpmw.Recover(log)(handler(c.Panic, c.Hijack)).ServeHTTP(w, r) + + require.Equal(t, c.Code, w.Status) + }) + } +} diff --git a/coderd/tracing/httpmw.go b/coderd/tracing/httpmw.go index b11c62500e..bf854f5c64 100644 --- a/coderd/tracing/httpmw.go +++ b/coderd/tracing/httpmw.go @@ -4,11 +4,12 @@ import ( "fmt" "net/http" - "github.com/go-chi/chi/middleware" "github.com/go-chi/chi/v5" sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" "go.opentelemetry.io/otel/trace" + + "github.com/coder/coder/coderd/httpapi" ) // HTTPMW adds tracing to http routes. @@ -25,13 +26,15 @@ func HTTPMW(tracerProvider *sdktrace.TracerProvider, name string) func(http.Hand defer span.End() r = r.WithContext(ctx) - wrw := middleware.NewWrapResponseWriter(rw, r.ProtoMajor) + sw, ok := rw.(*httpapi.StatusWriter) + if !ok { + panic(fmt.Sprintf("ResponseWriter not a *httpapi.StatusWriter; got %T", rw)) + } // pass the span through the request context and serve the request to the next middleware - next.ServeHTTP(wrw, r) - + next.ServeHTTP(sw, r) // capture response data - EndHTTPSpan(r, wrw.Status()) + EndHTTPSpan(r, sw.Status) }) } } diff --git a/coderd/workspaces.go b/coderd/workspaces.go index af8c4eddec..d524dc6340 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -512,7 +512,7 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) { return } // Check if the name was already in use. - if database.IsUniqueViolation(err, database.UniqueWorkspacesOwnerIDLowerIndex) { + if database.IsUniqueViolation(err) { httpapi.Write(rw, http.StatusConflict, codersdk.Response{ Message: fmt.Sprintf("Workspace %q already exists.", req.Name), Validations: []codersdk.ValidationError{{