coder/coderd/tracing/status_writer.go

92 lines
1.9 KiB
Go

package tracing
import (
"bufio"
"net"
"net/http"
"golang.org/x/xerrors"
)
var (
_ http.ResponseWriter = (*StatusWriter)(nil)
_ http.Hijacker = (*StatusWriter)(nil)
)
// StatusWriter intercepts the status of the request and the response body up
// to maxBodySize if Status >= 400. It is guaranteed to be the ResponseWriter
// directly downstream from Middleware.
type StatusWriter struct {
http.ResponseWriter
Status int
Hijacked bool
responseBody []byte
wroteHeader bool
}
func StatusWriterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
sw := &StatusWriter{ResponseWriter: rw}
next.ServeHTTP(sw, r)
})
}
func (w *StatusWriter) WriteHeader(status int) {
if !w.wroteHeader {
w.Status = status
w.wroteHeader = true
}
w.ResponseWriter.WriteHeader(status)
}
func (w *StatusWriter) Write(b []byte) (int, error) {
const maxBodySize = 4096
if !w.wroteHeader {
w.Status = http.StatusOK
w.wroteHeader = true
}
if w.Status >= http.StatusBadRequest {
// This is technically wrong as multiple calls to write
// will simply overwrite w.ResponseBody but given that
// we typically only write to the response body once
// and this field is only used for logging I'm leaving
// this as-is.
w.responseBody = make([]byte, minInt(len(b), maxBodySize))
copy(w.responseBody, b)
}
return w.ResponseWriter.Write(b)
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter)
}
w.Hijacked = true
return hijacker.Hijack()
}
func (w *StatusWriter) ResponseBody() []byte {
return w.responseBody
}
func (w *StatusWriter) Flush() {
f, ok := w.ResponseWriter.(http.Flusher)
if !ok {
panic("http.ResponseWriter is not http.Flusher")
}
f.Flush()
}