mirror of https://github.com/coder/coder.git
fix: fix race in PGCoord at startup (#9144)
Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
c0a78533bf
commit
2f46f2315c
|
@ -403,11 +403,15 @@ func (b *binder) writeOne(bnd binding) error {
|
|||
CoordinatorID: b.coordinatorID,
|
||||
Node: nodeRaw,
|
||||
})
|
||||
b.logger.Debug(b.ctx, "upserted agent binding",
|
||||
slog.F("agent_id", bnd.agent), slog.F("node", nodeRaw), slog.Error(err))
|
||||
case bnd.isAgent() && len(nodeRaw) == 0:
|
||||
_, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{
|
||||
ID: bnd.agent,
|
||||
CoordinatorID: b.coordinatorID,
|
||||
})
|
||||
b.logger.Debug(b.ctx, "deleted agent binding",
|
||||
slog.F("agent_id", bnd.agent), slog.Error(err))
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
// treat deletes as idempotent
|
||||
err = nil
|
||||
|
@ -419,11 +423,16 @@ func (b *binder) writeOne(bnd binding) error {
|
|||
AgentID: bnd.agent,
|
||||
Node: nodeRaw,
|
||||
})
|
||||
b.logger.Debug(b.ctx, "upserted client binding",
|
||||
slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client),
|
||||
slog.F("node", nodeRaw), slog.Error(err))
|
||||
case bnd.isClient() && len(nodeRaw) == 0:
|
||||
_, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{
|
||||
ID: bnd.client,
|
||||
CoordinatorID: b.coordinatorID,
|
||||
})
|
||||
b.logger.Debug(b.ctx, "deleted client binding",
|
||||
slog.F("agent_id", bnd.agent), slog.F("client_id", bnd.client), slog.Error(err))
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
// treat deletes as idempotent
|
||||
err = nil
|
||||
|
@ -620,7 +629,7 @@ func newQuerier(
|
|||
updates: updates,
|
||||
healthy: true, // assume we start healthy
|
||||
}
|
||||
go q.subscribe()
|
||||
q.subscribe()
|
||||
go q.handleConnIO()
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
go q.worker()
|
||||
|
@ -748,6 +757,8 @@ func (q *querier) query(mk mKey) error {
|
|||
|
||||
func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) {
|
||||
clients, err := q.store.GetTailnetClientsForAgent(q.ctx, agent)
|
||||
q.logger.Debug(q.ctx, "queried clients of agent",
|
||||
slog.F("agent_id", agent), slog.F("num_clients", len(clients)), slog.Error(err))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -772,6 +783,8 @@ func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) {
|
|||
|
||||
func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) {
|
||||
agents, err := q.store.GetTailnetAgents(q.ctx, agentID)
|
||||
q.logger.Debug(q.ctx, "queried agents",
|
||||
slog.F("agent_id", agentID), slog.F("num_agents", len(agents)), slog.Error(err))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -793,50 +806,62 @@ func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) {
|
|||
return mappings, nil
|
||||
}
|
||||
|
||||
// subscribe starts our subscriptions to client and agent updates in a new goroutine, and returns once we are subscribed
|
||||
// or the querier context is canceled.
|
||||
func (q *querier) subscribe() {
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0 // retry indefinitely
|
||||
eb.MaxInterval = dbMaxBackoff
|
||||
bkoff := backoff.WithContext(eb, q.ctx)
|
||||
var cancelClient context.CancelFunc
|
||||
err := backoff.Retry(func() error {
|
||||
cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient)
|
||||
subscribed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(subscribed)
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0 // retry indefinitely
|
||||
eb.MaxInterval = dbMaxBackoff
|
||||
bkoff := backoff.WithContext(eb, q.ctx)
|
||||
var cancelClient context.CancelFunc
|
||||
err := backoff.Retry(func() error {
|
||||
cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient)
|
||||
if err != nil {
|
||||
q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
cancelClient = cancelFn
|
||||
return nil
|
||||
}, bkoff)
|
||||
if err != nil {
|
||||
q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err))
|
||||
return err
|
||||
if q.ctx.Err() == nil {
|
||||
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
cancelClient = cancelFn
|
||||
return nil
|
||||
}, bkoff)
|
||||
if err != nil {
|
||||
if q.ctx.Err() == nil {
|
||||
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
defer cancelClient()
|
||||
bkoff.Reset()
|
||||
defer cancelClient()
|
||||
bkoff.Reset()
|
||||
q.logger.Debug(q.ctx, "subscribed to client updates")
|
||||
|
||||
var cancelAgent context.CancelFunc
|
||||
err = backoff.Retry(func() error {
|
||||
cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent)
|
||||
var cancelAgent context.CancelFunc
|
||||
err = backoff.Retry(func() error {
|
||||
cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent)
|
||||
if err != nil {
|
||||
q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
cancelAgent = cancelFn
|
||||
return nil
|
||||
}, bkoff)
|
||||
if err != nil {
|
||||
q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err))
|
||||
return err
|
||||
if q.ctx.Err() == nil {
|
||||
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
cancelAgent = cancelFn
|
||||
return nil
|
||||
}, bkoff)
|
||||
if err != nil {
|
||||
if q.ctx.Err() == nil {
|
||||
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
defer cancelAgent()
|
||||
defer cancelAgent()
|
||||
q.logger.Debug(q.ctx, "subscribed to agent updates")
|
||||
|
||||
// hold subscriptions open until context is canceled
|
||||
<-q.ctx.Done()
|
||||
// unblock the outer function from returning
|
||||
subscribed <- struct{}{}
|
||||
|
||||
// hold subscriptions open until context is canceled
|
||||
<-q.ctx.Done()
|
||||
}()
|
||||
<-subscribed
|
||||
}
|
||||
|
||||
func (q *querier) listenClient(_ context.Context, msg []byte, err error) {
|
||||
|
|
|
@ -86,7 +86,7 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
agent := newTestAgent(t, coordinator)
|
||||
agent := newTestAgent(t, coordinator, "agent")
|
||||
defer agent.close()
|
||||
agent.sendNode(&agpl.Node{PreferredDERP: 10})
|
||||
require.Eventually(t, func() bool {
|
||||
|
@ -123,7 +123,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
agent := newTestAgent(t, coordinator)
|
||||
agent := newTestAgent(t, coordinator, "original")
|
||||
defer agent.close()
|
||||
agent.sendNode(&agpl.Node{PreferredDERP: 10})
|
||||
|
||||
|
@ -151,7 +151,7 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
|||
agent.waitForClose(ctx, t)
|
||||
|
||||
// Create a new agent connection. This is to simulate a reconnect!
|
||||
agent = newTestAgent(t, coordinator, agent.id)
|
||||
agent = newTestAgent(t, coordinator, "reconnection", agent.id)
|
||||
// Ensure the existing listening connIO sends its node immediately!
|
||||
clientNodes = agent.recvNodes(ctx, t)
|
||||
require.Len(t, clientNodes, 1)
|
||||
|
@ -200,7 +200,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
agent := newTestAgent(t, coordinator)
|
||||
agent := newTestAgent(t, coordinator, "agent")
|
||||
defer agent.close()
|
||||
agent.sendNode(&agpl.Node{PreferredDERP: 10})
|
||||
|
||||
|
@ -333,16 +333,16 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord2.Close()
|
||||
|
||||
agent1 := newTestAgent(t, coord1)
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
agent2 := newTestAgent(t, coord2)
|
||||
agent2 := newTestAgent(t, coord2, "agent2")
|
||||
defer agent2.close()
|
||||
|
||||
client11 := newTestClient(t, coord1, agent1.id)
|
||||
|
@ -460,19 +460,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord2.Close()
|
||||
coord3, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord3.Close()
|
||||
|
||||
agent1 := newTestAgent(t, coord1)
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
agent2 := newTestAgent(t, coord2, agent1.id)
|
||||
agent2 := newTestAgent(t, coord2, "agent2", agent1.id)
|
||||
defer agent2.close()
|
||||
|
||||
client := newTestClient(t, coord3, agent1.id)
|
||||
|
@ -552,7 +552,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|||
err := uut.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
agent1 := newTestAgent(t, uut)
|
||||
agent1 := newTestAgent(t, uut, "agent1")
|
||||
defer agent1.close()
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
|
@ -566,7 +566,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|||
agent1.waitForClose(ctx, t)
|
||||
|
||||
// new agent should immediately disconnect
|
||||
agent2 := newTestAgent(t, uut)
|
||||
agent2 := newTestAgent(t, uut, "agent2")
|
||||
defer agent2.close()
|
||||
agent2.waitForClose(ctx, t)
|
||||
|
||||
|
@ -579,7 +579,7 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|||
// OK
|
||||
}
|
||||
}
|
||||
agent3 := newTestAgent(t, uut)
|
||||
agent3 := newTestAgent(t, uut, "agent3")
|
||||
defer agent3.close()
|
||||
select {
|
||||
case <-agent3.closeChan:
|
||||
|
@ -618,10 +618,10 @@ func newTestConn(ids []uuid.UUID) *testConn {
|
|||
return a
|
||||
}
|
||||
|
||||
func newTestAgent(t *testing.T, coord agpl.Coordinator, id ...uuid.UUID) *testConn {
|
||||
func newTestAgent(t *testing.T, coord agpl.Coordinator, name string, id ...uuid.UUID) *testConn {
|
||||
a := newTestConn(id)
|
||||
go func() {
|
||||
err := coord.ServeAgent(a.serverWS, a.id, "")
|
||||
err := coord.ServeAgent(a.serverWS, a.id, name)
|
||||
assert.NoError(t, err)
|
||||
close(a.closeChan)
|
||||
}()
|
||||
|
@ -636,7 +636,7 @@ func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node {
|
|||
t.Helper()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout receiving nodes")
|
||||
t.Fatalf("testConn id %s: timeout receiving nodes ", c.id)
|
||||
return nil
|
||||
case nodes := <-c.nodeChan:
|
||||
return nodes
|
||||
|
|
Loading…
Reference in New Issue