From 520b12e1a26c3feddef073a458811c0abea40577 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 31 Jan 2024 00:38:19 +0400 Subject: [PATCH] fix: close MultiAgentConn when coordinator closes (#11941) Fixes an issue where a MultiAgentConn isn't closed properly when the coordinator it is connected to is closed. Since servertailnet checks whether the conn is closed before reinitializing, it is important that we check this, otherwise servertailnet can get stuck if the coordinator closes (e.g. when we switch from AGPL to PGCoordinator after decoding a license). --- enterprise/tailnet/connio.go | 27 +-- enterprise/tailnet/multiagent_test.go | 226 +++++++++----------------- tailnet/coordinator.go | 8 +- tailnet/coordinator_test.go | 18 ++ tailnet/tailnettest/tailnettest.go | 124 ++++++++++++++ 5 files changed, 236 insertions(+), 167 deletions(-) diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go index 45d9c71c3e..6e98dfec4c 100644 --- a/enterprise/tailnet/connio.go +++ b/enterprise/tailnet/connio.go @@ -2,7 +2,6 @@ package tailnet import ( "context" - "io" "sync" "sync/atomic" "time" @@ -104,19 +103,21 @@ func (c *connIO) recvLoop() { }() defer c.Close() for { - req, err := agpl.RecvCtx(c.peerCtx, c.requests) - if err != nil { - if xerrors.Is(err, context.Canceled) || - xerrors.Is(err, context.DeadlineExceeded) || - xerrors.Is(err, io.EOF) { - c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err)) - } else { - c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err)) + select { + case <-c.coordCtx.Done(): + c.logger.Debug(c.coordCtx, "exiting io recvLoop; coordinator exit") + return + case <-c.peerCtx.Done(): + c.logger.Debug(c.peerCtx, "exiting io recvLoop; peer context canceled") + return + case req, ok := <-c.requests: + if !ok { + c.logger.Debug(c.peerCtx, "exiting io recvLoop; requests chan closed") + return + } + if err := c.handleRequest(req); err != nil { + return } - return - } - if err := c.handleRequest(req); err != nil { - return } } } diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go index c9f8f73fe9..bbb3c55735 100644 --- a/enterprise/tailnet/multiagent_test.go +++ b/enterprise/tailnet/multiagent_test.go @@ -4,17 +4,14 @@ import ( "context" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/require" - "golang.org/x/exp/slices" - "tailscale.com/types/key" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/tailnet" agpl "github.com/coder/coder/v2/tailnet" - "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) @@ -42,25 +39,48 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord1) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.id) } +func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() + + err = coord1.Close() + require.NoError(t, err) + + ma1.RequireEventuallyClosed(ctx) +} + // TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with // a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe // with the MultiAgent closing. @@ -86,20 +106,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord1) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - ma1.unsubscribeAgent(agent1.id) - ma1.close() + ma1.RequireUnsubscribeAgent(agent1.id) + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -131,35 +151,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord1) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - ma1.unsubscribeAgent(agent1.id) + ma1.RequireUnsubscribeAgent(agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() - ma1.sendNodeWithDERP(9) + ma1.SendNodeWithDERP(9) assertNeverHasDERPs(ctx, t, agent1, 9) }() func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() agent1.sendNode(&agpl.Node{PreferredDERP: 8}) - ma1.assertNeverHasDERPs(ctx, 8) + ma1.RequireNeverHasDERPs(ctx, 8) }() - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -196,19 +216,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord2) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord2) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -246,19 +266,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord2) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord2) + defer ma1.Close() - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -305,129 +325,29 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { defer agent1.close() agent2.sendNode(&agpl.Node{PreferredDERP: 6}) - ma1 := newTestMultiAgent(t, coord3) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord3) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.subscribeAgent(agent2.id) - ma1.assertEventuallyHasDERPs(ctx, 6) + ma1.RequireSubscribeAgent(agent2.id) + ma1.RequireEventuallyHasDERPs(ctx, 6) agent2.sendNode(&agpl.Node{PreferredDERP: 2}) - ma1.assertEventuallyHasDERPs(ctx, 2) + ma1.RequireEventuallyHasDERPs(ctx, 2) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent2, 3) - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) require.NoError(t, agent2.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.id) } - -type testMultiAgent struct { - t testing.TB - id uuid.UUID - a agpl.MultiAgentConn - nodeKey []byte - discoKey string -} - -func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent { - nk, err := key.NewNode().Public().MarshalBinary() - require.NoError(t, err) - dk, err := key.NewDisco().Public().MarshalText() - require.NoError(t, err) - m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)} - m.a = coord.ServeMultiAgent(m.id) - return m -} - -func (m *testMultiAgent) sendNodeWithDERP(derp int32) { - m.t.Helper() - err := m.a.UpdateSelf(&proto.Node{ - Key: m.nodeKey, - Disco: m.discoKey, - PreferredDerp: derp, - }) - require.NoError(m.t, err) -} - -func (m *testMultiAgent) close() { - m.t.Helper() - err := m.a.Close() - require.NoError(m.t, err) -} - -func (m *testMultiAgent) subscribeAgent(id uuid.UUID) { - m.t.Helper() - err := m.a.SubscribeAgent(id) - require.NoError(m.t, err) -} - -func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) { - m.t.Helper() - err := m.a.UnsubscribeAgent(id) - require.NoError(m.t, err) -} - -func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) { - m.t.Helper() - for { - resp, ok := m.a.NextUpdate(ctx) - require.True(m.t, ok) - nodes, err := agpl.OnlyNodeUpdates(resp) - require.NoError(m.t, err) - if len(nodes) != len(expected) { - m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - m.t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} - -func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) { - m.t.Helper() - for { - resp, ok := m.a.NextUpdate(ctx) - if !ok { - return - } - nodes, err := agpl.OnlyNodeUpdates(resp) - require.NoError(m.t, err) - if len(nodes) != len(expected) { - m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - m.t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index a5d9241a85..31e2a9dded 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -1017,7 +1017,13 @@ func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logge } func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) { - defer cancel() + defer func() { + cErr := q.Close() + if cErr != nil { + logger.Info(ctx, "error closing response Queue", slog.Error(cErr)) + } + cancel() + }() for { resp, err := RecvCtx(ctx, resps) if err != nil { diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index ab38f91bd0..72a6591051 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -383,6 +383,24 @@ func TestCoordinator_Lost(t *testing.T) { test.LostTest(ctx, t, coordinator) } +func TestCoordinator_MultiAgent_CoordClose(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + coord1 := tailnet.NewCoordinator(logger.Named("coord1")) + defer coord1.Close() + + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() + + err := coord1.Close() + require.NoError(t, err) + + ma1.RequireEventuallyClosed(ctx) +} + func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) { t.Helper() sc := make(chan net.Conn, 1) diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 794aee549c..e3c66a23ab 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -1,6 +1,7 @@ package tailnettest import ( + "context" "crypto/tls" "fmt" "html" @@ -8,7 +9,11 @@ import ( "net/http" "net/http/httptest" "testing" + "time" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/stun/stuntest" @@ -19,6 +24,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/testutil" ) //go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn @@ -125,3 +132,120 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap { }, } } + +type TestMultiAgent struct { + t testing.TB + id uuid.UUID + a tailnet.MultiAgentConn + nodeKey []byte + discoKey string +} + +func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent { + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + m := &TestMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)} + m.a = coord.ServeMultiAgent(m.id) + return m +} + +func (m *TestMultiAgent) SendNodeWithDERP(d int32) { + m.t.Helper() + err := m.a.UpdateSelf(&proto.Node{ + Key: m.nodeKey, + Disco: m.discoKey, + PreferredDerp: d, + }) + require.NoError(m.t, err) +} + +func (m *TestMultiAgent) Close() { + m.t.Helper() + err := m.a.Close() + require.NoError(m.t, err) +} + +func (m *TestMultiAgent) RequireSubscribeAgent(id uuid.UUID) { + m.t.Helper() + err := m.a.SubscribeAgent(id) + require.NoError(m.t, err) +} + +func (m *TestMultiAgent) RequireUnsubscribeAgent(id uuid.UUID) { + m.t.Helper() + err := m.a.UnsubscribeAgent(id) + require.NoError(m.t, err) +} + +func (m *TestMultiAgent) RequireEventuallyHasDERPs(ctx context.Context, expected ...int) { + m.t.Helper() + for { + resp, ok := m.a.NextUpdate(ctx) + require.True(m.t, ok) + nodes, err := tailnet.OnlyNodeUpdates(resp) + require.NoError(m.t, err) + if len(nodes) != len(expected) { + m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + m.t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} + +func (m *TestMultiAgent) RequireNeverHasDERPs(ctx context.Context, expected ...int) { + m.t.Helper() + for { + resp, ok := m.a.NextUpdate(ctx) + if !ok { + return + } + nodes, err := tailnet.OnlyNodeUpdates(resp) + require.NoError(m.t, err) + if len(nodes) != len(expected) { + m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + m.t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} + +func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) { + m.t.Helper() + tkr := time.NewTicker(testutil.IntervalFast) + defer tkr.Stop() + for { + select { + case <-ctx.Done(): + m.t.Fatal("timeout") + return // unhittable + case <-tkr.C: + if m.a.IsClosed() { + return + } + } + } +}