From e5ba586e305ae13b277a76829e633715a7e2e097 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 5 Feb 2024 14:01:37 +0400 Subject: [PATCH] fix: fix graceful disconnect in DialWorkspaceAgent (#11993) I noticed in testing that the CLI wasn't correctly sending the disconnect message when it shuts down, and thus agents are seeing this as a "lost" peer, rather than a "disconnected" one. What was happening is that we just used a single context for everything from the netconn to the RPCs, and when the context was canceled we failed to send the disconnect message due to canceled context. So, this PR splits things into two contexts, with a graceful one set to last up to 1 second longer than the main one. --- codersdk/workspaceagents.go | 120 +++++++++++++++------- codersdk/workspaceagents_internal_test.go | 106 +++++++++++++++++++ tailnet/coordinator.go | 3 +- 3 files changed, 190 insertions(+), 39 deletions(-) create mode 100644 codersdk/workspaceagents_internal_test.go diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index f6c7fbdab6..7109bd747d 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -12,10 +12,9 @@ import ( "net/netip" "strconv" "strings" + "sync" "time" - "golang.org/x/sync/errgroup" - "github.com/google/uuid" "golang.org/x/xerrors" "nhooyr.io/websocket" @@ -360,6 +359,15 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, return agentConn, nil } +// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is +// included so that we can fake it in testing. +// +// @typescript-ignore tailnetConn +type tailnetConn interface { + tailnet.Coordinatee + SetDERPMap(derpMap *tailcfg.DERPMap) +} + // tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to // // 1) run the Coordinate API and pass node information back and forth @@ -370,13 +378,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, // // @typescript-ignore tailnetAPIConnector type tailnetAPIConnector struct { - ctx context.Context + // We keep track of two contexts: the main context from the caller, and a "graceful" context + // that we keep open slightly longer than the main context to give a chance to send the + // Disconnect message to the coordinator. That tells the coordinator that we really meant to + // disconnect instead of just losing network connectivity. + ctx context.Context + gracefulCtx context.Context + cancelGracefulCtx context.CancelFunc + logger slog.Logger agentID uuid.UUID coordinateURL string dialOptions *websocket.DialOptions - conn *tailnet.Conn + conn tailnetConn connected chan error isFirst bool @@ -387,7 +402,7 @@ type tailnetAPIConnector struct { func runTailnetAPIConnector( ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions, - conn *tailnet.Conn, + conn tailnetConn, ) *tailnetAPIConnector { tac := &tailnetAPIConnector{ ctx: ctx, @@ -399,10 +414,23 @@ func runTailnetAPIConnector( connected: make(chan error, 1), closed: make(chan struct{}), } + tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background()) + go tac.manageGracefulTimeout() go tac.run() return tac } +// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context +// to allow a graceful disconnect. +func (tac *tailnetAPIConnector) manageGracefulTimeout() { + defer tac.cancelGracefulCtx() + <-tac.ctx.Done() + select { + case <-tac.closed: + case <-time.After(time.Second): + } +} + func (tac *tailnetAPIConnector) run() { tac.isFirst = true defer close(tac.closed) @@ -437,7 +465,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) { return nil, err } client, err := tailnet.NewDRPCClient( - websocket.NetConn(tac.ctx, ws, websocket.MessageBinary), + websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary), tac.logger, ) if err != nil { @@ -464,65 +492,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli <-conn.Closed() } }() - eg, egCtx := errgroup.WithContext(tac.ctx) - eg.Go(func() error { - return tac.coordinate(egCtx, client) - }) - eg.Go(func() error { - return tac.derpMap(egCtx, client) - }) - err := eg.Wait() - if err != nil && - !xerrors.Is(err, io.EOF) && - !xerrors.Is(err, context.Canceled) && - !xerrors.Is(err, context.DeadlineExceeded) { - tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API") - } + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + tac.coordinate(client) + }() + go func() { + defer wg.Done() + dErr := tac.derpMap(client) + if dErr != nil && tac.ctx.Err() == nil { + // The main context is still active, meaning that we want the tailnet data plane to stay + // up, even though we hit some error getting DERP maps on the control plane. That means + // we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just + // close the underlying connection. This will trigger a retry of the control plane in + // run(). + client.DRPCConn().Close() + // Note that derpMap() logs it own errors, we don't bother here. + } + }() + wg.Wait() } -func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error { - coord, err := client.Coordinate(ctx) +func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { + // we use the gracefulCtx here so that we'll have time to send the graceful disconnect + coord, err := client.Coordinate(tac.gracefulCtx) if err != nil { - return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err) + tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err)) + return } defer func() { cErr := coord.Close() if cErr != nil { - tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr)) + tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr)) } }() coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID) - tac.logger.Debug(ctx, "serving coordinator") - err = <-coordination.Error() - if err != nil && - !xerrors.Is(err, io.EOF) && - !xerrors.Is(err, context.Canceled) && - !xerrors.Is(err, context.DeadlineExceeded) { - return xerrors.Errorf("remote coordination error: %w", err) + tac.logger.Debug(tac.ctx, "serving coordinator") + select { + case <-tac.ctx.Done(): + tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") + crdErr := coordination.Close() + if crdErr != nil { + tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err)) + } + case err = <-coordination.Error(): + if err != nil && + !xerrors.Is(err, io.EOF) && + !xerrors.Is(err, context.Canceled) && + !xerrors.Is(err, context.DeadlineExceeded) { + tac.logger.Error(tac.ctx, "remote coordination error: %w", err) + } } - return nil } -func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error { - s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{}) +func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error { + s, err := client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{}) if err != nil { return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err) } defer func() { cErr := s.Close() if cErr != nil { - tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) + tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr)) } }() for { dmp, err := s.Recv() if err != nil { - if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { return nil } - return xerrors.Errorf("error receiving DERP Map: %w", err) + tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err)) + return err } - tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp)) + tac.logger.Debug(tac.ctx, "got new DERP Map", slog.F("derp_map", dmp)) dm := tailnet.DERPMapFromProto(dmp) tac.conn.SetDERPMap(dm) } diff --git a/codersdk/workspaceagents_internal_test.go b/codersdk/workspaceagents_internal_test.go new file mode 100644 index 0000000000..38854114c1 --- /dev/null +++ b/codersdk/workspaceagents_internal_test.go @@ -0,0 +1,106 @@ +package codersdk + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "nhooyr.io/websocket" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" +) + +func TestTailnetAPIConnector_Disconnects(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + logger := slogtest.Make(t, &slogtest.Options{ + // we get EOF when we simulate a DERPMap error + IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF), + }).Leveled(slog.LevelDebug) + agentID := uuid.UUID{0x55} + clientID := uuid.UUID{0x66} + fCoord := tailnettest.NewFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + derpMapCh := make(chan *tailcfg.DERPMap) + defer close(derpMapCh) + svc, err := tailnet.NewClientService( + logger, &coordPtr, + time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh }, + ) + require.NoError(t, err) + + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sws, err := websocket.Accept(w, r, nil) + if !assert.NoError(t, err) { + return + } + ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary) + err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ + Name: "client", + ID: clientID, + Auth: tailnet.ClientTunnelAuth{AgentID: agentID}, + }) + assert.NoError(t, err) + })) + + fConn := newFakeTailnetConn() + + uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn) + + call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) + reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs) + require.NotNil(t, reqTun.AddTunnel) + + _ = testutil.RequireRecvCtx(ctx, t, uut.connected) + + // simulate a problem with DERPMaps by sending nil + testutil.RequireSendCtx(ctx, t, derpMapCh, nil) + + // this should cause the coordinate call to hang up WITHOUT disconnecting + reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs) + require.Nil(t, reqNil) + + // ...and then reconnect + call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) + reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs) + require.NotNil(t, reqTun.AddTunnel) + + // canceling the context should trigger the disconnect message + cancel() + reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs) + require.NotNil(t, reqDisc) + require.NotNil(t, reqDisc.Disconnect) +} + +type fakeTailnetConn struct{} + +func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error { + // TODO implement me + panic("implement me") +} + +func (*fakeTailnetConn) SetAllPeersLost() {} + +func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {} + +func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {} + +func newFakeTailnetConn() *fakeTailnetConn { + return &fakeTailnetConn{} +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 31e2a9dded..530b42aea3 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -134,6 +134,7 @@ func (c *remoteCoordination) Close() (retErr error) { if err != nil { return xerrors.Errorf("send disconnect: %w", err) } + c.logger.Debug(context.Background(), "sent disconnect") return nil } @@ -167,7 +168,7 @@ func (c *remoteCoordination) respLoop() { } } -// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a +// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a // Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as // a client---agents should NOT set this!). func NewRemoteCoordination(logger slog.Logger,