From f6dc70751152c34cb1b555f0af11904a724fd55f Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 17 Jan 2024 08:55:45 +0400 Subject: [PATCH] chore: add DERPForcedWebsocket to nodeUpdater (#11567) Add support for DERPForcedWebsocket to nodeUpdater --- tailnet/node.go | 26 +++++++++--- tailnet/node_internal_test.go | 79 ++++++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 8 deletions(-) diff --git a/tailnet/node.go b/tailnet/node.go index a9912154d6..8eb3774de3 100644 --- a/tailnet/node.go +++ b/tailnet/node.go @@ -86,12 +86,13 @@ func newNodeUpdater( id tailcfg.NodeID, np key.NodePublic, dp key.DiscoPublic, ) *nodeUpdater { u := &nodeUpdater{ - phased: phased{Cond: *(sync.NewCond(&sync.Mutex{}))}, - logger: logger, - id: id, - key: np, - discoKey: dp, - callback: callback, + phased: phased{Cond: *(sync.NewCond(&sync.Mutex{}))}, + logger: logger, + id: id, + key: np, + discoKey: dp, + derpForcedWebsockets: make(map[int]string), + callback: callback, } go u.updateLoop() return u @@ -132,3 +133,16 @@ func (u *nodeUpdater) setNetInfo(ni *tailcfg.NetInfo) { u.Broadcast() } } + +// setDERPForcedWebsocket handles callbacks from the magicConn about DERP regions that are forced to +// use websockets (instead of Upgrade: derp). This information is for debugging only. +func (u *nodeUpdater) setDERPForcedWebsocket(region int, reason string) { + u.L.Lock() + defer u.L.Unlock() + dirty := u.derpForcedWebsockets[region] != reason + u.derpForcedWebsockets[region] = reason + if dirty { + u.dirty = true + u.Broadcast() + } +} diff --git a/tailnet/node_internal_test.go b/tailnet/node_internal_test.go index 27dc5609d1..da26c1d7ba 100644 --- a/tailnet/node_internal_test.go +++ b/tailnet/node_internal_test.go @@ -74,12 +74,10 @@ func TestNodeUpdater_setNetInfo_same(t *testing.T) { nodeKey := key.NewNode().Public() discoKey := key.NewDisco().Public() nodeCh := make(chan *Node) - goCh := make(chan struct{}) uut := newNodeUpdater( logger, func(n *Node) { nodeCh <- n - <-goCh }, id, nodeKey, discoKey, ) @@ -108,3 +106,80 @@ func TestNodeUpdater_setNetInfo_same(t *testing.T) { }() _ = testutil.RequireRecvCtx(ctx, t, done) } + +func TestNodeUpdater_setDERPForcedWebsocket_different(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + id := tailcfg.NodeID(1) + nodeKey := key.NewNode().Public() + discoKey := key.NewDisco().Public() + nodeCh := make(chan *Node) + uut := newNodeUpdater( + logger, + func(n *Node) { + nodeCh <- n + }, + id, nodeKey, discoKey, + ) + defer uut.close() + + // Given: preferred DERP is 1, so we'll send an update + uut.L.Lock() + uut.preferredDERP = 1 + uut.L.Unlock() + + // When: we set a new forced websocket reason + uut.setDERPForcedWebsocket(1, "test") + + // Then: we receive an update with the reason set + node := testutil.RequireRecvCtx(ctx, t, nodeCh) + require.Equal(t, nodeKey, node.Key) + require.Equal(t, discoKey, node.DiscoKey) + require.True(t, maps.Equal(map[int]string{1: "test"}, node.DERPForcedWebsocket)) + + done := make(chan struct{}) + go func() { + defer close(done) + uut.close() + }() + _ = testutil.RequireRecvCtx(ctx, t, done) +} + +func TestNodeUpdater_setDERPForcedWebsocket_same(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + id := tailcfg.NodeID(1) + nodeKey := key.NewNode().Public() + discoKey := key.NewDisco().Public() + nodeCh := make(chan *Node) + uut := newNodeUpdater( + logger, + func(n *Node) { + nodeCh <- n + }, + id, nodeKey, discoKey, + ) + defer uut.close() + + // Then: we don't configure + requireNeverConfigures(ctx, t, &uut.phased) + + // Given: preferred DERP is 1, so we would send an update on change && + // reason for region 1 is set to "test" + uut.L.Lock() + uut.preferredDERP = 1 + uut.derpForcedWebsockets[1] = "test" + uut.L.Unlock() + + // When: we set region 1 to "test + uut.setDERPForcedWebsocket(1, "test") + + done := make(chan struct{}) + go func() { + defer close(done) + uut.close() + }() + _ = testutil.RequireRecvCtx(ctx, t, done) +}