chore: consolidate websocketNetConn implementations (#12065)

Consolidates websocketNetConn from multiple packages in favor of a central one in codersdk
This commit is contained in:
Spike Curtis 2024-02-09 11:39:08 +04:00 committed by GitHub
parent ec8e41f516
commit 1f5a6d59ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 145 additions and 184 deletions

View File

@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"sort" "sort"
@ -544,7 +543,7 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
} }
go httpapi.Heartbeat(ctx, conn) go httpapi.Heartbeat(ctx, conn)
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
defer wsNetConn.Close() // Also closes conn. defer wsNetConn.Close() // Also closes conn.
// The Go stdlib JSON encoder appends a newline character after message write. // The Go stdlib JSON encoder appends a newline character after message write.
@ -881,7 +880,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
}) })
return return
} }
ctx, nconn := websocketNetConn(ctx, ws, websocket.MessageBinary) ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary)
defer nconn.Close() defer nconn.Close()
// Slurp all packets from the connection into io.Discard so pongs get sent // Slurp all packets from the connection into io.Discard so pongs get sent
@ -990,7 +989,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
return return
} }
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close() defer wsNetConn.Close()
closeCtx, closeCtxCancel := context.WithCancel(ctx) closeCtx, closeCtxCancel := context.WithCancel(ctx)
@ -1077,7 +1076,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
}) })
return return
} }
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close() defer wsNetConn.Close()
go httpapi.Heartbeat(ctx, conn) go httpapi.Heartbeat(ctx, conn)
@ -2108,47 +2107,6 @@ func createExternalAuthResponse(typ, token string, extra pqtype.NullRawMessage)
return resp, err return resp, err
} }
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.WorkspaceAgentLog { func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.WorkspaceAgentLog {
sdk := make([]codersdk.WorkspaceAgentLog, 0, len(logs)) sdk := make([]codersdk.WorkspaceAgentLog, 0, len(logs))
for _, logEntry := range logs { for _, logEntry := range logs {

View File

@ -100,7 +100,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
return return
} }
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close() defer wsNetConn.Close()
ycfg := yamux.DefaultConfig() ycfg := yamux.DefaultConfig()

View File

@ -203,7 +203,7 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
return nil, codersdk.ReadBodyAsError(res) return nil, codersdk.ReadBodyAsError(res)
} }
_, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) _, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
netConn := &closeNetConn{ netConn := &closeNetConn{
Conn: wsNetConn, Conn: wsNetConn,
@ -596,50 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
return authResp, json.NewDecoder(res.Body).Decode(&authResp) return authResp, json.NewDecoder(res.Body).Decode(&authResp)
} }
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
// the default because some of our protocols can include large messages like startup scripts.
conn.SetReadLimit(1 << 22)
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
// LogsNotifyChannel returns the channel name responsible for notifying // LogsNotifyChannel returns the channel name responsible for notifying
// of new logs. // of new logs.
func LogsNotifyChannel(agentID uuid.UUID) string { func LogsNotifyChannel(agentID uuid.UUID) string {

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"time" "time"
@ -248,7 +247,7 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
config := yamux.DefaultConfig() config := yamux.DefaultConfig()
config.LogOutput = io.Discard config.LogOutput = io.Discard
// Use background context because caller should close the client. // Use background context because caller should close the client.
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary) _, wsNetConn := WebsocketNetConn(context.Background(), conn, websocket.MessageBinary)
session, err := yamux.Client(wsNetConn, config) session, err := yamux.Client(wsNetConn, config)
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "") _ = conn.Close(websocket.StatusGoingAway, "")
@ -257,45 +256,3 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
} }
return proto.NewDRPCProvisionerDaemonClient(drpc.MultiplexedConn(session)), nil return proto.NewDRPCProvisionerDaemonClient(drpc.MultiplexedConn(session)), nil
} }
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

53
codersdk/websocket.go Normal file
View File

@ -0,0 +1,53 @@
package codersdk
import (
"context"
"net"
"nhooyr.io/websocket"
)
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// WebsocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func WebsocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
// the default because some of our protocols can include large messages like startup scripts.
conn.SetReadLimit(1 << 22)
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

View File

@ -0,0 +1,80 @@
package codersdk_test
import (
"crypto/rand"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nhooyr.io/websocket"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// TestWebsocketNetConn_LargeWrites tests that we can write large amounts of data thru the netconn
// in a single write. Without specifically setting the read limit, the websocket library limits
// the amount of data that can be read in a single message to 32kiB. Even after raising the limit,
// curiously, it still only reads 32kiB per Read(), but allows the large write to go thru.
func TestWebsocketNetConn_LargeWrites(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
n := 4 * 1024 * 1024 // 4 MiB
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
_, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
defer nc.Close()
// Although the writes are all in one go, the reads get broken up by
// the library.
j := 0
b := make([]byte, n)
for j < n {
k, err := nc.Read(b[j:])
if !assert.NoError(t, err) {
return
}
j += k
t.Logf("server read %d bytes, total %d", k, j)
}
assert.Equal(t, n, j)
j, err = nc.Write(b)
assert.Equal(t, n, j)
if !assert.NoError(t, err) {
return
}
}))
// use of random data is worst case scenario for compression
cb := make([]byte, n)
rk, err := rand.Read(cb)
require.NoError(t, err)
require.Equal(t, n, rk)
// nolint: bodyclose
cws, _, err := websocket.Dial(ctx, svr.URL, nil)
require.NoError(t, err)
_, cnc := codersdk.WebsocketNetConn(ctx, cws, websocket.MessageBinary)
ck, err := cnc.Write(cb)
require.NoError(t, err)
require.Equal(t, n, ck)
cb2 := make([]byte, n)
j := 0
for j < n {
k, err := cnc.Read(cb2[j:])
if !assert.NoError(t, err) {
return
}
j += k
t.Logf("client read %d bytes, total %d", k, j)
}
require.NoError(t, err)
require.Equal(t, n, j)
require.Equal(t, cb, cb2)
}

View File

@ -844,7 +844,7 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID,
} }
logChunks := make(chan []WorkspaceAgentLog, 1) logChunks := make(chan []WorkspaceAgentLog, 1)
closed := make(chan struct{}) closed := make(chan struct{})
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText)
decoder := json.NewDecoder(wsNetConn) decoder := json.NewDecoder(wsNetConn)
go func() { go func() {
defer close(closed) defer close(closed)

View File

@ -50,7 +50,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary) ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client", Name: "client",
ID: clientID, ID: clientID,

View File

@ -7,13 +7,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/coder/coder/v2/provisionersdk"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hashicorp/yamux" "github.com/hashicorp/yamux"
"github.com/moby/moby/pkg/namesgenerator" "github.com/moby/moby/pkg/namesgenerator"
@ -37,6 +34,7 @@ import (
"github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
) )
func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler { func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler {
@ -297,7 +295,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
// the same connection. // the same connection.
config := yamux.DefaultConfig() config := yamux.DefaultConfig()
config.LogOutput = io.Discard config.LogOutput = io.Discard
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close() defer wsNetConn.Close()
session, err := yamux.Server(wsNetConn, config) session, err := yamux.Server(wsNetConn, config)
if err != nil { if err != nil {
@ -360,44 +358,3 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
} }
_ = conn.Close(websocket.StatusGoingAway, "") _ = conn.Close(websocket.StatusGoingAway, "")
} }
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

View File

@ -57,7 +57,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
return return
} }
ctx, nc := websocketNetConn(ctx, conn, msgType) ctx, nc := codersdk.WebsocketNetConn(ctx, conn, msgType)
defer nc.Close() defer nc.Close()
id := uuid.New() id := uuid.New()