coder/coderd/healthcheck/websocket.go

158 lines
3.4 KiB
Go

package healthcheck
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"time"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/codersdk/healthsdk"
)
type WebsocketReport healthsdk.WebsocketReport
type WebsocketReportOptions struct {
APIKey string
AccessURL *url.URL
HTTPClient *http.Client
Dismissed bool
}
func (r *WebsocketReport) Run(ctx context.Context, opts *WebsocketReportOptions) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
r.Severity = health.SeverityOK
r.Warnings = []health.Message{}
r.Dismissed = opts.Dismissed
u, err := opts.AccessURL.Parse("/api/v2/debug/ws")
if err != nil {
r.Error = convertError(xerrors.Errorf("parse access url: %w", err))
r.Severity = health.SeverityError
return
}
if u.Scheme == "https" {
u.Scheme = "wss"
} else {
u.Scheme = "ws"
}
//nolint:bodyclose // websocket package closes this for you
c, res, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPClient: opts.HTTPClient,
HTTPHeader: http.Header{"Coder-Session-Token": []string{opts.APIKey}},
})
if res != nil {
var body string
if res.Body != nil {
b, err := io.ReadAll(res.Body)
if err == nil {
body = string(b)
}
}
r.Body = body
r.Code = res.StatusCode
}
if err != nil {
r.Error = convertError(xerrors.Errorf("websocket dial: %w", err))
r.Error = health.Errorf(health.CodeWebsocketDial, "websocket dial: %s", err)
r.Severity = health.SeverityError
return
}
defer c.Close(websocket.StatusGoingAway, "goodbye")
for i := 0; i < 3; i++ {
msg := strconv.Itoa(i)
err := c.Write(ctx, websocket.MessageText, []byte(msg))
if err != nil {
r.Error = health.Errorf(health.CodeWebsocketEcho, "write message: %s", err)
r.Severity = health.SeverityError
return
}
ty, got, err := c.Read(ctx)
if err != nil {
r.Error = health.Errorf(health.CodeWebsocketEcho, "read message: %s", err)
r.Severity = health.SeverityError
return
}
if ty != websocket.MessageText {
r.Error = health.Errorf(health.CodeWebsocketMsg, "received incorrect message type: %v", ty)
r.Severity = health.SeverityError
return
}
if string(got) != msg {
r.Error = health.Errorf(health.CodeWebsocketMsg, "received incorrect message: wanted %q, got %q", msg, string(got))
r.Severity = health.SeverityError
return
}
}
c.Close(websocket.StatusGoingAway, "goodbye")
r.Healthy = true
}
type WebsocketEchoServer struct {
Error error
Code int
}
func (s *WebsocketEchoServer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if s.Error != nil {
rw.WriteHeader(s.Code)
_, _ = rw.Write([]byte(s.Error.Error()))
return
}
ctx := r.Context()
c, err := websocket.Accept(rw, r, &websocket.AcceptOptions{})
if err != nil {
rw.WriteHeader(http.StatusBadRequest)
_, _ = rw.Write([]byte(fmt.Sprint("unable to accept:", err)))
return
}
defer c.Close(websocket.StatusGoingAway, "goodbye")
echo := func() error {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
typ, r, err := c.Reader(ctx)
if err != nil {
return xerrors.Errorf("get reader: %w", err)
}
w, err := c.Writer(ctx, typ)
if err != nil {
return xerrors.Errorf("get writer: %w", err)
}
_, err = io.Copy(w, r)
if err != nil {
return xerrors.Errorf("echo message: %w", err)
}
err = w.Close()
return err
}
for {
err := echo()
if err != nil {
return
}
}
}