fix: make PGCoordinator close connections when unhealthy (#9125)

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2023-08-17 09:36:47 +04:00 committed by GitHub
parent c217a0d819
commit c7a6d626b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 183 additions and 28 deletions

View File

@ -211,6 +211,7 @@ issues:
run:
skip-dirs:
- node_modules
- .git
skip-files:
- scripts/rules.go
timeout: 10m

View File

@ -586,10 +586,12 @@ type querier struct {
workQ *workQ[mKey]
heartbeats *heartbeats
updates <-chan struct{}
updates <-chan hbUpdate
mu sync.Mutex
mappers map[mKey]*countedMapper
conns map[*connIO]struct{}
healthy bool
}
type countedMapper struct {
@ -604,7 +606,7 @@ func newQuerier(
self uuid.UUID, newConnections chan *connIO, numWorkers int,
firstHeartbeat chan<- struct{},
) *querier {
updates := make(chan struct{})
updates := make(chan hbUpdate)
q := &querier{
ctx: ctx,
logger: logger.Named("querier"),
@ -614,7 +616,9 @@ func newQuerier(
workQ: newWorkQ[mKey](ctx),
heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat),
mappers: make(map[mKey]*countedMapper),
conns: make(map[*connIO]struct{}),
updates: updates,
healthy: true, // assume we start healthy
}
go q.subscribe()
go q.handleConnIO()
@ -639,6 +643,15 @@ func (q *querier) handleConnIO() {
func (q *querier) newConn(c *connIO) {
q.mu.Lock()
defer q.mu.Unlock()
if !q.healthy {
err := c.updates.Close()
q.logger.Info(q.ctx, "closed incoming connection while unhealthy",
slog.Error(err),
slog.F("agent_id", c.agent),
slog.F("client_id", c.client),
)
return
}
mk := mKey{
agent: c.agent,
// if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself
@ -661,6 +674,7 @@ func (q *querier) newConn(c *connIO) {
return
}
cm.count++
q.conns[c] = struct{}{}
go q.cleanupConn(c)
}
@ -668,6 +682,7 @@ func (q *querier) cleanupConn(c *connIO) {
<-c.ctx.Done()
q.mu.Lock()
defer q.mu.Unlock()
delete(q.conns, c)
mk := mKey{
agent: c.agent,
// if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself
@ -911,8 +926,18 @@ func (q *querier) handleUpdates() {
select {
case <-q.ctx.Done():
return
case <-q.updates:
q.updateAll()
case u := <-q.updates:
if u.filter == filterUpdateUpdated {
q.updateAll()
}
if u.health == healthUpdateUnhealthy {
q.unhealthyCloseAll()
continue
}
if u.health == healthUpdateHealthy {
q.setHealthy()
continue
}
}
}
}
@ -932,6 +957,30 @@ func (q *querier) updateAll() {
}
}
// unhealthyCloseAll marks the coordinator unhealthy and closes all connections. We do this so that clients and agents
// are forced to reconnect to the coordinator, and will hopefully land on a healthy coordinator.
func (q *querier) unhealthyCloseAll() {
q.mu.Lock()
defer q.mu.Unlock()
q.healthy = false
for c := range q.conns {
// close connections async so that we don't block the querier routine that responds to updates
go func(c *connIO) {
err := c.updates.Close()
if err != nil {
q.logger.Debug(q.ctx, "error closing conn while unhealthy", slog.Error(err))
}
}(c)
// NOTE: we don't need to remove the connection from the map, as that will happen async in q.cleanupConn()
}
}
func (q *querier) setHealthy() {
q.mu.Lock()
defer q.mu.Unlock()
q.healthy = true
}
func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAgent, map[uuid.UUID][]database.TailnetClient, error) {
agents, err := q.store.GetAllTailnetAgents(ctx)
if err != nil {
@ -1078,6 +1127,28 @@ func (q *workQ[K]) done(key K) {
q.cond.Signal()
}
type filterUpdate int
const (
filterUpdateNone filterUpdate = iota
filterUpdateUpdated
)
type healthUpdate int
const (
healthUpdateNone healthUpdate = iota
healthUpdateHealthy
healthUpdateUnhealthy
)
// hbUpdate is an update sent from the heartbeats to the querier. Zero values of the fields mean no update of that
// kind.
type hbUpdate struct {
filter filterUpdate
health healthUpdate
}
// heartbeats sends heartbeats for this coordinator on a timer, and monitors heartbeats from other coordinators. If a
// coordinator misses their heartbeat, we remove it from our map of "valid" coordinators, such that we will filter out
// any mappings for it when filter() is called, and we send a signal on the update channel, which triggers all mappers
@ -1089,8 +1160,9 @@ type heartbeats struct {
store database.Store
self uuid.UUID
update chan<- struct{}
firstHeartbeat chan<- struct{}
update chan<- hbUpdate
firstHeartbeat chan<- struct{}
failedHeartbeats int
lock sync.RWMutex
coordinators map[uuid.UUID]time.Time
@ -1103,7 +1175,7 @@ type heartbeats struct {
func newHeartbeats(
ctx context.Context, logger slog.Logger,
ps pubsub.Pubsub, store database.Store,
self uuid.UUID, update chan<- struct{},
self uuid.UUID, update chan<- hbUpdate,
firstHeartbeat chan<- struct{},
) *heartbeats {
h := &heartbeats{
@ -1194,7 +1266,7 @@ func (h *heartbeats) recvBeat(id uuid.UUID) {
h.logger.Info(h.ctx, "heartbeats (re)started", slog.F("other_coordinator_id", id))
// send on a separate goroutine to avoid holding lock. Triggering update can be async
go func() {
_ = sendCtx(h.ctx, h.update, struct{}{})
_ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
}()
}
h.coordinators[id] = time.Now()
@ -1241,7 +1313,7 @@ func (h *heartbeats) checkExpiry() {
if expired {
// send on a separate goroutine to avoid holding lock. Triggering update can be async
go func() {
_ = sendCtx(h.ctx, h.update, struct{}{})
_ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
}()
}
// we need to reset the timer for when the next oldest coordinator will expire, if any.
@ -1269,11 +1341,20 @@ func (h *heartbeats) sendBeats() {
func (h *heartbeats) sendBeat() {
_, err := h.store.UpsertTailnetCoordinator(h.ctx, h.self)
if err != nil {
// just log errors, heartbeats are rescheduled on a timer
h.logger.Error(h.ctx, "failed to send heartbeat", slog.Error(err))
h.failedHeartbeats++
if h.failedHeartbeats == 3 {
h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy")
_ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
}
return
}
h.logger.Debug(h.ctx, "sent heartbeat")
if h.failedHeartbeats >= 3 {
h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy")
_ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
}
h.failedHeartbeats = 0
}
func (h *heartbeats) sendDelete() {

View File

@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -21,7 +22,9 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbmock"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/enterprise/tailnet"
agpl "github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
@ -36,11 +39,11 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
@ -75,11 +78,11 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
@ -112,11 +115,11 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
@ -189,11 +192,11 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
@ -276,14 +279,14 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
mu := sync.Mutex{}
heartbeats := []time.Time{}
unsub, err := pubsub.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) {
unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) {
assert.NoError(t, err)
mu.Lock()
defer mu.Unlock()
@ -293,7 +296,7 @@ func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
defer unsub()
start := time.Now()
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
@ -326,14 +329,14 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coord2.Close()
@ -453,17 +456,17 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coord2.Close()
coord3, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
coord3, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coord3.Close()
@ -516,6 +519,76 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
assertEventuallyNoAgents(ctx, t, store, agent1.id)
}
func TestPGCoordinator_Unhealthy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
ctrl := gomock.NewController(t)
mStore := dbmock.NewMockStore(ctrl)
ps := pubsub.NewInMemory()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
calls := make(chan struct{})
threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
Times(3).
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
Return(database.TailnetCoordinator{}, xerrors.New("test disconnect"))
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
MinTimes(1).
After(threeMissed).
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
Return(database.TailnetCoordinator{}, nil)
// extra calls we don't particularly care about for this test
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().GetTailnetClientsForAgent(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
mStore.EXPECT().DeleteTailnetAgent(gomock.Any(), gomock.Any()).
AnyTimes().Return(database.DeleteTailnetAgentRow{}, nil)
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
require.NoError(t, err)
defer func() {
err := uut.Close()
require.NoError(t, err)
}()
agent1 := newTestAgent(t, uut)
defer agent1.close()
for i := 0; i < 3; i++ {
select {
case <-ctx.Done():
t.Fatal("timeout")
case calls <- struct{}{}:
// OK
}
}
// connected agent should be disconnected
agent1.waitForClose(ctx, t)
// new agent should immediately disconnect
agent2 := newTestAgent(t, uut)
defer agent2.close()
agent2.waitForClose(ctx, t)
// next heartbeats succeed, so we are healthy
for i := 0; i < 2; i++ {
select {
case <-ctx.Done():
t.Fatal("timeout")
case calls <- struct{}{}:
// OK
}
}
agent3 := newTestAgent(t, uut)
defer agent3.close()
select {
case <-agent3.closeChan:
t.Fatal("agent conn closed after we are healthy")
case <-time.After(time.Second):
// OK
}
}
type testConn struct {
ws, serverWS net.Conn
nodeChan chan []*agpl.Node