mirror of https://github.com/coder/coder.git
fix: ensure agent websocket only removes its own conn (#5828)
This commit is contained in:
parent
443e2180fa
commit
c3731a1be0
|
@ -104,7 +104,7 @@ func NewCoordinator() Coordinator {
|
|||
return &coordinator{
|
||||
closed: false,
|
||||
nodes: map[uuid.UUID]*Node{},
|
||||
agentSockets: map[uuid.UUID]net.Conn{},
|
||||
agentSockets: map[uuid.UUID]idConn{},
|
||||
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
|
||||
}
|
||||
}
|
||||
|
@ -123,12 +123,19 @@ type coordinator struct {
|
|||
// nodes maps agent and connection IDs their respective node.
|
||||
nodes map[uuid.UUID]*Node
|
||||
// agentSockets maps agent IDs to their open websocket.
|
||||
agentSockets map[uuid.UUID]net.Conn
|
||||
agentSockets map[uuid.UUID]idConn
|
||||
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
|
||||
// are subscribed to updates for that agent.
|
||||
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
|
||||
}
|
||||
|
||||
type idConn struct {
|
||||
// id is an ephemeral UUID used to uniquely identify the owner of the
|
||||
// connection.
|
||||
id uuid.UUID
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// Node returns an in-memory node by ID.
|
||||
// If the node does not exist, nil is returned.
|
||||
func (c *coordinator) Node(id uuid.UUID) *Node {
|
||||
|
@ -137,6 +144,18 @@ func (c *coordinator) Node(id uuid.UUID) *Node {
|
|||
return c.nodes[id]
|
||||
}
|
||||
|
||||
func (c *coordinator) NodeCount() int {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return len(c.nodes)
|
||||
}
|
||||
|
||||
func (c *coordinator) AgentCount() int {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return len(c.agentSockets)
|
||||
}
|
||||
|
||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
||||
// with the specified ID.
|
||||
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
||||
|
@ -224,9 +243,9 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json
|
|||
return xerrors.Errorf("marshal nodes: %w", err)
|
||||
}
|
||||
|
||||
_, err = agentSocket.Write(data)
|
||||
_, err = agentSocket.conn.Write(data)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("write json: %w", err)
|
||||
|
@ -268,27 +287,41 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
|
|||
c.mutex.Lock()
|
||||
}
|
||||
|
||||
// If an old agent socket is connected, we close it
|
||||
// to avoid any leaks. This shouldn't ever occur because
|
||||
// we expect one agent to be running.
|
||||
// This uniquely identifies a connection that belongs to this goroutine.
|
||||
unique := uuid.New()
|
||||
|
||||
// If an old agent socket is connected, we close it to avoid any leaks. This
|
||||
// shouldn't ever occur because we expect one agent to be running, but it's
|
||||
// possible for a race condition to happen when an agent is disconnected and
|
||||
// attempts to reconnect before the server realizes the old connection is
|
||||
// dead.
|
||||
oldAgentSocket, ok := c.agentSockets[id]
|
||||
if ok {
|
||||
_ = oldAgentSocket.Close()
|
||||
_ = oldAgentSocket.conn.Close()
|
||||
}
|
||||
c.agentSockets[id] = conn
|
||||
c.agentSockets[id] = idConn{
|
||||
id: unique,
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
c.mutex.Unlock()
|
||||
defer func() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.agentSockets, id)
|
||||
delete(c.nodes, id)
|
||||
|
||||
// Only delete the connection if it's ours. It could have been
|
||||
// overwritten.
|
||||
if idConn := c.agentSockets[id]; idConn.id == unique {
|
||||
delete(c.agentSockets, id)
|
||||
delete(c.nodes, id)
|
||||
}
|
||||
}()
|
||||
|
||||
decoder := json.NewDecoder(conn)
|
||||
for {
|
||||
err := c.handleNextAgentMessage(id, decoder)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("handle next agent message: %w", err)
|
||||
|
@ -349,7 +382,7 @@ func (c *coordinator) Close() error {
|
|||
for _, socket := range c.agentSockets {
|
||||
socket := socket
|
||||
go func() {
|
||||
_ = socket.Close()
|
||||
_ = socket.conn.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -145,4 +145,98 @@ func TestCoordinator(t *testing.T) {
|
|||
<-clientErrChan
|
||||
<-closeClientChan
|
||||
})
|
||||
|
||||
t.Run("AgentDoubleConnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
coordinator := tailnet.NewCoordinator()
|
||||
|
||||
agentWS1, agentServerWS1 := net.Pipe()
|
||||
defer agentWS1.Close()
|
||||
agentNodeChan1 := make(chan []*tailnet.Node)
|
||||
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
|
||||
agentNodeChan1 <- nodes
|
||||
return nil
|
||||
})
|
||||
agentID := uuid.New()
|
||||
closeAgentChan1 := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeAgent(agentServerWS1, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeAgentChan1)
|
||||
}()
|
||||
sendAgentNode1(&tailnet.Node{})
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(agentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
clientWS, clientServerWS := net.Pipe()
|
||||
defer clientWS.Close()
|
||||
defer clientServerWS.Close()
|
||||
clientNodeChan := make(chan []*tailnet.Node)
|
||||
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
|
||||
clientNodeChan <- nodes
|
||||
return nil
|
||||
})
|
||||
clientID := uuid.New()
|
||||
closeClientChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeClientChan)
|
||||
}()
|
||||
agentNodes := <-clientNodeChan
|
||||
require.Len(t, agentNodes, 1)
|
||||
sendClientNode(&tailnet.Node{})
|
||||
clientNodes := <-agentNodeChan1
|
||||
require.Len(t, clientNodes, 1)
|
||||
|
||||
// Ensure an update to the agent node reaches the client!
|
||||
sendAgentNode1(&tailnet.Node{})
|
||||
agentNodes = <-clientNodeChan
|
||||
require.Len(t, agentNodes, 1)
|
||||
|
||||
// Create a new agent connection without disconnecting the old one.
|
||||
agentWS2, agentServerWS2 := net.Pipe()
|
||||
defer agentWS2.Close()
|
||||
agentNodeChan2 := make(chan []*tailnet.Node)
|
||||
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
|
||||
agentNodeChan2 <- nodes
|
||||
return nil
|
||||
})
|
||||
closeAgentChan2 := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeAgent(agentServerWS2, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeAgentChan2)
|
||||
}()
|
||||
|
||||
// Ensure the existing listening client sends it's node immediately!
|
||||
clientNodes = <-agentNodeChan2
|
||||
require.Len(t, clientNodes, 1)
|
||||
|
||||
counts, ok := coordinator.(interface {
|
||||
NodeCount() int
|
||||
AgentCount() int
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("coordinator should have NodeCount() and AgentCount()")
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, counts.NodeCount())
|
||||
assert.Equal(t, 1, counts.AgentCount())
|
||||
|
||||
err := agentWS2.Close()
|
||||
require.NoError(t, err)
|
||||
<-agentErrChan2
|
||||
<-closeAgentChan2
|
||||
|
||||
err = clientWS.Close()
|
||||
require.NoError(t, err)
|
||||
<-clientErrChan
|
||||
<-closeClientChan
|
||||
|
||||
// This original agent websocket should've been closed forcefully.
|
||||
<-agentErrChan1
|
||||
<-closeAgentChan1
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue