fix: close MultiAgentConn when coordinator closes (#11941)

Fixes an issue where a MultiAgentConn isn't closed properly when the coordinator it is connected to is closed.

Since servertailnet checks whether the conn is closed before reinitializing, it is important that we check this, otherwise servertailnet can get stuck if the coordinator closes (e.g. when we switch from AGPL to PGCoordinator after decoding a license).
This commit is contained in:
Spike Curtis 2024-01-31 00:38:19 +04:00 committed by GitHub
parent 2fd1a726aa
commit 520b12e1a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 236 additions and 167 deletions

View File

@ -2,7 +2,6 @@ package tailnet
import (
"context"
"io"
"sync"
"sync/atomic"
"time"
@ -104,19 +103,21 @@ func (c *connIO) recvLoop() {
}()
defer c.Close()
for {
req, err := agpl.RecvCtx(c.peerCtx, c.requests)
if err != nil {
if xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
xerrors.Is(err, io.EOF) {
c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err))
} else {
c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err))
select {
case <-c.coordCtx.Done():
c.logger.Debug(c.coordCtx, "exiting io recvLoop; coordinator exit")
return
case <-c.peerCtx.Done():
c.logger.Debug(c.peerCtx, "exiting io recvLoop; peer context canceled")
return
case req, ok := <-c.requests:
if !ok {
c.logger.Debug(c.peerCtx, "exiting io recvLoop; requests chan closed")
return
}
if err := c.handleRequest(req); err != nil {
return
}
return
}
if err := c.handleRequest(req); err != nil {
return
}
}
}

View File

@ -4,17 +4,14 @@ import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
@ -42,25 +39,48 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
ma1.close()
ma1.Close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
err = coord1.Close()
require.NoError(t, err)
ma1.RequireEventuallyClosed(ctx)
}
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
// a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe
// with the MultiAgent closing.
@ -86,20 +106,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
ma1.unsubscribeAgent(agent1.id)
ma1.close()
ma1.RequireUnsubscribeAgent(agent1.id)
ma1.Close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -131,35 +151,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
ma1.unsubscribeAgent(agent1.id)
ma1.RequireUnsubscribeAgent(agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
ma1.sendNodeWithDERP(9)
ma1.SendNodeWithDERP(9)
assertNeverHasDERPs(ctx, t, agent1, 9)
}()
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
ma1.assertNeverHasDERPs(ctx, 8)
ma1.RequireNeverHasDERPs(ctx, 8)
}()
ma1.close()
ma1.Close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -196,19 +216,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
ma1 := newTestMultiAgent(t, coord2)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
defer ma1.Close()
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
ma1.close()
ma1.Close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -246,19 +266,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
ma1 := newTestMultiAgent(t, coord2)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
defer ma1.Close()
ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.close()
ma1.Close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -305,129 +325,29 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
defer agent1.close()
agent2.sendNode(&agpl.Node{PreferredDERP: 6})
ma1 := newTestMultiAgent(t, coord3)
defer ma1.close()
ma1 := tailnettest.NewTestMultiAgent(t, coord3)
defer ma1.Close()
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
ma1.RequireSubscribeAgent(agent1.id)
ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
ma1.assertEventuallyHasDERPs(ctx, 1)
ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.subscribeAgent(agent2.id)
ma1.assertEventuallyHasDERPs(ctx, 6)
ma1.RequireSubscribeAgent(agent2.id)
ma1.RequireEventuallyHasDERPs(ctx, 6)
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
ma1.assertEventuallyHasDERPs(ctx, 2)
ma1.RequireEventuallyHasDERPs(ctx, 2)
ma1.sendNodeWithDERP(3)
ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
assertEventuallyHasDERPs(ctx, t, agent2, 3)
ma1.close()
ma1.Close()
require.NoError(t, agent1.close())
require.NoError(t, agent2.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
type testMultiAgent struct {
t testing.TB
id uuid.UUID
a agpl.MultiAgentConn
nodeKey []byte
discoKey string
}
func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent {
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
m.a = coord.ServeMultiAgent(m.id)
return m
}
func (m *testMultiAgent) sendNodeWithDERP(derp int32) {
m.t.Helper()
err := m.a.UpdateSelf(&proto.Node{
Key: m.nodeKey,
Disco: m.discoKey,
PreferredDerp: derp,
})
require.NoError(m.t, err)
}
func (m *testMultiAgent) close() {
m.t.Helper()
err := m.a.Close()
require.NoError(m.t, err)
}
func (m *testMultiAgent) subscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.SubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.UnsubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
require.True(m.t, ok)
nodes, err := agpl.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
if !ok {
return
}
nodes, err := agpl.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}

View File

@ -1017,7 +1017,13 @@ func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logge
}
func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
defer cancel()
defer func() {
cErr := q.Close()
if cErr != nil {
logger.Info(ctx, "error closing response Queue", slog.Error(cErr))
}
cancel()
}()
for {
resp, err := RecvCtx(ctx, resps)
if err != nil {

View File

@ -383,6 +383,24 @@ func TestCoordinator_Lost(t *testing.T) {
test.LostTest(ctx, t, coordinator)
}
func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
coord1 := tailnet.NewCoordinator(logger.Named("coord1"))
defer coord1.Close()
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close()
err := coord1.Close()
require.NoError(t, err)
ma1.RequireEventuallyClosed(ctx)
}
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
t.Helper()
sc := make(chan net.Conn, 1)

View File

@ -1,6 +1,7 @@
package tailnettest
import (
"context"
"crypto/tls"
"fmt"
"html"
@ -8,7 +9,11 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/net/stun/stuntest"
@ -19,6 +24,8 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
@ -125,3 +132,120 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
},
}
}
type TestMultiAgent struct {
t testing.TB
id uuid.UUID
a tailnet.MultiAgentConn
nodeKey []byte
discoKey string
}
func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent {
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
m := &TestMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
m.a = coord.ServeMultiAgent(m.id)
return m
}
func (m *TestMultiAgent) SendNodeWithDERP(d int32) {
m.t.Helper()
err := m.a.UpdateSelf(&proto.Node{
Key: m.nodeKey,
Disco: m.discoKey,
PreferredDerp: d,
})
require.NoError(m.t, err)
}
func (m *TestMultiAgent) Close() {
m.t.Helper()
err := m.a.Close()
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireSubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.SubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireUnsubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.UnsubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *TestMultiAgent) RequireEventuallyHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
require.True(m.t, ok)
nodes, err := tailnet.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *TestMultiAgent) RequireNeverHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
if !ok {
return
}
nodes, err := tailnet.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) {
m.t.Helper()
tkr := time.NewTicker(testutil.IntervalFast)
defer tkr.Stop()
for {
select {
case <-ctx.Done():
m.t.Fatal("timeout")
return // unhittable
case <-tkr.C:
if m.a.IsClosed() {
return
}
}
}
}