fix: ensure agent websocket only removes its own conn (#5828)

This commit is contained in:
Colin Adler 2023-01-23 17:22:34 -06:00 committed by GitHub
parent 443e2180fa
commit c3731a1be0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 140 additions and 13 deletions

View File

@ -104,7 +104,7 @@ func NewCoordinator() Coordinator {
return &coordinator{
closed: false,
nodes: map[uuid.UUID]*Node{},
agentSockets: map[uuid.UUID]net.Conn{},
agentSockets: map[uuid.UUID]idConn{},
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
}
}
@ -123,12 +123,19 @@ type coordinator struct {
// nodes maps agent and connection IDs their respective node.
nodes map[uuid.UUID]*Node
// agentSockets maps agent IDs to their open websocket.
agentSockets map[uuid.UUID]net.Conn
agentSockets map[uuid.UUID]idConn
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
// are subscribed to updates for that agent.
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
}
type idConn struct {
// id is an ephemeral UUID used to uniquely identify the owner of the
// connection.
id uuid.UUID
conn net.Conn
}
// Node returns an in-memory node by ID.
// If the node does not exist, nil is returned.
func (c *coordinator) Node(id uuid.UUID) *Node {
@ -137,6 +144,18 @@ func (c *coordinator) Node(id uuid.UUID) *Node {
return c.nodes[id]
}
func (c *coordinator) NodeCount() int {
c.mutex.Lock()
defer c.mutex.Unlock()
return len(c.nodes)
}
func (c *coordinator) AgentCount() int {
c.mutex.Lock()
defer c.mutex.Unlock()
return len(c.agentSockets)
}
// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
@ -224,9 +243,9 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json
return xerrors.Errorf("marshal nodes: %w", err)
}
_, err = agentSocket.Write(data)
_, err = agentSocket.conn.Write(data)
if err != nil {
if errors.Is(err, io.EOF) {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
return nil
}
return xerrors.Errorf("write json: %w", err)
@ -268,27 +287,41 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
c.mutex.Lock()
}
// If an old agent socket is connected, we close it
// to avoid any leaks. This shouldn't ever occur because
// we expect one agent to be running.
// This uniquely identifies a connection that belongs to this goroutine.
unique := uuid.New()
// If an old agent socket is connected, we close it to avoid any leaks. This
// shouldn't ever occur because we expect one agent to be running, but it's
// possible for a race condition to happen when an agent is disconnected and
// attempts to reconnect before the server realizes the old connection is
// dead.
oldAgentSocket, ok := c.agentSockets[id]
if ok {
_ = oldAgentSocket.Close()
_ = oldAgentSocket.conn.Close()
}
c.agentSockets[id] = conn
c.agentSockets[id] = idConn{
id: unique,
conn: conn,
}
c.mutex.Unlock()
defer func() {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.agentSockets, id)
delete(c.nodes, id)
// Only delete the connection if it's ours. It could have been
// overwritten.
if idConn := c.agentSockets[id]; idConn.id == unique {
delete(c.agentSockets, id)
delete(c.nodes, id)
}
}()
decoder := json.NewDecoder(conn)
for {
err := c.handleNextAgentMessage(id, decoder)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
return nil
}
return xerrors.Errorf("handle next agent message: %w", err)
@ -349,7 +382,7 @@ func (c *coordinator) Close() error {
for _, socket := range c.agentSockets {
socket := socket
go func() {
_ = socket.Close()
_ = socket.conn.Close()
wg.Done()
}()
}

View File

@ -145,4 +145,98 @@ func TestCoordinator(t *testing.T) {
<-clientErrChan
<-closeClientChan
})
t.Run("AgentDoubleConnect", func(t *testing.T) {
t.Parallel()
coordinator := tailnet.NewCoordinator()
agentWS1, agentServerWS1 := net.Pipe()
defer agentWS1.Close()
agentNodeChan1 := make(chan []*tailnet.Node)
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
agentNodeChan1 <- nodes
return nil
})
agentID := uuid.New()
closeAgentChan1 := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS1, agentID)
assert.NoError(t, err)
close(closeAgentChan1)
}()
sendAgentNode1(&tailnet.Node{})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
clientWS, clientServerWS := net.Pipe()
defer clientWS.Close()
defer clientServerWS.Close()
clientNodeChan := make(chan []*tailnet.Node)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
clientNodeChan <- nodes
return nil
})
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
close(closeClientChan)
}()
agentNodes := <-clientNodeChan
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{})
clientNodes := <-agentNodeChan1
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode1(&tailnet.Node{})
agentNodes = <-clientNodeChan
require.Len(t, agentNodes, 1)
// Create a new agent connection without disconnecting the old one.
agentWS2, agentServerWS2 := net.Pipe()
defer agentWS2.Close()
agentNodeChan2 := make(chan []*tailnet.Node)
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
agentNodeChan2 <- nodes
return nil
})
closeAgentChan2 := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS2, agentID)
assert.NoError(t, err)
close(closeAgentChan2)
}()
// Ensure the existing listening client sends it's node immediately!
clientNodes = <-agentNodeChan2
require.Len(t, clientNodes, 1)
counts, ok := coordinator.(interface {
NodeCount() int
AgentCount() int
})
if !ok {
t.Fatal("coordinator should have NodeCount() and AgentCount()")
}
assert.Equal(t, 2, counts.NodeCount())
assert.Equal(t, 1, counts.AgentCount())
err := agentWS2.Close()
require.NoError(t, err)
<-agentErrChan2
<-closeAgentChan2
err = clientWS.Close()
require.NoError(t, err)
<-clientErrChan
<-closeClientChan
// This original agent websocket should've been closed forcefully.
<-agentErrChan1
<-closeAgentChan1
})
}