mirror of https://github.com/coder/coder.git
chore: consolidate websocketNetConn implementations (#12065)
Consolidates websocketNetConn from multiple packages in favor of a central one in codersdk
This commit is contained in:
parent
ec8e41f516
commit
1f5a6d59ba
|
@ -7,7 +7,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
|
@ -544,7 +543,7 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
go httpapi.Heartbeat(ctx, conn)
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
// The Go stdlib JSON encoder appends a newline character after message write.
|
||||
|
@ -881,7 +880,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
return
|
||||
}
|
||||
ctx, nconn := websocketNetConn(ctx, ws, websocket.MessageBinary)
|
||||
ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary)
|
||||
defer nconn.Close()
|
||||
|
||||
// Slurp all packets from the connection into io.Discard so pongs get sent
|
||||
|
@ -990,7 +989,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
closeCtx, closeCtxCancel := context.WithCancel(ctx)
|
||||
|
@ -1077,7 +1076,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
|||
})
|
||||
return
|
||||
}
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.Heartbeat(ctx, conn)
|
||||
|
@ -2108,47 +2107,6 @@ func createExternalAuthResponse(typ, token string, extra pqtype.NullRawMessage)
|
|||
return resp, err
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
||||
|
||||
func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.WorkspaceAgentLog {
|
||||
sdk := make([]codersdk.WorkspaceAgentLog, 0, len(logs))
|
||||
for _, logEntry := range logs {
|
||||
|
|
|
@ -100,7 +100,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
ycfg := yamux.DefaultConfig()
|
||||
|
|
|
@ -203,7 +203,7 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
|
|||
return nil, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
_, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
_, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
|
||||
netConn := &closeNetConn{
|
||||
Conn: wsNetConn,
|
||||
|
@ -596,50 +596,6 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
|
|||
return authResp, json.NewDecoder(res.Body).Decode(&authResp)
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
|
||||
// the default because some of our protocols can include large messages like startup scripts.
|
||||
conn.SetReadLimit(1 << 22)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
||||
|
||||
// LogsNotifyChannel returns the channel name responsible for notifying
|
||||
// of new logs.
|
||||
func LogsNotifyChannel(agentID uuid.UUID) string {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"time"
|
||||
|
@ -248,7 +247,7 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
|
|||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
// Use background context because caller should close the client.
|
||||
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
|
||||
_, wsNetConn := WebsocketNetConn(context.Background(), conn, websocket.MessageBinary)
|
||||
session, err := yamux.Client(wsNetConn, config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
|
@ -257,45 +256,3 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
|
|||
}
|
||||
return proto.NewDRPCProvisionerDaemonClient(drpc.MultiplexedConn(session)), nil
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
// @typescript-ignore wsNetConn
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
// @typescript-ignore wsNetConn
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// WebsocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func WebsocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
// Set the read limit to 4 MiB -- about the limit for protobufs. This needs to be larger than
|
||||
// the default because some of our protocols can include large messages like startup scripts.
|
||||
conn.SetReadLimit(1 << 22)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
package codersdk_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestWebsocketNetConn_LargeWrites tests that we can write large amounts of data thru the netconn
|
||||
// in a single write. Without specifically setting the read limit, the websocket library limits
|
||||
// the amount of data that can be read in a single message to 32kiB. Even after raising the limit,
|
||||
// curiously, it still only reads 32kiB per Read(), but allows the large write to go thru.
|
||||
func TestWebsocketNetConn_LargeWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
n := 4 * 1024 * 1024 // 4 MiB
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
defer nc.Close()
|
||||
|
||||
// Although the writes are all in one go, the reads get broken up by
|
||||
// the library.
|
||||
j := 0
|
||||
b := make([]byte, n)
|
||||
for j < n {
|
||||
k, err := nc.Read(b[j:])
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
j += k
|
||||
t.Logf("server read %d bytes, total %d", k, j)
|
||||
}
|
||||
assert.Equal(t, n, j)
|
||||
j, err = nc.Write(b)
|
||||
assert.Equal(t, n, j)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
}))
|
||||
|
||||
// use of random data is worst case scenario for compression
|
||||
cb := make([]byte, n)
|
||||
rk, err := rand.Read(cb)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, rk)
|
||||
|
||||
// nolint: bodyclose
|
||||
cws, _, err := websocket.Dial(ctx, svr.URL, nil)
|
||||
require.NoError(t, err)
|
||||
_, cnc := codersdk.WebsocketNetConn(ctx, cws, websocket.MessageBinary)
|
||||
ck, err := cnc.Write(cb)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, ck)
|
||||
|
||||
cb2 := make([]byte, n)
|
||||
j := 0
|
||||
for j < n {
|
||||
k, err := cnc.Read(cb2[j:])
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
j += k
|
||||
t.Logf("client read %d bytes, total %d", k, j)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, j)
|
||||
require.Equal(t, cb, cb2)
|
||||
}
|
|
@ -844,7 +844,7 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID,
|
|||
}
|
||||
logChunks := make(chan []WorkspaceAgentLog, 1)
|
||||
closed := make(chan struct{})
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
|
||||
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
decoder := json.NewDecoder(wsNetConn)
|
||||
go func() {
|
||||
defer close(closed)
|
||||
|
|
|
@ -50,7 +50,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: clientID,
|
||||
|
|
|
@ -7,13 +7,10 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
|
@ -37,6 +34,7 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
)
|
||||
|
||||
func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler {
|
||||
|
@ -297,7 +295,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||
// the same connection.
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
session, err := yamux.Server(wsNetConn, config)
|
||||
if err != nil {
|
||||
|
@ -360,44 +358,3 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
ctx, nc := websocketNetConn(ctx, conn, msgType)
|
||||
ctx, nc := codersdk.WebsocketNetConn(ctx, conn, msgType)
|
||||
defer nc.Close()
|
||||
|
||||
id := uuid.New()
|
||||
|
|
Loading…
Reference in New Issue