coder/tailnet/tailnettest/tailnettest.go

351 lines
8.4 KiB
Go

package tailnettest
import (
"context"
"crypto/tls"
"fmt"
"html"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/net/stun/stuntest"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
tslogger "tailscale.com/types/logger"
"tailscale.com/types/nettype"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
//go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
//go:generate mockgen -destination ./coordinateemock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinatee
type derpAndSTUNCfg struct {
DisableSTUN bool
DERPIsEmbedded bool
}
type DERPAndStunOption func(cfg *derpAndSTUNCfg)
func DisableSTUN(cfg *derpAndSTUNCfg) {
cfg.DisableSTUN = true
}
func DERPIsEmbedded(cfg *derpAndSTUNCfg) {
cfg.DERPIsEmbedded = true
}
// RunDERPAndSTUN creates a DERP mapping for tests.
func RunDERPAndSTUN(t *testing.T, opts ...DERPAndStunOption) (*tailcfg.DERPMap, *derp.Server) {
cfg := new(derpAndSTUNCfg)
for _, o := range opts {
o(cfg)
}
logf := tailnet.Logger(slogtest.Make(t, nil))
d := derp.NewServer(key.NewNode(), logf)
server := httptest.NewUnstartedServer(derphttp.Handler(d))
server.Config.ErrorLog = tslogger.StdLogger(logf)
server.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
server.StartTLS()
t.Cleanup(func() {
server.CloseClientConnections()
server.Close()
d.Close()
})
stunPort := -1
if !cfg.DisableSTUN {
stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{})
t.Cleanup(func() {
stunCleanup()
})
stunPort = stunAddr.Port
}
tcpAddr, ok := server.Listener.Addr().(*net.TCPAddr)
if !ok {
t.FailNow()
}
return &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
RegionID: 1,
RegionCode: "test",
RegionName: "Test",
Nodes: []*tailcfg.DERPNode{
{
Name: "t2",
RegionID: 1,
IPv4: "127.0.0.1",
IPv6: "none",
STUNPort: stunPort,
DERPPort: tcpAddr.Port,
InsecureForTests: true,
},
},
EmbeddedRelay: cfg.DERPIsEmbedded,
},
},
}, d
}
// RunDERPOnlyWebSockets creates a DERP mapping for tests that
// only allows WebSockets through it. Many proxies do not support
// upgrading DERP, so this is a good fallback.
func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
logf := tailnet.Logger(slogtest.Make(t, nil))
d := derp.NewServer(key.NewNode(), logf)
handler := derphttp.Handler(d)
var closeFunc func()
handler, closeFunc = tailnet.WithWebsocketSupport(d, handler)
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/derp" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello"))
return
}
if r.Header.Get("Upgrade") != "websocket" {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(fmt.Sprintf(`Invalid "Upgrade" header: %s`, html.EscapeString(r.Header.Get("Upgrade")))))
return
}
handler.ServeHTTP(w, r)
}))
server.Config.ErrorLog = tslogger.StdLogger(logf)
server.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
server.StartTLS()
t.Cleanup(func() {
server.CloseClientConnections()
server.Close()
closeFunc()
d.Close()
})
tcpAddr, ok := server.Listener.Addr().(*net.TCPAddr)
if !ok {
t.FailNow()
}
return &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
RegionID: 1,
RegionCode: "test",
RegionName: "Test",
Nodes: []*tailcfg.DERPNode{
{
Name: "t1",
RegionID: 1,
IPv4: "127.0.0.1",
IPv6: "none",
DERPPort: tcpAddr.Port,
InsecureForTests: true,
},
},
},
},
}
}
type TestMultiAgent struct {
t testing.TB
id uuid.UUID
a tailnet.MultiAgentConn
nodeKey []byte
discoKey string
}
func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent {
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
m := &TestMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
m.a = coord.ServeMultiAgent(m.id)
return m
}
func (m *TestMultiAgent) SendNodeWithDERP(d int32) {
m.t.Helper()
err := m.a.UpdateSelf(&proto.Node{
Key: m.nodeKey,
Disco: m.discoKey,
PreferredDerp: d,
})
require.NoError(m.t, err)
}
func (m *TestMultiAgent) Close() {
m.t.Helper()
err := m.a.Close()
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireSubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.SubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireUnsubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.UnsubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireEventuallyHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
require.True(m.t, ok)
nodes, err := tailnet.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *TestMultiAgent) RequireNeverHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
if !ok {
return
}
nodes, err := tailnet.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) {
m.t.Helper()
tkr := time.NewTicker(testutil.IntervalFast)
defer tkr.Stop()
for {
select {
case <-ctx.Done():
m.t.Fatal("timeout")
return // unhittable
case <-tkr.C:
if m.a.IsClosed() {
return
}
}
}
}
type FakeCoordinator struct {
CoordinateCalls chan *FakeCoordinate
ServeClientCalls chan *FakeServeClient
}
func (*FakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) {
panic("unimplemented")
}
func (*FakeCoordinator) Node(uuid.UUID) *tailnet.Node {
panic("unimplemented")
}
func (f *FakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
errCh := make(chan error)
f.ServeClientCalls <- &FakeServeClient{
Conn: conn,
ID: id,
Agent: agent,
ErrCh: errCh,
}
return <-errCh
}
func (*FakeCoordinator) ServeAgent(net.Conn, uuid.UUID, string) error {
panic("unimplemented")
}
func (*FakeCoordinator) Close() error {
panic("unimplemented")
}
func (*FakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn {
panic("unimplemented")
}
func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
f.CoordinateCalls <- &FakeCoordinate{
Ctx: ctx,
ID: id,
Name: name,
Auth: a,
Reqs: reqs,
Resps: resps,
}
return reqs, resps
}
func NewFakeCoordinator() *FakeCoordinator {
return &FakeCoordinator{
CoordinateCalls: make(chan *FakeCoordinate, 100),
ServeClientCalls: make(chan *FakeServeClient, 100),
}
}
type FakeCoordinate struct {
Ctx context.Context
ID uuid.UUID
Name string
Auth tailnet.CoordinateeAuth
Reqs chan *proto.CoordinateRequest
Resps chan *proto.CoordinateResponse
}
type FakeServeClient struct {
Conn net.Conn
ID uuid.UUID
Agent uuid.UUID
ErrCh chan error
}