mirror of https://github.com/coder/coder.git
fix: ensure agent websocket only removes its own conn (#5828)
This commit is contained in:
parent
443e2180fa
commit
c3731a1be0
|
@ -104,7 +104,7 @@ func NewCoordinator() Coordinator {
|
||||||
return &coordinator{
|
return &coordinator{
|
||||||
closed: false,
|
closed: false,
|
||||||
nodes: map[uuid.UUID]*Node{},
|
nodes: map[uuid.UUID]*Node{},
|
||||||
agentSockets: map[uuid.UUID]net.Conn{},
|
agentSockets: map[uuid.UUID]idConn{},
|
||||||
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
|
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,12 +123,19 @@ type coordinator struct {
|
||||||
// nodes maps agent and connection IDs their respective node.
|
// nodes maps agent and connection IDs their respective node.
|
||||||
nodes map[uuid.UUID]*Node
|
nodes map[uuid.UUID]*Node
|
||||||
// agentSockets maps agent IDs to their open websocket.
|
// agentSockets maps agent IDs to their open websocket.
|
||||||
agentSockets map[uuid.UUID]net.Conn
|
agentSockets map[uuid.UUID]idConn
|
||||||
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
|
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
|
||||||
// are subscribed to updates for that agent.
|
// are subscribed to updates for that agent.
|
||||||
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
|
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type idConn struct {
|
||||||
|
// id is an ephemeral UUID used to uniquely identify the owner of the
|
||||||
|
// connection.
|
||||||
|
id uuid.UUID
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
// Node returns an in-memory node by ID.
|
// Node returns an in-memory node by ID.
|
||||||
// If the node does not exist, nil is returned.
|
// If the node does not exist, nil is returned.
|
||||||
func (c *coordinator) Node(id uuid.UUID) *Node {
|
func (c *coordinator) Node(id uuid.UUID) *Node {
|
||||||
|
@ -137,6 +144,18 @@ func (c *coordinator) Node(id uuid.UUID) *Node {
|
||||||
return c.nodes[id]
|
return c.nodes[id]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *coordinator) NodeCount() int {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
return len(c.nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *coordinator) AgentCount() int {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
return len(c.agentSockets)
|
||||||
|
}
|
||||||
|
|
||||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
||||||
// with the specified ID.
|
// with the specified ID.
|
||||||
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
||||||
|
@ -224,9 +243,9 @@ func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json
|
||||||
return xerrors.Errorf("marshal nodes: %w", err)
|
return xerrors.Errorf("marshal nodes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agentSocket.Write(data)
|
_, err = agentSocket.conn.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return xerrors.Errorf("write json: %w", err)
|
return xerrors.Errorf("write json: %w", err)
|
||||||
|
@ -268,27 +287,41 @@ func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// If an old agent socket is connected, we close it
|
// This uniquely identifies a connection that belongs to this goroutine.
|
||||||
// to avoid any leaks. This shouldn't ever occur because
|
unique := uuid.New()
|
||||||
// we expect one agent to be running.
|
|
||||||
|
// If an old agent socket is connected, we close it to avoid any leaks. This
|
||||||
|
// shouldn't ever occur because we expect one agent to be running, but it's
|
||||||
|
// possible for a race condition to happen when an agent is disconnected and
|
||||||
|
// attempts to reconnect before the server realizes the old connection is
|
||||||
|
// dead.
|
||||||
oldAgentSocket, ok := c.agentSockets[id]
|
oldAgentSocket, ok := c.agentSockets[id]
|
||||||
if ok {
|
if ok {
|
||||||
_ = oldAgentSocket.Close()
|
_ = oldAgentSocket.conn.Close()
|
||||||
}
|
}
|
||||||
c.agentSockets[id] = conn
|
c.agentSockets[id] = idConn{
|
||||||
|
id: unique,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
|
||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
defer func() {
|
defer func() {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
delete(c.agentSockets, id)
|
|
||||||
delete(c.nodes, id)
|
// Only delete the connection if it's ours. It could have been
|
||||||
|
// overwritten.
|
||||||
|
if idConn := c.agentSockets[id]; idConn.id == unique {
|
||||||
|
delete(c.agentSockets, id)
|
||||||
|
delete(c.nodes, id)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
decoder := json.NewDecoder(conn)
|
decoder := json.NewDecoder(conn)
|
||||||
for {
|
for {
|
||||||
err := c.handleNextAgentMessage(id, decoder)
|
err := c.handleNextAgentMessage(id, decoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return xerrors.Errorf("handle next agent message: %w", err)
|
return xerrors.Errorf("handle next agent message: %w", err)
|
||||||
|
@ -349,7 +382,7 @@ func (c *coordinator) Close() error {
|
||||||
for _, socket := range c.agentSockets {
|
for _, socket := range c.agentSockets {
|
||||||
socket := socket
|
socket := socket
|
||||||
go func() {
|
go func() {
|
||||||
_ = socket.Close()
|
_ = socket.conn.Close()
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,4 +145,98 @@ func TestCoordinator(t *testing.T) {
|
||||||
<-clientErrChan
|
<-clientErrChan
|
||||||
<-closeClientChan
|
<-closeClientChan
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("AgentDoubleConnect", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
coordinator := tailnet.NewCoordinator()
|
||||||
|
|
||||||
|
agentWS1, agentServerWS1 := net.Pipe()
|
||||||
|
defer agentWS1.Close()
|
||||||
|
agentNodeChan1 := make(chan []*tailnet.Node)
|
||||||
|
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
|
||||||
|
agentNodeChan1 <- nodes
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
agentID := uuid.New()
|
||||||
|
closeAgentChan1 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
err := coordinator.ServeAgent(agentServerWS1, agentID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
close(closeAgentChan1)
|
||||||
|
}()
|
||||||
|
sendAgentNode1(&tailnet.Node{})
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return coordinator.Node(agentID) != nil
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
|
||||||
|
clientWS, clientServerWS := net.Pipe()
|
||||||
|
defer clientWS.Close()
|
||||||
|
defer clientServerWS.Close()
|
||||||
|
clientNodeChan := make(chan []*tailnet.Node)
|
||||||
|
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
|
||||||
|
clientNodeChan <- nodes
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
clientID := uuid.New()
|
||||||
|
closeClientChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
close(closeClientChan)
|
||||||
|
}()
|
||||||
|
agentNodes := <-clientNodeChan
|
||||||
|
require.Len(t, agentNodes, 1)
|
||||||
|
sendClientNode(&tailnet.Node{})
|
||||||
|
clientNodes := <-agentNodeChan1
|
||||||
|
require.Len(t, clientNodes, 1)
|
||||||
|
|
||||||
|
// Ensure an update to the agent node reaches the client!
|
||||||
|
sendAgentNode1(&tailnet.Node{})
|
||||||
|
agentNodes = <-clientNodeChan
|
||||||
|
require.Len(t, agentNodes, 1)
|
||||||
|
|
||||||
|
// Create a new agent connection without disconnecting the old one.
|
||||||
|
agentWS2, agentServerWS2 := net.Pipe()
|
||||||
|
defer agentWS2.Close()
|
||||||
|
agentNodeChan2 := make(chan []*tailnet.Node)
|
||||||
|
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
|
||||||
|
agentNodeChan2 <- nodes
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
closeAgentChan2 := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
err := coordinator.ServeAgent(agentServerWS2, agentID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
close(closeAgentChan2)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Ensure the existing listening client sends it's node immediately!
|
||||||
|
clientNodes = <-agentNodeChan2
|
||||||
|
require.Len(t, clientNodes, 1)
|
||||||
|
|
||||||
|
counts, ok := coordinator.(interface {
|
||||||
|
NodeCount() int
|
||||||
|
AgentCount() int
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("coordinator should have NodeCount() and AgentCount()")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 2, counts.NodeCount())
|
||||||
|
assert.Equal(t, 1, counts.AgentCount())
|
||||||
|
|
||||||
|
err := agentWS2.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
<-agentErrChan2
|
||||||
|
<-closeAgentChan2
|
||||||
|
|
||||||
|
err = clientWS.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
<-clientErrChan
|
||||||
|
<-closeClientChan
|
||||||
|
|
||||||
|
// This original agent websocket should've been closed forcefully.
|
||||||
|
<-agentErrChan1
|
||||||
|
<-closeAgentChan1
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue