diff --git a/tailnet/node.go b/tailnet/node.go index 4178aec021..0365a3c28e 100644 --- a/tailnet/node.go +++ b/tailnet/node.go @@ -55,14 +55,20 @@ func (u *nodeUpdater) updateLoop() { u.logger.Debug(context.Background(), "closing nodeUpdater updateLoop") return } - node := u.nodeLocked() u.dirty = false u.phase = configuring u.Broadcast() + callback := u.callback + if callback == nil { + u.logger.Debug(context.Background(), "skipped sending node; no node callback") + continue + } + // We cannot reach nodes without DERP for discovery. Therefore, there is no point in sending // the node without this, and we can save ourselves from churn in the tailscale/wireguard // layer. + node := u.nodeLocked() if node.PreferredDERP == 0 { u.logger.Debug(context.Background(), "skipped sending node; no PreferredDERP", slog.F("node", node)) continue @@ -70,7 +76,7 @@ func (u *nodeUpdater) updateLoop() { u.L.Unlock() u.logger.Debug(context.Background(), "calling nodeUpdater callback", slog.F("node", node)) - u.callback(node) + callback(node) u.L.Lock() } } @@ -155,7 +161,7 @@ func (u *nodeUpdater) setDERPForcedWebsocket(region int, reason string) { } // setStatus handles the status callback from the wireguard engine to learn about new endpoints -// (e.g. discovered by STUN) +// (e.g. discovered by STUN). u.L MUST NOT be held func (u *nodeUpdater) setStatus(s *wgengine.Status, err error) { u.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err)) if err != nil { @@ -181,6 +187,7 @@ func (u *nodeUpdater) setStatus(s *wgengine.Status, err error) { u.Broadcast() } +// setAddresses sets the local addresses for the node. u.L MUST NOT be held. func (u *nodeUpdater) setAddresses(ips []netip.Prefix) { u.L.Lock() defer u.L.Unlock() @@ -192,3 +199,13 @@ func (u *nodeUpdater) setAddresses(ips []netip.Prefix) { u.dirty = true u.Broadcast() } + +// setCallback sets the callback for node changes. It also triggers a call +// for the current node immediately. u.L MUST NOT be held. +func (u *nodeUpdater) setCallback(callback func(node *Node)) { + u.L.Lock() + defer u.L.Unlock() + u.callback = callback + u.dirty = true + u.Broadcast() +} diff --git a/tailnet/node_internal_test.go b/tailnet/node_internal_test.go index f014977eda..7f2d1bd190 100644 --- a/tailnet/node_internal_test.go +++ b/tailnet/node_internal_test.go @@ -441,3 +441,44 @@ func TestNodeUpdater_setAddresses_same(t *testing.T) { }() _ = testutil.RequireRecvCtx(ctx, t, done) } + +func TestNodeUpdater_setCallback(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + id := tailcfg.NodeID(1) + nodeKey := key.NewNode().Public() + discoKey := key.NewDisco().Public() + uut := newNodeUpdater( + logger, + nil, + id, nodeKey, discoKey, + ) + defer uut.close() + + // Given: preferred DERP is 1 + addrs := []netip.Prefix{netip.MustParsePrefix("192.168.0.200/32")} + uut.L.Lock() + uut.preferredDERP = 1 + uut.addresses = slices.Clone(addrs) + uut.L.Unlock() + + // When: we set callback + nodeCh := make(chan *Node) + uut.setCallback(func(n *Node) { + nodeCh <- n + }) + + // Then: we get a node update + node := testutil.RequireRecvCtx(ctx, t, nodeCh) + require.Equal(t, nodeKey, node.Key) + require.Equal(t, discoKey, node.DiscoKey) + require.Equal(t, 1, node.PreferredDERP) + + done := make(chan struct{}) + go func() { + defer close(done) + uut.close() + }() + _ = testutil.RequireRecvCtx(ctx, t, done) +}