coder/tailnet/conn_test.go

169 lines
4.2 KiB
Go

package tailnet_test
import (
"context"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestTailnet(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap := tailnettest.RunDERPAndSTUN(t)
t.Run("InstantClose", func(t *testing.T) {
t.Parallel()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
})
t.Run("Connect", func(t *testing.T) {
t.Parallel()
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = w1.Close()
_ = w2.Close()
})
w1.SetNodeCallback(func(node *tailnet.Node) {
err := w2.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
w2.SetNodeCallback(func(node *tailnet.Node) {
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
conn := make(chan struct{}, 1)
go func() {
listener, err := w1.Listen("tcp", ":35565")
assert.NoError(t, err)
defer listener.Close()
nc, err := listener.Accept()
if !assert.NoError(t, err) {
return
}
_ = nc.Close()
conn <- struct{}{}
}()
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
require.NoError(t, err)
_ = nc.Close()
<-conn
nodes := make(chan *tailnet.Node, 1)
w2.SetNodeCallback(func(node *tailnet.Node) {
select {
case nodes <- node:
default:
}
})
node := <-nodes
// Ensure this connected over DERP!
require.Len(t, node.DERPForcedWebsocket, 0)
w1.Close()
w2.Close()
})
t.Run("ForcesWebSockets", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := testutil.Context(t)
defer cancelFunc()
w1IP := tailnet.IP()
derpMap := tailnettest.RunDERPOnlyWebSockets(t)
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
BlockEndpoints: true,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = w1.Close()
_ = w2.Close()
})
w1.SetNodeCallback(func(node *tailnet.Node) {
err := w2.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
w2.SetNodeCallback(func(node *tailnet.Node) {
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
require.True(t, w2.AwaitReachable(ctx, w1IP))
conn := make(chan struct{}, 1)
go func() {
listener, err := w1.Listen("tcp", ":35565")
assert.NoError(t, err)
defer listener.Close()
nc, err := listener.Accept()
if !assert.NoError(t, err) {
return
}
_ = nc.Close()
conn <- struct{}{}
}()
nc, err := w2.DialContextTCP(ctx, netip.AddrPortFrom(w1IP, 35565))
require.NoError(t, err)
_ = nc.Close()
<-conn
nodes := make(chan *tailnet.Node, 1)
w2.SetNodeCallback(func(node *tailnet.Node) {
select {
case nodes <- node:
default:
}
})
node := <-nodes
require.Len(t, node.DERPForcedWebsocket, 1)
// Ensure the reason is valid!
require.Equal(t, "GET failed with status code 400: Invalid \"Upgrade\" header: DERP", node.DERPForcedWebsocket[derpMap.RegionIDs()[0]])
w1.Close()
w2.Close()
})
}