From c3731a1be0bdf3a8b2527186c32a04b44973fdfe Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Mon, 23 Jan 2023 17:22:34 -0600 Subject: [PATCH] fix: ensure agent websocket only removes its own conn (#5828) --- tailnet/coordinator.go | 59 ++++++++++++++++++----- tailnet/coordinator_test.go | 94 +++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 13 deletions(-) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index dbd70ead1a..7c3f48c9ea 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -104,7 +104,7 @@ func NewCoordinator() Coordinator { return &coordinator{ closed: false, nodes: map[uuid.UUID]*Node{}, - agentSockets: map[uuid.UUID]net.Conn{}, + agentSockets: map[uuid.UUID]idConn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, } } @@ -123,12 +123,19 @@ type coordinator struct { // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]net.Conn + agentSockets map[uuid.UUID]idConn // agentToConnectionSockets maps agent IDs to connection IDs of conns that // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn } +type idConn struct { + // id is an ephemeral UUID used to uniquely identify the owner of the + // connection. + id uuid.UUID + conn net.Conn +} + // Node returns an in-memory node by ID. // If the node does not exist, nil is returned. func (c *coordinator) Node(id uuid.UUID) *Node { @@ -137,6 +144,18 @@ func (c *coordinator) Node(id uuid.UUID) *Node { return c.nodes[id] } +func (c *coordinator) NodeCount() int { + c.mutex.Lock() + defer c.mutex.Unlock() + return len(c.nodes) +} + +func (c *coordinator) AgentCount() int { + c.mutex.Lock() + defer c.mutex.Unlock() + return len(c.agentSockets) +} + // ServeClient accepts a WebSocket connection that wants to connect to an agent // with the specified ID. func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { @@ -224,9 +243,9 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json return xerrors.Errorf("marshal nodes: %w", err) } - _, err = agentSocket.Write(data) + _, err = agentSocket.conn.Write(data) if err != nil { - if errors.Is(err, io.EOF) { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { return nil } return xerrors.Errorf("write json: %w", err) @@ -268,27 +287,41 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { c.mutex.Lock() } - // If an old agent socket is connected, we close it - // to avoid any leaks. This shouldn't ever occur because - // we expect one agent to be running. + // This uniquely identifies a connection that belongs to this goroutine. + unique := uuid.New() + + // If an old agent socket is connected, we close it to avoid any leaks. This + // shouldn't ever occur because we expect one agent to be running, but it's + // possible for a race condition to happen when an agent is disconnected and + // attempts to reconnect before the server realizes the old connection is + // dead. oldAgentSocket, ok := c.agentSockets[id] if ok { - _ = oldAgentSocket.Close() + _ = oldAgentSocket.conn.Close() } - c.agentSockets[id] = conn + c.agentSockets[id] = idConn{ + id: unique, + conn: conn, + } + c.mutex.Unlock() defer func() { c.mutex.Lock() defer c.mutex.Unlock() - delete(c.agentSockets, id) - delete(c.nodes, id) + + // Only delete the connection if it's ours. It could have been + // overwritten. + if idConn := c.agentSockets[id]; idConn.id == unique { + delete(c.agentSockets, id) + delete(c.nodes, id) + } }() decoder := json.NewDecoder(conn) for { err := c.handleNextAgentMessage(id, decoder) if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { return nil } return xerrors.Errorf("handle next agent message: %w", err) @@ -349,7 +382,7 @@ func (c *coordinator) Close() error { for _, socket := range c.agentSockets { socket := socket go func() { - _ = socket.Close() + _ = socket.conn.Close() wg.Done() }() } diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index a4a020dead..60d909f715 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -145,4 +145,98 @@ func TestCoordinator(t *testing.T) { <-clientErrChan <-closeClientChan }) + + t.Run("AgentDoubleConnect", func(t *testing.T) { + t.Parallel() + coordinator := tailnet.NewCoordinator() + + agentWS1, agentServerWS1 := net.Pipe() + defer agentWS1.Close() + agentNodeChan1 := make(chan []*tailnet.Node) + sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error { + agentNodeChan1 <- nodes + return nil + }) + agentID := uuid.New() + closeAgentChan1 := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS1, agentID) + assert.NoError(t, err) + close(closeAgentChan1) + }() + sendAgentNode1(&tailnet.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*tailnet.Node) + sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error { + clientNodeChan <- nodes + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&tailnet.Node{}) + clientNodes := <-agentNodeChan1 + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode1(&tailnet.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Create a new agent connection without disconnecting the old one. + agentWS2, agentServerWS2 := net.Pipe() + defer agentWS2.Close() + agentNodeChan2 := make(chan []*tailnet.Node) + _, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error { + agentNodeChan2 <- nodes + return nil + }) + closeAgentChan2 := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS2, agentID) + assert.NoError(t, err) + close(closeAgentChan2) + }() + + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan2 + require.Len(t, clientNodes, 1) + + counts, ok := coordinator.(interface { + NodeCount() int + AgentCount() int + }) + if !ok { + t.Fatal("coordinator should have NodeCount() and AgentCount()") + } + + assert.Equal(t, 2, counts.NodeCount()) + assert.Equal(t, 1, counts.AgentCount()) + + err := agentWS2.Close() + require.NoError(t, err) + <-agentErrChan2 + <-closeAgentChan2 + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + + // This original agent websocket should've been closed forcefully. + <-agentErrChan1 + <-closeAgentChan1 + }) }