coder/enterprise/tailnet/workspaceproxy.go

158 lines
4.3 KiB
Go

package tailnet
import (
"bytes"
"context"
"encoding/json"
"errors"
"net"
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
)
type ClientService struct {
*agpl.ClientService
}
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
// loaded on each processed connection.
func NewClientService(
logger slog.Logger,
coordPtr *atomic.Pointer[agpl.Coordinator],
derpMapUpdateFrequency time.Duration,
derpMapFn func() *tailcfg.DERPMap,
) (
*ClientService, error,
) {
s, err := agpl.NewClientService(logger, coordPtr, derpMapUpdateFrequency, derpMapFn)
if err != nil {
return nil, err
}
return &ClientService{ClientService: s}, nil
}
func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID) error {
major, _, err := apiversion.Parse(version)
if err != nil {
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
return err
}
switch major {
case 1:
coord := *(s.CoordPtr.Load())
sub := coord.ServeMultiAgent(id)
return ServeWorkspaceProxy(ctx, conn, sub)
case 2:
auth := agpl.SingleTailnetTunnelAuth{}
streamID := agpl.StreamID{
Name: id.String(),
ID: id,
Auth: auth,
}
return s.ServeConnV2(ctx, conn, streamID)
default:
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
return xerrors.New("unsupported version")
}
}
func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
go func() {
err := forwardNodesToWorkspaceProxy(ctx, conn, ma)
if err != nil {
_ = conn.Close()
}
}()
decoder := json.NewDecoder(conn)
for {
var msg wsproxysdk.CoordinateMessage
err := decoder.Decode(&msg)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return xerrors.Errorf("read json: %w", err)
}
switch msg.Type {
case wsproxysdk.CoordinateMessageTypeSubscribe:
err := ma.SubscribeAgent(msg.AgentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
case wsproxysdk.CoordinateMessageTypeUnsubscribe:
err := ma.UnsubscribeAgent(msg.AgentID)
if err != nil {
return xerrors.Errorf("unsubscribe agent: %w", err)
}
case wsproxysdk.CoordinateMessageTypeNodeUpdate:
pn, err := agpl.NodeToProto(msg.Node)
if err != nil {
return err
}
err = ma.UpdateSelf(pn)
if err != nil {
return xerrors.Errorf("update self: %w", err)
}
default:
return xerrors.Errorf("unknown message type %q", msg.Type)
}
}
}
func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
var lastData []byte
for {
resp, ok := ma.NextUpdate(ctx)
if !ok {
return xerrors.New("multiagent is closed")
}
nodes, err := agpl.OnlyNodeUpdates(resp)
if err != nil {
return xerrors.Errorf("failed to convert response: %w", err)
}
data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes})
if err != nil {
return err
}
if bytes.Equal(lastData, data) {
continue
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = conn.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
return err
}
_, err = conn.Write(data)
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
return err
}
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
// our successful write, it is important that we reset the deadline before it fires.
err = conn.SetWriteDeadline(time.Time{})
if err != nil {
return err
}
lastData = data
}
}