chore: consolidate websocketNetConn implementations (#12065)

Consolidates websocketNetConn from multiple packages in favor of a central one in codersdk
This commit is contained in:
Spike Curtis 2024-02-09 11:39:08 +04:00 committed by GitHub
parent ec8e41f516
commit 1f5a6d59ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 145 additions and 184 deletions

View File

@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sort"
@ -544,7 +543,7 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
}
go httpapi.Heartbeat(ctx, conn)
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
defer wsNetConn.Close() // Also closes conn.
// The Go stdlib JSON encoder appends a newline character after message write.
@ -881,7 +880,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
})
return
}
ctx, nconn := websocketNetConn(ctx, ws, websocket.MessageBinary)
ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary)
defer nconn.Close()
// Slurp all packets from the connection into io.Discard so pongs get sent
@ -990,7 +989,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
return
}
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
closeCtx, closeCtxCancel := context.WithCancel(ctx)
@ -1077,7 +1076,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
})
return
}
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
go httpapi.Heartbeat(ctx, conn)
@ -2108,47 +2107,6 @@ func createExternalAuthResponse(typ, token string, extra pqtype.NullRawMessage)
return resp, err
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.WorkspaceAgentLog {
sdk := make([]codersdk.WorkspaceAgentLog, 0, len(logs))
for _, logEntry := range logs {

View File

@ -100,7 +100,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
return
}
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
ycfg := yamux.DefaultConfig()

View File

@ -203,7 +203,7 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
return nil, codersdk.ReadBodyAsError(res)
}
_, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
_, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
netConn := &closeNetConn{
Conn: wsNetConn,
@ -596,50 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
return authResp, json.NewDecoder(res.Body).Decode(&authResp)
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
// the default because some of our protocols can include large messages like startup scripts.
conn.SetReadLimit(1 << 22)
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
// LogsNotifyChannel returns the channel name responsible for notifying
// of new logs.
func LogsNotifyChannel(agentID uuid.UUID) string {

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"time"
@ -248,7 +247,7 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
// Use background context because caller should close the client.
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
_, wsNetConn := WebsocketNetConn(context.Background(), conn, websocket.MessageBinary)
session, err := yamux.Client(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
@ -257,45 +256,3 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
}
return proto.NewDRPCProvisionerDaemonClient(drpc.MultiplexedConn(session)), nil
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

53
codersdk/websocket.go Normal file
View File

@ -0,0 +1,53 @@
package codersdk
import (
"context"
"net"
"nhooyr.io/websocket"
)
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// WebsocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func WebsocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
// the default because some of our protocols can include large messages like startup scripts.
conn.SetReadLimit(1 << 22)
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

View File

@ -0,0 +1,80 @@
package codersdk_test
import (
"crypto/rand"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nhooyr.io/websocket"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// TestWebsocketNetConn_LargeWrites tests that we can write large amounts of data thru the netconn
// in a single write. Without specifically setting the read limit, the websocket library limits
// the amount of data that can be read in a single message to 32kiB. Even after raising the limit,
// curiously, it still only reads 32kiB per Read(), but allows the large write to go thru.
func TestWebsocketNetConn_LargeWrites(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
n := 4 * 1024 * 1024 // 4 MiB
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
_, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
defer nc.Close()
// Although the writes are all in one go, the reads get broken up by
// the library.
j := 0
b := make([]byte, n)
for j < n {
k, err := nc.Read(b[j:])
if !assert.NoError(t, err) {
return
}
j += k
t.Logf("server read %d bytes, total %d", k, j)
}
assert.Equal(t, n, j)
j, err = nc.Write(b)
assert.Equal(t, n, j)
if !assert.NoError(t, err) {
return
}
}))
// use of random data is worst case scenario for compression
cb := make([]byte, n)
rk, err := rand.Read(cb)
require.NoError(t, err)
require.Equal(t, n, rk)
// nolint: bodyclose
cws, _, err := websocket.Dial(ctx, svr.URL, nil)
require.NoError(t, err)
_, cnc := codersdk.WebsocketNetConn(ctx, cws, websocket.MessageBinary)
ck, err := cnc.Write(cb)
require.NoError(t, err)
require.Equal(t, n, ck)
cb2 := make([]byte, n)
j := 0
for j < n {
k, err := cnc.Read(cb2[j:])
if !assert.NoError(t, err) {
return
}
j += k
t.Logf("client read %d bytes, total %d", k, j)
}
require.NoError(t, err)
require.Equal(t, n, j)
require.Equal(t, cb, cb2)
}

View File

@ -844,7 +844,7 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID,
}
logChunks := make(chan []WorkspaceAgentLog, 1)
closed := make(chan struct{})
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText)
decoder := json.NewDecoder(wsNetConn)
go func() {
defer close(closed)

View File

@ -50,7 +50,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
if !assert.NoError(t, err) {
return
}
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: clientID,

View File

@ -7,13 +7,10 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/coder/coder/v2/provisionersdk"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"github.com/moby/moby/pkg/namesgenerator"
@ -37,6 +34,7 @@ import (
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
)
func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler {
@ -297,7 +295,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
// the same connection.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
session, err := yamux.Server(wsNetConn, config)
if err != nil {
@ -360,44 +358,3 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}
_ = conn.Close(websocket.StatusGoingAway, "")
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

View File

@ -57,7 +57,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
return
}
ctx, nc := websocketNetConn(ctx, conn, msgType)
ctx, nc := codersdk.WebsocketNetConn(ctx, conn, msgType)
defer nc.Close()
id := uuid.New()