feat: change codersdk to use tailnet v2 for DERPMap updates (#11736)

fixes #10533


refactors `codersdk` workspace agent dialer to use a single websocket connection to the tailnet v2 API for both coordination and DERPMap updates, rather than separate websockets (and the v1 API for DERPMaps).
This commit is contained in:
Spike Curtis 2024-01-29 11:26:50 +04:00 committed by GitHub
parent 699a4b8dd4
commit f9fdd44510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 188 additions and 136 deletions

View File

@ -14,6 +14,8 @@ import (
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
@ -317,142 +319,28 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
q := coordinateURL.Query()
q.Add("version", proto.CurrentVersion.String())
coordinateURL.RawQuery = q.Encode()
closedCoordinator := make(chan struct{})
// Must only ever be used once, send error OR close to avoid
// reassignment race. Buffered so we don't hang in goroutine.
firstCoordinator := make(chan error, 1)
go func() {
defer close(closedCoordinator)
isFirst := true
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
options.Logger.Debug(ctx, "connecting")
// nolint:bodyclose
ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
HTTPClient: c.HTTPClient,
HTTPHeader: headers,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if isFirst {
if res != nil && res.StatusCode == http.StatusConflict {
firstCoordinator <- ReadBodyAsError(res)
return
}
isFirst = false
close(firstCoordinator)
}
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
continue
}
client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary))
if err != nil {
options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
continue
}
coordinate, err := client.Coordinate(ctx)
if err != nil {
options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
continue
}
coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID)
options.Logger.Debug(ctx, "serving coordinator")
err = <-coordination.Error()
if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusGoingAway, "")
return
}
if err != nil {
options.Logger.Debug(ctx, "error serving coordinator", slog.Error(err))
_ = ws.Close(websocket.StatusGoingAway, "")
continue
}
_ = ws.Close(websocket.StatusGoingAway, "")
}
}()
derpMapURL, err := c.URL.Parse("/api/v2/derp-map")
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
closedDerpMap := make(chan struct{})
// Must only ever be used once, send error OR close to avoid
// reassignment race. Buffered so we don't hang in goroutine.
firstDerpMap := make(chan error, 1)
go func() {
defer close(closedDerpMap)
isFirst := true
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
options.Logger.Debug(ctx, "connecting to server for derp map updates")
// nolint:bodyclose
ws, res, err := websocket.Dial(ctx, derpMapURL.String(), &websocket.DialOptions{
HTTPClient: c.HTTPClient,
HTTPHeader: headers,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if isFirst {
if res != nil && res.StatusCode == http.StatusConflict {
firstDerpMap <- ReadBodyAsError(res)
return
}
isFirst = false
close(firstDerpMap)
}
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
continue
}
var (
nconn = websocket.NetConn(ctx, ws, websocket.MessageBinary)
dec = json.NewDecoder(nconn)
)
for {
var derpMap tailcfg.DERPMap
err := dec.Decode(&derpMap)
if xerrors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusGoingAway, "")
return
}
if err != nil {
options.Logger.Debug(ctx, "failed to decode derp map", slog.Error(err))
_ = ws.Close(websocket.StatusGoingAway, "")
return
}
if !tailnet.CompareDERPMaps(conn.DERPMap(), &derpMap) {
options.Logger.Debug(ctx, "updating derp map due to detected changes")
conn.SetDERPMap(&derpMap)
}
}
}
}()
for firstCoordinator != nil || firstDerpMap != nil {
select {
case <-dialCtx.Done():
return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err())
case err = <-firstCoordinator:
if err != nil {
return nil, xerrors.Errorf("start coordinator: %w", err)
}
firstCoordinator = nil
case err = <-firstDerpMap:
if err != nil {
return nil, xerrors.Errorf("receive derp map: %w", err)
}
firstDerpMap = nil
connector := runTailnetAPIConnector(ctx, options.Logger,
agentID, coordinateURL.String(),
&websocket.DialOptions{
HTTPClient: c.HTTPClient,
HTTPHeader: headers,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
},
conn,
)
options.Logger.Debug(ctx, "running tailnet API v2+ connector")
select {
case <-dialCtx.Done():
return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err())
case err = <-connector.connected:
if err != nil {
options.Logger.Error(ctx, "failed to connect to tailnet v2+ API", slog.Error(err))
return nil, xerrors.Errorf("start connector: %w", err)
}
options.Logger.Debug(ctx, "connected to tailnet v2+ API")
}
agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{
@ -464,8 +352,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
AgentIP: WorkspaceAgentIP,
CloseFunc: func() error {
cancel()
<-closedCoordinator
<-closedDerpMap
<-connector.closed
return conn.Close()
},
})
@ -478,6 +365,171 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
return agentConn, nil
}
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
//
// 1) run the Coordinate API and pass node information back and forth
// 2) stream DERPMap updates and program the Conn
//
// These functions share the same websocket, and so are combined here so that if we hit a problem
// we tear the whole thing down and start over with a new websocket.
//
// @typescript-ignore tailnetAPIConnector
type tailnetAPIConnector struct {
ctx context.Context
logger slog.Logger
agentID uuid.UUID
coordinateURL string
dialOptions *websocket.DialOptions
conn *tailnet.Conn
connected chan error
isFirst bool
closed chan struct{}
}
// runTailnetAPIConnector creates and runs a tailnetAPIConnector
func runTailnetAPIConnector(
ctx context.Context, logger slog.Logger,
agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions,
conn *tailnet.Conn,
) *tailnetAPIConnector {
tac := &tailnetAPIConnector{
ctx: ctx,
logger: logger,
agentID: agentID,
coordinateURL: coordinateURL,
dialOptions: dialOptions,
conn: conn,
connected: make(chan error, 1),
closed: make(chan struct{}),
}
go tac.run()
return tac
}
func (tac *tailnetAPIConnector) run() {
tac.isFirst = true
defer close(tac.closed)
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
tailnetClient, err := tac.dial()
if err != nil {
continue
}
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
tac.coordinateAndDERPMap(tailnetClient)
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
}
}
func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
// nolint:bodyclose
ws, res, err := websocket.Dial(tac.ctx, tac.coordinateURL, tac.dialOptions)
if tac.isFirst {
if res != nil && res.StatusCode == http.StatusConflict {
err = ReadBodyAsError(res)
tac.connected <- err
return nil, err
}
tac.isFirst = false
close(tac.connected)
}
if err != nil {
if !errors.Is(err, context.Canceled) {
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err))
}
return nil, err
}
client, err := tailnet.NewDRPCClient(websocket.NetConn(tac.ctx, ws, websocket.MessageBinary))
if err != nil {
tac.logger.Debug(tac.ctx, "failed to create DRPCClient", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
return nil, err
}
return client, err
}
// coordinateAndDERPMap uses the provided client to coordinate and stream DERP Maps. It is combined
// into one function so that a problem with one tears down the other and triggers a retry (if
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
// fate.
func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetClient) {
defer func() {
conn := client.DRPCConn()
closeErr := conn.Close()
if closeErr != nil &&
!xerrors.Is(closeErr, io.EOF) &&
!xerrors.Is(closeErr, context.Canceled) &&
!xerrors.Is(closeErr, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr))
<-conn.Closed()
}
}()
eg, egCtx := errgroup.WithContext(tac.ctx)
eg.Go(func() error {
return tac.coordinate(egCtx, client)
})
eg.Go(func() error {
return tac.derpMap(egCtx, client)
})
err := eg.Wait()
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API")
}
}
func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error {
coord, err := client.Coordinate(ctx)
if err != nil {
return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err)
}
defer func() {
cErr := coord.Close()
if cErr != nil {
tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr))
}
}()
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
tac.logger.Debug(ctx, "serving coordinator")
err = <-coordination.Error()
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
return xerrors.Errorf("remote coordination error: %w", err)
}
return nil
}
func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error {
s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
if err != nil {
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
}
defer func() {
cErr := s.Close()
if cErr != nil {
tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
}
}()
for {
dmp, err := s.Recv()
if err != nil {
if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return nil
}
return xerrors.Errorf("error receiving DERP Map: %w", err)
}
tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp))
dm := tailnet.DERPMapFromProto(dmp)
tac.conn.SetDERPMap(dm)
}
}
// WatchWorkspaceAgentMetadata watches the metadata of a workspace agent.
// The returned channel will be closed when the context is canceled. Exactly
// one error will be sent on the error channel. The metadata channel is never closed.