fix: use TSMP ping for reachability, not latency (#11749)

Use TSMP ping for reachability, but leave Disco ping for when we call Ping() since we often use that to determine whether we have a direct connection.

Also adds unit tests to make sure Ping() returns direct connection vs DERP correctly.
This commit is contained in:
Spike Curtis 2024-01-22 17:37:15 +04:00 committed by GitHub
parent 66f119bde8
commit 5388a1b6d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 101 additions and 3 deletions

View File

@ -374,9 +374,13 @@ func (c *Conn) Status() *ipnstate.Status {
// Ping sends a ping to the Wireguard engine.
// The bool returned is true if the ping was performed P2P.
func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, *ipnstate.PingResult, error) {
return c.pingWithType(ctx, ip, tailcfg.PingDisco)
}
func (c *Conn) pingWithType(ctx context.Context, ip netip.Addr, pt tailcfg.PingType) (time.Duration, bool, *ipnstate.PingResult, error) {
errCh := make(chan error, 1)
prChan := make(chan *ipnstate.PingResult, 1)
go c.wireguardEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
go c.wireguardEngine.Ping(ip, pt, func(pr *ipnstate.PingResult) {
if pr.Err != "" {
errCh <- xerrors.New(pr.Err)
return
@ -418,7 +422,13 @@ func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
_, _, _, err := c.Ping(ctx, ip)
// For reachability, we use TSMP ping, which pings at the IP layer, and
// therefore requires that wireguard and the netstack are up. If we
// don't wait for wireguard to be up, we could miss a handshake, and it
// might take 5 seconds for the handshake to be retried. A 5s initial
// round trip can set us up for poor TCP performance, since the initial
// round-trip-time sets the initial retransmit timeout.
_, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP)
if err == nil {
completed()
}

View File

@ -88,7 +88,7 @@ func TestTailnet(t *testing.T) {
}
})
node := testutil.RequireRecvCtx(ctx, t, nodes)
// Ensure this connected over DERP!
// Ensure this connected over raw (not websocket) DERP!
require.Len(t, node.DERPForcedWebsocket, 0)
w1.Close()
@ -157,6 +157,94 @@ func TestTailnet(t *testing.T) {
w1.Close()
w2.Close()
})
t.Run("PingDirect", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitLong)
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = w1.Close()
_ = w2.Close()
})
stitch(t, w2, w1)
stitch(t, w1, w2)
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
require.Eventually(t, func() bool {
_, direct, pong, err := w2.Ping(ctx, w1IP)
if err != nil {
t.Logf("ping error: %s", err.Error())
return false
}
if !direct {
t.Logf("got pong: %+v", pong)
return false
}
return true
}, testutil.WaitShort, testutil.IntervalFast)
w1.Close()
w2.Close()
})
t.Run("PingDERPOnly", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitLong)
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = w1.Close()
_ = w2.Close()
})
stitch(t, w2, w1)
stitch(t, w1, w2)
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
require.Eventually(t, func() bool {
_, direct, pong, err := w2.Ping(ctx, w1IP)
if err != nil {
t.Logf("ping error: %s", err.Error())
return false
}
if direct || pong.DERPRegionID != derpMap.RegionIDs()[0] {
t.Logf("got pong: %+v", pong)
return false
}
return true
}, testutil.WaitShort, testutil.IntervalFast)
w1.Close()
w2.Close()
})
}
// TestConn_PreferredDERP tests that we only trigger the NodeCallback when we have a preferred DERP server.