From 1f5a6d59ba73f86f8de4ecf14e3ffd50dab40064 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 9 Feb 2024 11:39:08 +0400 Subject: [PATCH] chore: consolidate websocketNetConn implementations (#12065) Consolidates websocketNetConn from multiple packages in favor of a central one in codersdk --- coderd/workspaceagents.go | 50 +----------- coderd/workspaceagentsrpc.go | 2 +- codersdk/agentsdk/agentsdk.go | 46 +---------- codersdk/provisionerdaemons.go | 45 +---------- codersdk/websocket.go | 53 ++++++++++++ codersdk/websocket_test.go | 80 +++++++++++++++++++ codersdk/workspaceagents.go | 2 +- codersdk/workspaceagents_internal_test.go | 2 +- enterprise/coderd/provisionerdaemons.go | 47 +---------- enterprise/coderd/workspaceproxycoordinate.go | 2 +- 10 files changed, 145 insertions(+), 184 deletions(-) create mode 100644 codersdk/websocket.go create mode 100644 codersdk/websocket_test.go diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 415193ff06..c254eb9e40 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/url" "sort" @@ -544,7 +543,7 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) { } go httpapi.Heartbeat(ctx, conn) - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) defer wsNetConn.Close() // Also closes conn. // The Go stdlib JSON encoder appends a newline character after message write. @@ -881,7 +880,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { }) return } - ctx, nconn := websocketNetConn(ctx, ws, websocket.MessageBinary) + ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary) defer nconn.Close() // Slurp all packets from the connection into io.Discard so pongs get sent @@ -990,7 +989,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request return } - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary) defer wsNetConn.Close() closeCtx, closeCtxCancel := context.WithCancel(ctx) @@ -1077,7 +1076,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R }) return } - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary) defer wsNetConn.Close() go httpapi.Heartbeat(ctx, conn) @@ -2108,47 +2107,6 @@ func createExternalAuthResponse(typ, token string, extra pqtype.NullRawMessage) return resp, err } -// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func -// is called if a read or write error is encountered. -type wsNetConn struct { - cancel context.CancelFunc - net.Conn -} - -func (c *wsNetConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - if err != nil { - c.cancel() - } - return n, err -} - -func (c *wsNetConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - if err != nil { - c.cancel() - } - return n, err -} - -func (c *wsNetConn) Close() error { - defer c.cancel() - return c.Conn.Close() -} - -// websocketNetConn wraps websocket.NetConn and returns a context that -// is tied to the parent context and the lifetime of the conn. Any error -// during read or write will cancel the context, but not close the -// conn. Close should be called to release context resources. -func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { - ctx, cancel := context.WithCancel(ctx) - nc := websocket.NetConn(ctx, conn, msgType) - return ctx, &wsNetConn{ - cancel: cancel, - Conn: nc, - } -} - func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.WorkspaceAgentLog { sdk := make([]codersdk.WorkspaceAgentLog, 0, len(logs)) for _, logEntry := range logs { diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index bdd244b4a7..a62286a9c9 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -100,7 +100,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { return } - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary) defer wsNetConn.Close() ycfg := yamux.DefaultConfig() diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 22cc0faab6..e96cd58b9d 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -203,7 +203,7 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) { return nil, codersdk.ReadBodyAsError(res) } - _, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + _, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary) netConn := &closeNetConn{ Conn: wsNetConn, @@ -596,50 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext return authResp, json.NewDecoder(res.Body).Decode(&authResp) } -// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func -// is called if a read or write error is encountered. -type wsNetConn struct { - cancel context.CancelFunc - net.Conn -} - -func (c *wsNetConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - if err != nil { - c.cancel() - } - return n, err -} - -func (c *wsNetConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - if err != nil { - c.cancel() - } - return n, err -} - -func (c *wsNetConn) Close() error { - defer c.cancel() - return c.Conn.Close() -} - -// websocketNetConn wraps websocket.NetConn and returns a context that -// is tied to the parent context and the lifetime of the conn. Any error -// during read or write will cancel the context, but not close the -// conn. Close should be called to release context resources. -func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { - // Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than - // the default because some of our protocols can include large messages like startup scripts. - conn.SetReadLimit(1 << 22) - ctx, cancel := context.WithCancel(ctx) - nc := websocket.NetConn(ctx, conn, msgType) - return ctx, &wsNetConn{ - cancel: cancel, - Conn: nc, - } -} - // LogsNotifyChannel returns the channel name responsible for notifying // of new logs. func LogsNotifyChannel(agentID uuid.UUID) string { diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 5457ba6991..e8f8ed8eb6 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/http/cookiejar" "time" @@ -248,7 +247,7 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione config := yamux.DefaultConfig() config.LogOutput = io.Discard // Use background context because caller should close the client. - _, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary) + _, wsNetConn := WebsocketNetConn(context.Background(), conn, websocket.MessageBinary) session, err := yamux.Client(wsNetConn, config) if err != nil { _ = conn.Close(websocket.StatusGoingAway, "") @@ -257,45 +256,3 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione } return proto.NewDRPCProvisionerDaemonClient(drpc.MultiplexedConn(session)), nil } - -// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func -// is called if a read or write error is encountered. -// @typescript-ignore wsNetConn -type wsNetConn struct { - cancel context.CancelFunc - net.Conn -} - -func (c *wsNetConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - if err != nil { - c.cancel() - } - return n, err -} - -func (c *wsNetConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - if err != nil { - c.cancel() - } - return n, err -} - -func (c *wsNetConn) Close() error { - defer c.cancel() - return c.Conn.Close() -} - -// websocketNetConn wraps websocket.NetConn and returns a context that -// is tied to the parent context and the lifetime of the conn. Any error -// during read or write will cancel the context, but not close the -// conn. Close should be called to release context resources. -func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { - ctx, cancel := context.WithCancel(ctx) - nc := websocket.NetConn(ctx, conn, msgType) - return ctx, &wsNetConn{ - cancel: cancel, - Conn: nc, - } -} diff --git a/codersdk/websocket.go b/codersdk/websocket.go new file mode 100644 index 0000000000..a872c20197 --- /dev/null +++ b/codersdk/websocket.go @@ -0,0 +1,53 @@ +package codersdk + +import ( + "context" + "net" + + "nhooyr.io/websocket" +) + +// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func +// is called if a read or write error is encountered. +// @typescript-ignore wsNetConn +type wsNetConn struct { + cancel context.CancelFunc + net.Conn +} + +func (c *wsNetConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if err != nil { + c.cancel() + } + return n, err +} + +func (c *wsNetConn) Close() error { + defer c.cancel() + return c.Conn.Close() +} + +// WebsocketNetConn wraps websocket.NetConn and returns a context that +// is tied to the parent context and the lifetime of the conn. Any error +// during read or write will cancel the context, but not close the +// conn. Close should be called to release context resources. +func WebsocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { + // Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than + // the default because some of our protocols can include large messages like startup scripts. + conn.SetReadLimit(1 << 22) + ctx, cancel := context.WithCancel(ctx) + nc := websocket.NetConn(ctx, conn, msgType) + return ctx, &wsNetConn{ + cancel: cancel, + Conn: nc, + } +} diff --git a/codersdk/websocket_test.go b/codersdk/websocket_test.go new file mode 100644 index 0000000000..861f9e9705 --- /dev/null +++ b/codersdk/websocket_test.go @@ -0,0 +1,80 @@ +package codersdk_test + +import ( + "crypto/rand" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "nhooyr.io/websocket" + + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// TestWebsocketNetConn_LargeWrites tests that we can write large amounts of data thru the netconn +// in a single write. Without specifically setting the read limit, the websocket library limits +// the amount of data that can be read in a single message to 32kiB. Even after raising the limit, +// curiously, it still only reads 32kiB per Read(), but allows the large write to go thru. +func TestWebsocketNetConn_LargeWrites(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + n := 4 * 1024 * 1024 // 4 MiB + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sws, err := websocket.Accept(w, r, nil) + if !assert.NoError(t, err) { + return + } + _, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary) + defer nc.Close() + + // Although the writes are all in one go, the reads get broken up by + // the library. + j := 0 + b := make([]byte, n) + for j < n { + k, err := nc.Read(b[j:]) + if !assert.NoError(t, err) { + return + } + j += k + t.Logf("server read %d bytes, total %d", k, j) + } + assert.Equal(t, n, j) + j, err = nc.Write(b) + assert.Equal(t, n, j) + if !assert.NoError(t, err) { + return + } + })) + + // use of random data is worst case scenario for compression + cb := make([]byte, n) + rk, err := rand.Read(cb) + require.NoError(t, err) + require.Equal(t, n, rk) + + // nolint: bodyclose + cws, _, err := websocket.Dial(ctx, svr.URL, nil) + require.NoError(t, err) + _, cnc := codersdk.WebsocketNetConn(ctx, cws, websocket.MessageBinary) + ck, err := cnc.Write(cb) + require.NoError(t, err) + require.Equal(t, n, ck) + + cb2 := make([]byte, n) + j := 0 + for j < n { + k, err := cnc.Read(cb2[j:]) + if !assert.NoError(t, err) { + return + } + j += k + t.Logf("client read %d bytes, total %d", k, j) + } + require.NoError(t, err) + require.Equal(t, n, j) + require.Equal(t, cb, cb2) +} diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 7109bd747d..eec42cdc3f 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -844,7 +844,7 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID, } logChunks := make(chan []WorkspaceAgentLog, 1) closed := make(chan struct{}) - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) + ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText) decoder := json.NewDecoder(wsNetConn) go func() { defer close(closed) diff --git a/codersdk/workspaceagents_internal_test.go b/codersdk/workspaceagents_internal_test.go index 38854114c1..c71f7d440c 100644 --- a/codersdk/workspaceagents_internal_test.go +++ b/codersdk/workspaceagents_internal_test.go @@ -50,7 +50,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { if !assert.NoError(t, err) { return } - ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary) + ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary) err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ Name: "client", ID: clientID, diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 92f034e352..f81f17befd 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -7,13 +7,10 @@ import ( "errors" "fmt" "io" - "net" "net/http" "strings" "time" - "github.com/coder/coder/v2/provisionersdk" - "github.com/google/uuid" "github.com/hashicorp/yamux" "github.com/moby/moby/pkg/namesgenerator" @@ -37,6 +34,7 @@ import ( "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionerd/proto" + "github.com/coder/coder/v2/provisionersdk" ) func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler { @@ -297,7 +295,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) // the same connection. config := yamux.DefaultConfig() config.LogOutput = io.Discard - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary) defer wsNetConn.Close() session, err := yamux.Server(wsNetConn, config) if err != nil { @@ -360,44 +358,3 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) } _ = conn.Close(websocket.StatusGoingAway, "") } - -// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func -// is called if a read or write error is encountered. -type wsNetConn struct { - cancel context.CancelFunc - net.Conn -} - -func (c *wsNetConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - if err != nil { - c.cancel() - } - return n, err -} - -func (c *wsNetConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - if err != nil { - c.cancel() - } - return n, err -} - -func (c *wsNetConn) Close() error { - defer c.cancel() - return c.Conn.Close() -} - -// websocketNetConn wraps websocket.NetConn and returns a context that -// is tied to the parent context and the lifetime of the conn. Any error -// during read or write will cancel the context, but not close the -// conn. Close should be called to release context resources. -func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) { - ctx, cancel := context.WithCancel(ctx) - nc := websocket.NetConn(ctx, conn, msgType) - return ctx, &wsNetConn{ - cancel: cancel, - Conn: nc, - } -} diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index 725019e251..a85cc0488e 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -57,7 +57,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request return } - ctx, nc := websocketNetConn(ctx, conn, msgType) + ctx, nc := codersdk.WebsocketNetConn(ctx, conn, msgType) defer nc.Close() id := uuid.New()