diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index f6c7fbdab6..7109bd747d 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -12,10 +12,9 @@ import ( "net/netip" "strconv" "strings" + "sync" "time" - "golang.org/x/sync/errgroup" - "github.com/google/uuid" "golang.org/x/xerrors" "nhooyr.io/websocket" @@ -360,6 +359,15 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, return agentConn, nil } +// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is +// included so that we can fake it in testing. +// +// @typescript-ignore tailnetConn +type tailnetConn interface { + tailnet.Coordinatee + SetDERPMap(derpMap *tailcfg.DERPMap) +} + // tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to // // 1) run the Coordinate API and pass node information back and forth @@ -370,13 +378,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, // // @typescript-ignore tailnetAPIConnector type tailnetAPIConnector struct { - ctx context.Context + // We keep track of two contexts: the main context from the caller, and a "graceful" context + // that we keep open slightly longer than the main context to give a chance to send the + // Disconnect message to the coordinator. That tells the coordinator that we really meant to + // disconnect instead of just losing network connectivity. + ctx context.Context + gracefulCtx context.Context + cancelGracefulCtx context.CancelFunc + logger slog.Logger agentID uuid.UUID coordinateURL string dialOptions *websocket.DialOptions - conn *tailnet.Conn + conn tailnetConn connected chan error isFirst bool @@ -387,7 +402,7 @@ type tailnetAPIConnector struct { func runTailnetAPIConnector( ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions, - conn *tailnet.Conn, + conn tailnetConn, ) *tailnetAPIConnector { tac := &tailnetAPIConnector{ ctx: ctx, @@ -399,10 +414,23 @@ func runTailnetAPIConnector( connected: make(chan error, 1), closed: make(chan struct{}), } + tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background()) + go tac.manageGracefulTimeout() go tac.run() return tac } +// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context +// to allow a graceful disconnect. +func (tac *tailnetAPIConnector) manageGracefulTimeout() { + defer tac.cancelGracefulCtx() + <-tac.ctx.Done() + select { + case <-tac.closed: + case <-time.After(time.Second): + } +} + func (tac *tailnetAPIConnector) run() { tac.isFirst = true defer close(tac.closed) @@ -437,7 +465,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) { return nil, err } client, err := tailnet.NewDRPCClient( - websocket.NetConn(tac.ctx, ws, websocket.MessageBinary), + websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary), tac.logger, ) if err != nil { @@ -464,65 +492,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli <-conn.Closed() } }() - eg, egCtx := errgroup.WithContext(tac.ctx) - eg.Go(func() error { - return tac.coordinate(egCtx, client) - }) - eg.Go(func() error { - return tac.derpMap(egCtx, client) - }) - err := eg.Wait() - if err != nil && - !xerrors.Is(err, io.EOF) && - !xerrors.Is(err, context.Canceled) && - !xerrors.Is(err, context.DeadlineExceeded) { - tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API") - } + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + tac.coordinate(client) + }() + go func() { + defer wg.Done() + dErr := tac.derpMap(client) + if dErr != nil && tac.ctx.Err() == nil { + // The main context is still active, meaning that we want the tailnet data plane to stay + // up, even though we hit some error getting DERP maps on the control plane. That means + // we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just + // close the underlying connection. This will trigger a retry of the control plane in + // run(). + client.DRPCConn().Close() + // Note that derpMap() logs it own errors, we don't bother here. + } + }() + wg.Wait() } -func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error { - coord, err := client.Coordinate(ctx) +func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { + // we use the gracefulCtx here so that we'll have time to send the graceful disconnect + coord, err := client.Coordinate(tac.gracefulCtx) if err != nil { - return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err) + tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err)) + return } defer func() { cErr := coord.Close() if cErr != nil { - tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr)) + tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr)) } }() coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID) - tac.logger.Debug(ctx, "serving coordinator") - err = <-coordination.Error() - if err != nil && - !xerrors.Is(err, io.EOF) && - !xerrors.Is(err, context.Canceled) && - !xerrors.Is(err, context.DeadlineExceeded) { - return xerrors.Errorf("remote coordination error: %w", err) + tac.logger.Debug(tac.ctx, "serving coordinator") + select { + case <-tac.ctx.Done(): + tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") + crdErr := coordination.Close() + if crdErr != nil { + tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err)) + } + case err = <-coordination.Error(): + if err != nil && + !xerrors.Is(err, io.EOF) && + !xerrors.Is(err, context.Canceled) && + !xerrors.Is(err, context.DeadlineExceeded) { + tac.logger.Error(tac.ctx, "remote coordination error: %w", err) + } } - return nil } -func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error { - s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{}) +func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error { + s, err := client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{}) if err != nil { return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err) } defer func() { cErr := s.Close() if cErr != nil { - tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) + tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) } }() for { dmp, err := s.Recv() if err != nil { - if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { return nil } - return xerrors.Errorf("error receiving DERP Map: %w", err) + tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err)) + return err } - tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp)) + tac.logger.Debug(tac.ctx, "got new DERP Map", slog.F("derp_map", dmp)) dm := tailnet.DERPMapFromProto(dmp) tac.conn.SetDERPMap(dm) } diff --git a/codersdk/workspaceagents_internal_test.go b/codersdk/workspaceagents_internal_test.go new file mode 100644 index 0000000000..38854114c1 --- /dev/null +++ b/codersdk/workspaceagents_internal_test.go @@ -0,0 +1,106 @@ +package codersdk + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "nhooyr.io/websocket" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" +) + +func TestTailnetAPIConnector_Disconnects(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + logger := slogtest.Make(t, &slogtest.Options{ + // we get EOF when we simulate a DERPMap error + IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF), + }).Leveled(slog.LevelDebug) + agentID := uuid.UUID{0x55} + clientID := uuid.UUID{0x66} + fCoord := tailnettest.NewFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + derpMapCh := make(chan *tailcfg.DERPMap) + defer close(derpMapCh) + svc, err := tailnet.NewClientService( + logger, &coordPtr, + time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh }, + ) + require.NoError(t, err) + + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sws, err := websocket.Accept(w, r, nil) + if !assert.NoError(t, err) { + return + } + ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary) + err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ + Name: "client", + ID: clientID, + Auth: tailnet.ClientTunnelAuth{AgentID: agentID}, + }) + assert.NoError(t, err) + })) + + fConn := newFakeTailnetConn() + + uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn) + + call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) + reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs) + require.NotNil(t, reqTun.AddTunnel) + + _ = testutil.RequireRecvCtx(ctx, t, uut.connected) + + // simulate a problem with DERPMaps by sending nil + testutil.RequireSendCtx(ctx, t, derpMapCh, nil) + + // this should cause the coordinate call to hang up WITHOUT disconnecting + reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs) + require.Nil(t, reqNil) + + // ...and then reconnect + call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) + reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs) + require.NotNil(t, reqTun.AddTunnel) + + // canceling the context should trigger the disconnect message + cancel() + reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs) + require.NotNil(t, reqDisc) + require.NotNil(t, reqDisc.Disconnect) +} + +type fakeTailnetConn struct{} + +func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error { + // TODO implement me + panic("implement me") +} + +func (*fakeTailnetConn) SetAllPeersLost() {} + +func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {} + +func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {} + +func newFakeTailnetConn() *fakeTailnetConn { + return &fakeTailnetConn{} +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 31e2a9dded..530b42aea3 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -134,6 +134,7 @@ func (c *remoteCoordination) Close() (retErr error) { if err != nil { return xerrors.Errorf("send disconnect: %w", err) } + c.logger.Debug(context.Background(), "sent disconnect") return nil } @@ -167,7 +168,7 @@ func (c *remoteCoordination) respLoop() { } } -// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a +// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a // Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as // a client---agents should NOT set this!). func NewRemoteCoordination(logger slog.Logger,