mirror of https://github.com/coder/coder.git
fix: make PGCoordinator close connections when unhealthy (#9125)
Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
c217a0d819
commit
c7a6d626b4
|
@ -211,6 +211,7 @@ issues:
|
|||
run:
|
||||
skip-dirs:
|
||||
- node_modules
|
||||
- .git
|
||||
skip-files:
|
||||
- scripts/rules.go
|
||||
timeout: 10m
|
||||
|
|
|
@ -586,10 +586,12 @@ type querier struct {
|
|||
|
||||
workQ *workQ[mKey]
|
||||
heartbeats *heartbeats
|
||||
updates <-chan struct{}
|
||||
updates <-chan hbUpdate
|
||||
|
||||
mu sync.Mutex
|
||||
mappers map[mKey]*countedMapper
|
||||
conns map[*connIO]struct{}
|
||||
healthy bool
|
||||
}
|
||||
|
||||
type countedMapper struct {
|
||||
|
@ -604,7 +606,7 @@ func newQuerier(
|
|||
self uuid.UUID, newConnections chan *connIO, numWorkers int,
|
||||
firstHeartbeat chan<- struct{},
|
||||
) *querier {
|
||||
updates := make(chan struct{})
|
||||
updates := make(chan hbUpdate)
|
||||
q := &querier{
|
||||
ctx: ctx,
|
||||
logger: logger.Named("querier"),
|
||||
|
@ -614,7 +616,9 @@ func newQuerier(
|
|||
workQ: newWorkQ[mKey](ctx),
|
||||
heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat),
|
||||
mappers: make(map[mKey]*countedMapper),
|
||||
conns: make(map[*connIO]struct{}),
|
||||
updates: updates,
|
||||
healthy: true, // assume we start healthy
|
||||
}
|
||||
go q.subscribe()
|
||||
go q.handleConnIO()
|
||||
|
@ -639,6 +643,15 @@ func (q *querier) handleConnIO() {
|
|||
func (q *querier) newConn(c *connIO) {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
if !q.healthy {
|
||||
err := c.updates.Close()
|
||||
q.logger.Info(q.ctx, "closed incoming connection while unhealthy",
|
||||
slog.Error(err),
|
||||
slog.F("agent_id", c.agent),
|
||||
slog.F("client_id", c.client),
|
||||
)
|
||||
return
|
||||
}
|
||||
mk := mKey{
|
||||
agent: c.agent,
|
||||
// if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself
|
||||
|
@ -661,6 +674,7 @@ func (q *querier) newConn(c *connIO) {
|
|||
return
|
||||
}
|
||||
cm.count++
|
||||
q.conns[c] = struct{}{}
|
||||
go q.cleanupConn(c)
|
||||
}
|
||||
|
||||
|
@ -668,6 +682,7 @@ func (q *querier) cleanupConn(c *connIO) {
|
|||
<-c.ctx.Done()
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
delete(q.conns, c)
|
||||
mk := mKey{
|
||||
agent: c.agent,
|
||||
// if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself
|
||||
|
@ -911,8 +926,18 @@ func (q *querier) handleUpdates() {
|
|||
select {
|
||||
case <-q.ctx.Done():
|
||||
return
|
||||
case <-q.updates:
|
||||
q.updateAll()
|
||||
case u := <-q.updates:
|
||||
if u.filter == filterUpdateUpdated {
|
||||
q.updateAll()
|
||||
}
|
||||
if u.health == healthUpdateUnhealthy {
|
||||
q.unhealthyCloseAll()
|
||||
continue
|
||||
}
|
||||
if u.health == healthUpdateHealthy {
|
||||
q.setHealthy()
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -932,6 +957,30 @@ func (q *querier) updateAll() {
|
|||
}
|
||||
}
|
||||
|
||||
// unhealthyCloseAll marks the coordinator unhealthy and closes all connections. We do this so that clients and agents
|
||||
// are forced to reconnect to the coordinator, and will hopefully land on a healthy coordinator.
|
||||
func (q *querier) unhealthyCloseAll() {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
q.healthy = false
|
||||
for c := range q.conns {
|
||||
// close connections async so that we don't block the querier routine that responds to updates
|
||||
go func(c *connIO) {
|
||||
err := c.updates.Close()
|
||||
if err != nil {
|
||||
q.logger.Debug(q.ctx, "error closing conn while unhealthy", slog.Error(err))
|
||||
}
|
||||
}(c)
|
||||
// NOTE: we don't need to remove the connection from the map, as that will happen async in q.cleanupConn()
|
||||
}
|
||||
}
|
||||
|
||||
func (q *querier) setHealthy() {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
q.healthy = true
|
||||
}
|
||||
|
||||
func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAgent, map[uuid.UUID][]database.TailnetClient, error) {
|
||||
agents, err := q.store.GetAllTailnetAgents(ctx)
|
||||
if err != nil {
|
||||
|
@ -1078,6 +1127,28 @@ func (q *workQ[K]) done(key K) {
|
|||
q.cond.Signal()
|
||||
}
|
||||
|
||||
type filterUpdate int
|
||||
|
||||
const (
|
||||
filterUpdateNone filterUpdate = iota
|
||||
filterUpdateUpdated
|
||||
)
|
||||
|
||||
type healthUpdate int
|
||||
|
||||
const (
|
||||
healthUpdateNone healthUpdate = iota
|
||||
healthUpdateHealthy
|
||||
healthUpdateUnhealthy
|
||||
)
|
||||
|
||||
// hbUpdate is an update sent from the heartbeats to the querier. Zero values of the fields mean no update of that
|
||||
// kind.
|
||||
type hbUpdate struct {
|
||||
filter filterUpdate
|
||||
health healthUpdate
|
||||
}
|
||||
|
||||
// heartbeats sends heartbeats for this coordinator on a timer, and monitors heartbeats from other coordinators. If a
|
||||
// coordinator misses their heartbeat, we remove it from our map of "valid" coordinators, such that we will filter out
|
||||
// any mappings for it when filter() is called, and we send a signal on the update channel, which triggers all mappers
|
||||
|
@ -1089,8 +1160,9 @@ type heartbeats struct {
|
|||
store database.Store
|
||||
self uuid.UUID
|
||||
|
||||
update chan<- struct{}
|
||||
firstHeartbeat chan<- struct{}
|
||||
update chan<- hbUpdate
|
||||
firstHeartbeat chan<- struct{}
|
||||
failedHeartbeats int
|
||||
|
||||
lock sync.RWMutex
|
||||
coordinators map[uuid.UUID]time.Time
|
||||
|
@ -1103,7 +1175,7 @@ type heartbeats struct {
|
|||
func newHeartbeats(
|
||||
ctx context.Context, logger slog.Logger,
|
||||
ps pubsub.Pubsub, store database.Store,
|
||||
self uuid.UUID, update chan<- struct{},
|
||||
self uuid.UUID, update chan<- hbUpdate,
|
||||
firstHeartbeat chan<- struct{},
|
||||
) *heartbeats {
|
||||
h := &heartbeats{
|
||||
|
@ -1194,7 +1266,7 @@ func (h *heartbeats) recvBeat(id uuid.UUID) {
|
|||
h.logger.Info(h.ctx, "heartbeats (re)started", slog.F("other_coordinator_id", id))
|
||||
// send on a separate goroutine to avoid holding lock. Triggering update can be async
|
||||
go func() {
|
||||
_ = sendCtx(h.ctx, h.update, struct{}{})
|
||||
_ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
|
||||
}()
|
||||
}
|
||||
h.coordinators[id] = time.Now()
|
||||
|
@ -1241,7 +1313,7 @@ func (h *heartbeats) checkExpiry() {
|
|||
if expired {
|
||||
// send on a separate goroutine to avoid holding lock. Triggering update can be async
|
||||
go func() {
|
||||
_ = sendCtx(h.ctx, h.update, struct{}{})
|
||||
_ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
|
||||
}()
|
||||
}
|
||||
// we need to reset the timer for when the next oldest coordinator will expire, if any.
|
||||
|
@ -1269,11 +1341,20 @@ func (h *heartbeats) sendBeats() {
|
|||
func (h *heartbeats) sendBeat() {
|
||||
_, err := h.store.UpsertTailnetCoordinator(h.ctx, h.self)
|
||||
if err != nil {
|
||||
// just log errors, heartbeats are rescheduled on a timer
|
||||
h.logger.Error(h.ctx, "failed to send heartbeat", slog.Error(err))
|
||||
h.failedHeartbeats++
|
||||
if h.failedHeartbeats == 3 {
|
||||
h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy")
|
||||
_ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
|
||||
}
|
||||
return
|
||||
}
|
||||
h.logger.Debug(h.ctx, "sent heartbeat")
|
||||
if h.failedHeartbeats >= 3 {
|
||||
h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy")
|
||||
_ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
|
||||
}
|
||||
h.failedHeartbeats = 0
|
||||
}
|
||||
|
||||
func (h *heartbeats) sendDelete() {
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -21,7 +22,9 @@ import (
|
|||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbmock"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/enterprise/tailnet"
|
||||
agpl "github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/testutil"
|
||||
|
@ -36,11 +39,11 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
|
|||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
|
@ -75,11 +78,11 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
|||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
|
@ -112,11 +115,11 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
|||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
|
@ -189,11 +192,11 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
|||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
|
@ -276,14 +279,14 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
|
|||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
mu := sync.Mutex{}
|
||||
heartbeats := []time.Time{}
|
||||
unsub, err := pubsub.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) {
|
||||
unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) {
|
||||
assert.NoError(t, err)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
@ -293,7 +296,7 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
|
|||
defer unsub()
|
||||
|
||||
start := time.Now()
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
|
@ -326,14 +329,14 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
|
|||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord2.Close()
|
||||
|
||||
|
@ -453,17 +456,17 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
|
|||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, pubsub := dbtestutil.NewDB(t)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord2.Close()
|
||||
coord3, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
|
||||
coord3, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord3.Close()
|
||||
|
||||
|
@ -516,6 +519,76 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
|
|||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinator_Unhealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mStore := dbmock.NewMockStore(ctrl)
|
||||
ps := pubsub.NewInMemory()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
calls := make(chan struct{})
|
||||
threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
||||
Times(3).
|
||||
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
|
||||
Return(database.TailnetCoordinator{}, xerrors.New("test disconnect"))
|
||||
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
||||
MinTimes(1).
|
||||
After(threeMissed).
|
||||
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
|
||||
Return(database.TailnetCoordinator{}, nil)
|
||||
// extra calls we don't particularly care about for this test
|
||||
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
|
||||
mStore.EXPECT().GetTailnetClientsForAgent(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
|
||||
mStore.EXPECT().DeleteTailnetAgent(gomock.Any(), gomock.Any()).
|
||||
AnyTimes().Return(database.DeleteTailnetAgentRow{}, nil)
|
||||
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
|
||||
|
||||
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := uut.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
agent1 := newTestAgent(t, uut)
|
||||
defer agent1.close()
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
case calls <- struct{}{}:
|
||||
// OK
|
||||
}
|
||||
}
|
||||
// connected agent should be disconnected
|
||||
agent1.waitForClose(ctx, t)
|
||||
|
||||
// new agent should immediately disconnect
|
||||
agent2 := newTestAgent(t, uut)
|
||||
defer agent2.close()
|
||||
agent2.waitForClose(ctx, t)
|
||||
|
||||
// next heartbeats succeed, so we are healthy
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
case calls <- struct{}{}:
|
||||
// OK
|
||||
}
|
||||
}
|
||||
agent3 := newTestAgent(t, uut)
|
||||
defer agent3.close()
|
||||
select {
|
||||
case <-agent3.closeChan:
|
||||
t.Fatal("agent conn closed after we are healthy")
|
||||
case <-time.After(time.Second):
|
||||
// OK
|
||||
}
|
||||
}
|
||||
|
||||
type testConn struct {
|
||||
ws, serverWS net.Conn
|
||||
nodeChan chan []*agpl.Node
|
||||
|
|
Loading…
Reference in New Issue