fix: fix race in PGCoord at startup (#9144)

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2023-08-18 09:53:03 +04:00 committed by GitHub
parent c0a78533bf
commit 2f46f2315c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 57 deletions

View File

@ -403,11 +403,15 @@ func (b *binder) writeOne(bnd binding) error {
CoordinatorID: b.coordinatorID,
Node: nodeRaw,
})
b.logger.Debug(b.ctx, "upserted agent binding",
slog.F("agent_id", bnd.agent), slog.F("node", nodeRaw), slog.Error(err))
case bnd.isAgent() && len(nodeRaw) == 0:
_, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{
ID: bnd.agent,
CoordinatorID: b.coordinatorID,
})
b.logger.Debug(b.ctx, "deleted agent binding",
slog.F("agent_id", bnd.agent), slog.Error(err))
if xerrors.Is(err, sql.ErrNoRows) {
// treat deletes as idempotent
err = nil
@ -419,11 +423,16 @@ func (b *binder) writeOne(bnd binding) error {
AgentID: bnd.agent,
Node: nodeRaw,
})
b.logger.Debug(b.ctx, "upserted client binding",
slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client),
slog.F("node", nodeRaw), slog.Error(err))
case bnd.isClient() && len(nodeRaw) == 0:
_, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{
ID: bnd.client,
CoordinatorID: b.coordinatorID,
})
b.logger.Debug(b.ctx, "deleted client binding",
slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), slog.Error(err))
if xerrors.Is(err, sql.ErrNoRows) {
// treat deletes as idempotent
err = nil
@ -620,7 +629,7 @@ func newQuerier(
updates: updates,
healthy: true, // assume we start healthy
}
go q.subscribe()
q.subscribe()
go q.handleConnIO()
for i := 0; i < numWorkers; i++ {
go q.worker()
@ -748,6 +757,8 @@ func (q *querier) query(mk mKey) error {
func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) {
clients, err := q.store.GetTailnetClientsForAgent(q.ctx, agent)
q.logger.Debug(q.ctx, "queried clients of agent",
slog.F("agent_id", agent), slog.F("num_clients", len(clients)), slog.Error(err))
if err != nil {
return nil, err
}
@ -772,6 +783,8 @@ func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) {
func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) {
agents, err := q.store.GetTailnetAgents(q.ctx, agentID)
q.logger.Debug(q.ctx, "queried agents",
slog.F("agent_id", agentID), slog.F("num_agents", len(agents)), slog.Error(err))
if err != nil {
return nil, err
}
@ -793,50 +806,62 @@ func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) {
return mappings, nil
}
// subscribe starts our subscriptions to client and agent updates in a new goroutine, and returns once we are subscribed
// or the querier context is canceled.
func (q *querier) subscribe() {
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0 // retry indefinitely
eb.MaxInterval = dbMaxBackoff
bkoff := backoff.WithContext(eb, q.ctx)
var cancelClient context.CancelFunc
err := backoff.Retry(func() error {
cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient)
subscribed := make(chan struct{})
go func() {
defer close(subscribed)
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0 // retry indefinitely
eb.MaxInterval = dbMaxBackoff
bkoff := backoff.WithContext(eb, q.ctx)
var cancelClient context.CancelFunc
err := backoff.Retry(func() error {
cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient)
if err != nil {
q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err))
return err
}
cancelClient = cancelFn
return nil
}, bkoff)
if err != nil {
q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err))
return err
if q.ctx.Err() == nil {
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
}
return
}
cancelClient = cancelFn
return nil
}, bkoff)
if err != nil {
if q.ctx.Err() == nil {
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
}
return
}
defer cancelClient()
bkoff.Reset()
defer cancelClient()
bkoff.Reset()
q.logger.Debug(q.ctx, "subscribed to client updates")
var cancelAgent context.CancelFunc
err = backoff.Retry(func() error {
cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent)
var cancelAgent context.CancelFunc
err = backoff.Retry(func() error {
cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent)
if err != nil {
q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err))
return err
}
cancelAgent = cancelFn
return nil
}, bkoff)
if err != nil {
q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err))
return err
if q.ctx.Err() == nil {
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
}
return
}
cancelAgent = cancelFn
return nil
}, bkoff)
if err != nil {
if q.ctx.Err() == nil {
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
}
return
}
defer cancelAgent()
defer cancelAgent()
q.logger.Debug(q.ctx, "subscribed to agent updates")
// hold subscriptions open until context is canceled
<-q.ctx.Done()
// unblock the outer function from returning
subscribed <- struct{}{}
// hold subscriptions open until context is canceled
<-q.ctx.Done()
}()
<-subscribed
}
func (q *querier) listenClient(_ context.Context, msg []byte, err error) {

View File

@ -86,7 +86,7 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator)
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
require.Eventually(t, func() bool {
@ -123,7 +123,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator)
agent := newTestAgent(t, coordinator, "original")
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
@ -151,7 +151,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
agent.waitForClose(ctx, t)
// Create a new agent connection. This is to simulate a reconnect!
agent = newTestAgent(t, coordinator, agent.id)
agent = newTestAgent(t, coordinator, "reconnection", agent.id)
// Ensure the existing listening connIO sends its node immediately!
clientNodes = agent.recvNodes(ctx, t)
require.Len(t, clientNodes, 1)
@ -200,7 +200,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator)
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
@ -333,16 +333,16 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
require.NoError(t, err)
defer coord2.Close()
agent1 := newTestAgent(t, coord1)
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent2 := newTestAgent(t, coord2)
agent2 := newTestAgent(t, coord2, "agent2")
defer agent2.close()
client11 := newTestClient(t, coord1, agent1.id)
@ -460,19 +460,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
require.NoError(t, err)
defer coord2.Close()
coord3, err := tailnet.NewPGCoord(ctx, logger, ps, store)
coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store)
require.NoError(t, err)
defer coord3.Close()
agent1 := newTestAgent(t, coord1)
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent2 := newTestAgent(t, coord2, agent1.id)
agent2 := newTestAgent(t, coord2, "agent2", agent1.id)
defer agent2.close()
client := newTestClient(t, coord3, agent1.id)
@ -552,7 +552,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
err := uut.Close()
require.NoError(t, err)
}()
agent1 := newTestAgent(t, uut)
agent1 := newTestAgent(t, uut, "agent1")
defer agent1.close()
for i := 0; i < 3; i++ {
select {
@ -566,7 +566,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
agent1.waitForClose(ctx, t)
// new agent should immediately disconnect
agent2 := newTestAgent(t, uut)
agent2 := newTestAgent(t, uut, "agent2")
defer agent2.close()
agent2.waitForClose(ctx, t)
@ -579,7 +579,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
// OK
}
}
agent3 := newTestAgent(t, uut)
agent3 := newTestAgent(t, uut, "agent3")
defer agent3.close()
select {
case <-agent3.closeChan:
@ -618,10 +618,10 @@ func newTestConn(ids []uuid.UUID) *testConn {
return a
}
func newTestAgent(t *testing.T, coord agpl.Coordinator, id ...uuid.UUID) *testConn {
func newTestAgent(t *testing.T, coord agpl.Coordinator, name string, id ...uuid.UUID) *testConn {
a := newTestConn(id)
go func() {
err := coord.ServeAgent(a.serverWS, a.id, "")
err := coord.ServeAgent(a.serverWS, a.id, name)
assert.NoError(t, err)
close(a.closeChan)
}()
@ -636,7 +636,7 @@ func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout receiving nodes")
t.Fatalf("testConn id %s: timeout receiving nodes ", c.id)
return nil
case nodes := <-c.nodeChan:
return nodes