diff --git a/tailnet/node.go b/tailnet/node.go index 0365a3c28e..e7e83b6690 100644 --- a/tailnet/node.go +++ b/tailnet/node.go @@ -35,6 +35,7 @@ type nodeUpdater struct { endpoints []string addresses []netip.Prefix lastStatus time.Time + blockEndpoints bool } // updateLoop waits until the config is dirty and then calls the callback with the newest node. @@ -111,6 +112,10 @@ func newNodeUpdater( // nodeLocked returns the current best node information. u.L must be held. func (u *nodeUpdater) nodeLocked() *Node { + var endpoints []string + if !u.blockEndpoints { + endpoints = slices.Clone(u.endpoints) + } return &Node{ ID: u.id, AsOf: dbtime.Now(), @@ -118,7 +123,7 @@ func (u *nodeUpdater) nodeLocked() *Node { Addresses: slices.Clone(u.addresses), AllowedIPs: slices.Clone(u.addresses), DiscoKey: u.discoKey, - Endpoints: slices.Clone(u.endpoints), + Endpoints: endpoints, PreferredDERP: u.preferredDERP, DERPLatency: maps.Clone(u.derpLatency), DERPForcedWebsocket: maps.Clone(u.derpForcedWebsockets), @@ -209,3 +214,17 @@ func (u *nodeUpdater) setCallback(callback func(node *Node)) { u.dirty = true u.Broadcast() } + +// setBlockEndpoints sets whether we block reporting Node endpoints. u.L MUST NOT +// be held. +// nolint: revive +func (u *nodeUpdater) setBlockEndpoints(blockEndpoints bool) { + u.L.Lock() + defer u.L.Unlock() + if u.blockEndpoints == blockEndpoints { + return + } + u.dirty = true + u.blockEndpoints = blockEndpoints + u.Broadcast() +} diff --git a/tailnet/node_internal_test.go b/tailnet/node_internal_test.go index 7f2d1bd190..aa933de4be 100644 --- a/tailnet/node_internal_test.go +++ b/tailnet/node_internal_test.go @@ -482,3 +482,90 @@ func TestNodeUpdater_setCallback(t *testing.T) { }() _ = testutil.RequireRecvCtx(ctx, t, done) } + +func TestNodeUpdater_setBlockEndpoints_different(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + id := tailcfg.NodeID(1) + nodeKey := key.NewNode().Public() + discoKey := key.NewDisco().Public() + nodeCh := make(chan *Node) + uut := newNodeUpdater( + logger, + func(n *Node) { + nodeCh <- n + }, + id, nodeKey, discoKey, + ) + defer uut.close() + + // Given: preferred DERP is 1, so we'll send an update && some endpoints + uut.L.Lock() + uut.preferredDERP = 1 + uut.endpoints = []string{"10.11.12.13:7890"} + uut.L.Unlock() + + // When: we setBlockEndpoints + uut.setBlockEndpoints(true) + + // Then: we receive an update without endpoints + node := testutil.RequireRecvCtx(ctx, t, nodeCh) + require.Equal(t, nodeKey, node.Key) + require.Equal(t, discoKey, node.DiscoKey) + require.Len(t, node.Endpoints, 0) + + // When: we unset BlockEndpoints + uut.setBlockEndpoints(false) + + // Then: we receive an update with endpoints + node = testutil.RequireRecvCtx(ctx, t, nodeCh) + require.Equal(t, nodeKey, node.Key) + require.Equal(t, discoKey, node.DiscoKey) + require.Len(t, node.Endpoints, 1) + + done := make(chan struct{}) + go func() { + defer close(done) + uut.close() + }() + _ = testutil.RequireRecvCtx(ctx, t, done) +} + +func TestNodeUpdater_setBlockEndpoints_same(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + id := tailcfg.NodeID(1) + nodeKey := key.NewNode().Public() + discoKey := key.NewDisco().Public() + nodeCh := make(chan *Node) + uut := newNodeUpdater( + logger, + func(n *Node) { + nodeCh <- n + }, + id, nodeKey, discoKey, + ) + defer uut.close() + + // Then: we don't configure + requireNeverConfigures(ctx, t, &uut.phased) + + // Given: preferred DERP is 1, so we would send an update on change && + // blockEndpoints already set + uut.L.Lock() + uut.preferredDERP = 1 + uut.blockEndpoints = true + uut.L.Unlock() + + // When: we set block endpoints + uut.setBlockEndpoints(true) + + done := make(chan struct{}) + go func() { + defer close(done) + uut.close() + }() + _ = testutil.RequireRecvCtx(ctx, t, done) +}