
565 lines
18 KiB

package tailnet_test
import (
func TestCoordinator(t *testing.T) {
t.Run("ClientWithoutAgent", func(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(server, id, uuid.New())
assert.NoError(t, err)
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, client.Close())
require.NoError(t, server.Close())
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
t.Run("AgentWithoutClients", func(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
err := client.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
t.Run("AgentWithClient", func(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
// in this test we use real websockets to test use of deadlines
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
agentWS, agentServerWS := websocketConn(ctx, t)
defer agentWS.Close()
agentNodeChan := make(chan []*tailnet.Node)
sendAgentNode, agentErrChan := tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error {
agentNodeChan <- nodes
return nil
agentID := uuid.New()
closeAgentChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
sendAgentNode(&tailnet.Node{PreferredDERP: 1})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
clientWS, clientServerWS := websocketConn(ctx, t)
defer clientWS.Close()
defer clientServerWS.Close()
clientNodeChan := make(chan []*tailnet.Node)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
clientNodeChan <- nodes
return nil
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
// wait longer than the internal wait timeout.
// this tests for regression of
time.Sleep(tailnet.WriteTimeout * 3 / 2)
// Ensure an update to the agent node reaches the client!
sendAgentNode(&tailnet.Node{PreferredDERP: 3})
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
// Close the agent WebSocket so a new one can connect.
err := agentWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
// Create a new agent connection. This is to simulate a reconnect!
agentWS, agentServerWS = net.Pipe()
defer agentWS.Close()
agentNodeChan = make(chan []*tailnet.Node)
_, agentErrChan = tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error {
agentNodeChan <- nodes
return nil
closeAgentChan = make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
// Ensure the existing listening client sends its node immediately!
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
err = agentWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
err = clientWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
t.Run("AgentDoubleConnect", func(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitLong)
agentWS1, agentServerWS1 := net.Pipe()
defer agentWS1.Close()
agentNodeChan1 := make(chan []*tailnet.Node)
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
t.Logf("agent1 got node update: %v", nodes)
agentNodeChan1 <- nodes
return nil
agentID := uuid.New()
closeAgentChan1 := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS1, agentID, "")
assert.NoError(t, err)
sendAgentNode1(&tailnet.Node{PreferredDERP: 1})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
clientWS, clientServerWS := net.Pipe()
defer clientWS.Close()
defer clientServerWS.Close()
clientNodeChan := make(chan []*tailnet.Node)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
t.Logf("client got node update: %v", nodes)
clientNodeChan <- nodes
return nil
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan1)
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode1(&tailnet.Node{PreferredDERP: 3})
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
// Create a new agent connection without disconnecting the old one.
agentWS2, agentServerWS2 := net.Pipe()
defer agentWS2.Close()
agentNodeChan2 := make(chan []*tailnet.Node)
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
t.Logf("agent2 got node update: %v", nodes)
agentNodeChan2 <- nodes
return nil
closeAgentChan2 := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS2, agentID, "")
assert.NoError(t, err)
// Ensure the existing listening client sends it's node immediately!
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan2)
require.Len(t, clientNodes, 1)
// This original agent websocket should've been closed forcefully.
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan1)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan1)
err := agentWS2.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan2)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan2)
err = clientWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on
func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
agentWS, agentServerWS := net.Pipe()
defer agentWS.Close()
agentID := uuid.New()
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
// send an agent update before the client connects so that there is
// node data available to send right away.
aNode := tailnet.Node{PreferredDERP: 0}
aData, err := json.Marshal(&aNode)
require.NoError(t, err)
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
_, err = agentWS.Write(aData)
require.NoError(t, err)
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
// Connect from the client
clientWS, clientServerWS := net.Pipe()
defer clientWS.Close()
clientID := uuid.New()
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
// peek one byte from the node update, so we know the coordinator is
// trying to write to the client.
// buffer needs to be 2 characters longer because return value is a list
// so, it needs [ and ]
buf := make([]byte, len(aData)+2)
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err := clientWS.Read(buf[:1])
require.NoError(t, err)
require.Equal(t, 1, n)
// send a second update
aNode.PreferredDERP = 1
require.NoError(t, err)
aData, err = json.Marshal(&aNode)
require.NoError(t, err)
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
_, err = agentWS.Write(aData)
require.NoError(t, err)
// read the rest of the update from the client, should be initial node.
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err = clientWS.Read(buf[1:])
require.NoError(t, err)
var cNodes []*tailnet.Node
err = json.Unmarshal(buf[:n+1], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 0, cNodes[0].PreferredDERP)
// read second update
// without a fix for our
// read would time out here.
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err = clientWS.Read(buf)
require.NoError(t, err)
err = json.Unmarshal(buf[:n], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 1, cNodes[0].PreferredDERP)
func TestCoordinator_BidirectionalTunnels(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.BidirectionalTunnels(ctx, t, coordinator)
func TestCoordinator_GracefulDisconnect(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.GracefulDisconnectTest(ctx, t, coordinator)
func TestCoordinator_Lost(t *testing.T) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.LostTest(ctx, t, coordinator)
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
sc := make(chan net.Conn, 1)
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
wss, err := websocket.Accept(rw, r, nil)
require.NoError(t, err)
conn := websocket.NetConn(r.Context(), wss, websocket.MessageBinary)
sc <- conn
close(sc) // there can be only one
// hold open until context canceled
// nolint: bodyclose
wsc, _, err := websocket.Dial(ctx, s.URL, nil)
require.NoError(t, err)
client = websocket.NetConn(ctx, wsc, websocket.MessageBinary)
server, ok := <-sc
require.True(t, ok)
return client, server
func TestInMemoryCoordination(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close()
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.NoError(t, err)
// OK!
func TestRemoteCoordination(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
svc, err := tailnet.NewClientService(
logger.Named("svc"), &coordPtr,
func() *tailcfg.DERPMap { panic("not implemented") },
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
client, err := tailnet.NewDRPCClient(cC)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close()
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
// OK!
// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
uut tailnet.Coordination, fConn *fakeCoordinatee,
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
// When we send a peer update, it should update the coordinatee
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
updates := []*proto.CoordinateResponse_PeerUpdate{
Id: agentID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 2,
Key: nk,
Disco: string(dk),
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
defer fConn.Unlock()
return len(fConn.updates) > 0
}, testutil.WaitShort, testutil.IntervalFast)
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
err = uut.Close()
require.NoError(t, err)
// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
type fakeCoordinatee struct {
callback func(*tailnet.Node)
updates [][]*proto.CoordinateResponse_PeerUpdate
setAllPeersLostCalls int
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
defer f.Unlock()
f.updates = append(f.updates, updates)
return nil
func (f *fakeCoordinatee) SetAllPeersLost() {
defer f.Unlock()
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
defer f.Unlock()
f.callback = callback