chore: add setStatus support to nodeUpdater (#11568)

Add support for the wgengine Status callback to nodeUpdater
This commit is contained in:
Spike Curtis 2024-01-17 09:06:34 +04:00 committed by GitHub
parent f6dc707511
commit 38d9ce5267
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 208 additions and 0 deletions

View File

@ -4,11 +4,13 @@ import (
"context"
"net/netip"
"sync"
"time"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/wgengine"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/dbtime"
@ -32,6 +34,7 @@ type nodeUpdater struct {
derpForcedWebsockets map[int]string
endpoints []string
addresses []netip.Prefix
lastStatus time.Time
}
// updateLoop waits until the config is dirty and then calls the callback with the newest node.
@ -146,3 +149,30 @@ func (u *nodeUpdater) setDERPForcedWebsocket(region int, reason string) {
u.Broadcast()
}
}
// setStatus handles the status callback from the wireguard engine to learn about new endpoints
// (e.g. discovered by STUN)
func (u *nodeUpdater) setStatus(s *wgengine.Status, err error) {
u.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err))
if err != nil {
return
}
u.L.Lock()
defer u.L.Unlock()
if s.AsOf.Before(u.lastStatus) {
// Don't process outdated status!
return
}
u.lastStatus = s.AsOf
endpoints := make([]string, len(s.LocalAddrs))
for i, ep := range s.LocalAddrs {
endpoints[i] = ep.Addr.String()
}
if slices.Equal(endpoints, u.endpoints) {
// No need to update the node if nothing changed!
return
}
u.endpoints = endpoints
u.dirty = true
u.Broadcast()
}

View File

@ -1,7 +1,15 @@
package tailnet
import (
"net/netip"
"testing"
"time"
"golang.org/x/xerrors"
"golang.org/x/exp/slices"
"tailscale.com/wgengine"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
@ -183,3 +191,173 @@ func TestNodeUpdater_setDERPForcedWebsocket_same(t *testing.T) {
}()
_ = testutil.RequireRecvCtx(ctx, t, done)
}
func TestNodeUpdater_setStatus_different(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
id := tailcfg.NodeID(1)
nodeKey := key.NewNode().Public()
discoKey := key.NewDisco().Public()
nodeCh := make(chan *Node)
uut := newNodeUpdater(
logger,
func(n *Node) {
nodeCh <- n
},
id, nodeKey, discoKey,
)
defer uut.close()
// Given: preferred DERP is 1, so we'll send an update
uut.L.Lock()
uut.preferredDERP = 1
uut.L.Unlock()
// When: we set a new status
asof := time.Date(2024, 1, 10, 8, 0o0, 1, 1, time.UTC)
uut.setStatus(&wgengine.Status{
LocalAddrs: []tailcfg.Endpoint{
{Addr: netip.MustParseAddrPort("[fe80::1]:5678")},
},
AsOf: asof,
}, nil)
// Then: we receive an update with the endpoint
node := testutil.RequireRecvCtx(ctx, t, nodeCh)
require.Equal(t, nodeKey, node.Key)
require.Equal(t, discoKey, node.DiscoKey)
require.True(t, slices.Equal([]string{"[fe80::1]:5678"}, node.Endpoints))
// Then: we store the AsOf time as lastStatus
uut.L.Lock()
require.Equal(t, uut.lastStatus, asof)
uut.L.Unlock()
done := make(chan struct{})
go func() {
defer close(done)
uut.close()
}()
_ = testutil.RequireRecvCtx(ctx, t, done)
}
func TestNodeUpdater_setStatus_same(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
id := tailcfg.NodeID(1)
nodeKey := key.NewNode().Public()
discoKey := key.NewDisco().Public()
nodeCh := make(chan *Node)
uut := newNodeUpdater(
logger,
func(n *Node) {
nodeCh <- n
},
id, nodeKey, discoKey,
)
defer uut.close()
// Then: we don't configure
requireNeverConfigures(ctx, t, &uut.phased)
// Given: preferred DERP is 1, so we would send an update on change &&
// endpoints set to {"[fe80::1]:5678"}
uut.L.Lock()
uut.preferredDERP = 1
uut.endpoints = []string{"[fe80::1]:5678"}
uut.L.Unlock()
// When: we set a status with endpoints {[fe80::1]:5678}
uut.setStatus(&wgengine.Status{LocalAddrs: []tailcfg.Endpoint{
{Addr: netip.MustParseAddrPort("[fe80::1]:5678")},
}}, nil)
done := make(chan struct{})
go func() {
defer close(done)
uut.close()
}()
_ = testutil.RequireRecvCtx(ctx, t, done)
}
func TestNodeUpdater_setStatus_error(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
id := tailcfg.NodeID(1)
nodeKey := key.NewNode().Public()
discoKey := key.NewDisco().Public()
nodeCh := make(chan *Node)
uut := newNodeUpdater(
logger,
func(n *Node) {
nodeCh <- n
},
id, nodeKey, discoKey,
)
defer uut.close()
// Then: we don't configure
requireNeverConfigures(ctx, t, &uut.phased)
// Given: preferred DERP is 1, so we would send an update on change && empty endpoints
uut.L.Lock()
uut.preferredDERP = 1
uut.L.Unlock()
// When: we set a status with endpoints {[fe80::1]:5678}, with an error
uut.setStatus(&wgengine.Status{LocalAddrs: []tailcfg.Endpoint{
{Addr: netip.MustParseAddrPort("[fe80::1]:5678")},
}}, xerrors.New("test"))
done := make(chan struct{})
go func() {
defer close(done)
uut.close()
}()
_ = testutil.RequireRecvCtx(ctx, t, done)
}
func TestNodeUpdater_setStatus_outdated(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
id := tailcfg.NodeID(1)
nodeKey := key.NewNode().Public()
discoKey := key.NewDisco().Public()
nodeCh := make(chan *Node)
uut := newNodeUpdater(
logger,
func(n *Node) {
nodeCh <- n
},
id, nodeKey, discoKey,
)
defer uut.close()
// Then: we don't configure
requireNeverConfigures(ctx, t, &uut.phased)
// Given: preferred DERP is 1, so we would send an update on change && lastStatus set ahead
ahead := time.Date(2024, 1, 10, 8, 0o0, 1, 0, time.UTC)
behind := time.Date(2024, 1, 10, 8, 0o0, 0, 0, time.UTC)
uut.L.Lock()
uut.preferredDERP = 1
uut.lastStatus = ahead
uut.L.Unlock()
// When: we set a status with endpoints {[fe80::1]:5678}, with AsOf set behind
uut.setStatus(&wgengine.Status{
LocalAddrs: []tailcfg.Endpoint{{Addr: netip.MustParseAddrPort("[fe80::1]:5678")}},
AsOf: behind,
}, xerrors.New("test"))
done := make(chan struct{})
go func() {
defer close(done)
uut.close()
}()
_ = testutil.RequireRecvCtx(ctx, t, done)
}