coder/tailnet/conn_test.go

490 lines
14 KiB
Go

package tailnet_test
import (
"context"
"net/netip"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestTailnet(t *testing.T) {
t.Parallel()
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
t.Run("InstantClose", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
})
t.Run("Connect", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitLong)
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = w1.Close()
_ = w2.Close()
})
stitch(t, w2, w1)
stitch(t, w1, w2)
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
conn := make(chan struct{}, 1)
go func() {
listener, err := w1.Listen("tcp", ":35565")
assert.NoError(t, err)
defer listener.Close()
nc, err := listener.Accept()
if !assert.NoError(t, err) {
return
}
_ = nc.Close()
conn <- struct{}{}
}()
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
require.NoError(t, err)
_ = nc.Close()
<-conn
nodes := make(chan *tailnet.Node, 1)
w2.SetNodeCallback(func(node *tailnet.Node) {
select {
case nodes <- node:
default:
}
})
node := testutil.RequireRecvCtx(ctx, t, nodes)
// Ensure this connected over raw (not websocket) DERP!
require.Len(t, node.DERPForcedWebsocket, 0)
w1.Close()
w2.Close()
})
t.Run("ForcesWebSockets", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
w1IP := tailnet.IP()
derpMap := tailnettest.RunDERPOnlyWebSockets(t)
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = w1.Close()
_ = w2.Close()
})
stitch(t, w2, w1)
stitch(t, w1, w2)
require.True(t, w2.AwaitReachable(ctx, w1IP))
conn := make(chan struct{}, 1)
go func() {
listener, err := w1.Listen("tcp", ":35565")
assert.NoError(t, err)
defer listener.Close()
nc, err := listener.Accept()
if !assert.NoError(t, err) {
return
}
_ = nc.Close()
conn <- struct{}{}
}()
nc, err := w2.DialContextTCP(ctx, netip.AddrPortFrom(w1IP, 35565))
require.NoError(t, err)
_ = nc.Close()
<-conn
nodes := make(chan *tailnet.Node, 1)
w2.SetNodeCallback(func(node *tailnet.Node) {
select {
case nodes <- node:
default:
}
})
node := <-nodes
require.Len(t, node.DERPForcedWebsocket, 1)
// Ensure the reason is valid!
require.Equal(t, `GET failed with status code 400 (a proxy could be disallowing the use of 'Upgrade: derp'): Invalid "Upgrade" header: DERP`, node.DERPForcedWebsocket[derpMap.RegionIDs()[0]])
w1.Close()
w2.Close()
})
t.Run("PingDirect", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitLong)
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = w1.Close()
_ = w2.Close()
})
stitch(t, w2, w1)
stitch(t, w1, w2)
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
require.Eventually(t, func() bool {
_, direct, pong, err := w2.Ping(ctx, w1IP)
if err != nil {
t.Logf("ping error: %s", err.Error())
return false
}
if !direct {
t.Logf("got pong: %+v", pong)
return false
}
return true
}, testutil.WaitShort, testutil.IntervalFast)
w1.Close()
w2.Close()
})
t.Run("PingDERPOnly", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitLong)
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = w1.Close()
_ = w2.Close()
})
stitch(t, w2, w1)
stitch(t, w1, w2)
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
require.Eventually(t, func() bool {
_, direct, pong, err := w2.Ping(ctx, w1IP)
if err != nil {
t.Logf("ping error: %s", err.Error())
return false
}
if direct || pong.DERPRegionID != derpMap.RegionIDs()[0] {
t.Logf("got pong: %+v", pong)
return false
}
return true
}, testutil.WaitShort, testutil.IntervalFast)
w1.Close()
w2.Close()
})
}
// TestConn_PreferredDERP tests that we only trigger the NodeCallback when we have a preferred DERP server.
func TestConn_PreferredDERP(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
defer func() {
err := conn.Close()
require.NoError(t, err)
}()
// buffer channel so callback doesn't block
nodes := make(chan *tailnet.Node, 50)
conn.SetNodeCallback(func(node *tailnet.Node) {
nodes <- node
})
select {
case node := <-nodes:
require.Equal(t, 1, node.PreferredDERP)
case <-ctx.Done():
t.Fatal("timed out waiting for node")
}
}
// TestConn_UpdateDERP tests that when update the DERP map we pick a new
// preferred DERP server and new connections can be made from clients.
func TestConn_UpdateDERP(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap1, _ := tailnettest.RunDERPAndSTUN(t)
ip := tailnet.IP()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap1,
BlockEndpoints: true,
})
require.NoError(t, err)
defer func() {
err := conn.Close()
assert.NoError(t, err)
}()
// Buffer channel so callback doesn't block
nodes := make(chan *tailnet.Node, 50)
conn.SetNodeCallback(func(node *tailnet.Node) {
nodes <- node
})
ctx1, cancel1 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel1()
select {
case node := <-nodes:
require.Equal(t, 1, node.PreferredDERP)
case <-ctx1.Done():
t.Fatal("timed out waiting for node")
}
// Connect from a different client.
client1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("client1"),
DERPMap: derpMap1,
BlockEndpoints: true,
})
require.NoError(t, err)
defer func() {
err := client1.Close()
assert.NoError(t, err)
}()
stitch(t, conn, client1)
pn, err := tailnet.NodeToProto(conn.Node())
require.NoError(t, err)
connID := uuid.New()
err = client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
require.NoError(t, err)
awaitReachableCtx1, awaitReachableCancel1 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel1()
require.True(t, client1.AwaitReachable(awaitReachableCtx1, ip))
// Update the DERP map and wait for the preferred DERP server to change.
derpMap2, _ := tailnettest.RunDERPAndSTUN(t)
// Change the region ID.
derpMap2.Regions[2] = derpMap2.Regions[1]
delete(derpMap2.Regions, 1)
derpMap2.Regions[2].RegionID = 2
for _, node := range derpMap2.Regions[2].Nodes {
node.RegionID = 2
}
conn.SetDERPMap(derpMap2)
ctx2, cancel2 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel2()
parentLoop:
for {
select {
case node := <-nodes:
if node.PreferredDERP != 2 {
t.Logf("waiting for preferred DERP server to change, got %v", node.PreferredDERP)
continue
}
t.Log("preferred DERP server changed!")
break parentLoop
case <-ctx2.Done():
t.Fatal("timed out waiting for preferred DERP server to change")
}
}
// Client1 should be dropped...
awaitReachableCtx2, awaitReachableCancel2 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel2()
require.False(t, client1.AwaitReachable(awaitReachableCtx2, ip))
// ... unless the client updates it's derp map and nodes.
client1.SetDERPMap(derpMap2)
pn, err = tailnet.NodeToProto(conn.Node())
require.NoError(t, err)
client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
awaitReachableCtx3, awaitReachableCancel3 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel3()
require.True(t, client1.AwaitReachable(awaitReachableCtx3, ip))
// Connect from a different different client with up-to-date derp map and
// nodes.
client2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("client2"),
DERPMap: derpMap2,
BlockEndpoints: true,
})
require.NoError(t, err)
defer func() {
err := client2.Close()
assert.NoError(t, err)
}()
stitch(t, conn, client2)
pn, err = tailnet.NodeToProto(conn.Node())
require.NoError(t, err)
client2.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
awaitReachableCtx4, awaitReachableCancel4 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel4()
require.True(t, client2.AwaitReachable(awaitReachableCtx4, ip))
}
func TestConn_BlockEndpoints(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
// Setup conn 1.
ip1 := tailnet.IP()
conn1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(ip1, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
defer func() {
err := conn1.Close()
assert.NoError(t, err)
}()
// Setup conn 2.
ip2 := tailnet.IP()
conn2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(ip2, 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
defer func() {
err := conn2.Close()
assert.NoError(t, err)
}()
// Connect them together and wait for them to be reachable.
stitch(t, conn2, conn1)
stitch(t, conn1, conn2)
awaitReachableCtx, awaitReachableCancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel()
require.True(t, conn1.AwaitReachable(awaitReachableCtx, ip2))
// Wait 10s for endpoints to potentially be sent over Disco. There's no way
// to force Disco to send endpoints immediately.
time.Sleep(10 * time.Second)
// Double check that both peers don't have endpoints for the other peer
// according to magicsock.
conn1Status, ok := conn1.Status().Peer[conn2.Node().Key]
require.True(t, ok)
require.Empty(t, conn1Status.Addrs)
require.Empty(t, conn1Status.CurAddr)
conn2Status, ok := conn2.Status().Peer[conn1.Node().Key]
require.True(t, ok)
require.Empty(t, conn2Status.Addrs)
require.Empty(t, conn2Status.CurAddr)
}
// stitch sends node updates from src Conn as peer updates to dst Conn. Sort of
// like the Coordinator would, but without actually needing a Coordinator.
func stitch(t *testing.T, dst, src *tailnet.Conn) {
srcID := uuid.New()
src.SetNodeCallback(func(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
err = dst.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: srcID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
assert.NoError(t, err)
})
}