mirror of https://github.com/coder/coder.git
fix: stop holding Pubsub mutex while calling pq.Listener (#12518)
fixes #11950 https://github.com/coder/coder/issues/11950#issuecomment-1987756088 explains the bug We were also calling into `Unlisten()` and `Close()` while holding the mutex. I don't believe that `Close()` depends on the notification loop being unblocked, but it's hard to be sure, and the safest thing to do is assume it could block. So, I added a unit test that fakes out `pq.Listener` and sends a bunch of notifies every time we call into it to hopefully prevent regression where we hold the mutex while calling into these functions. It also removes the use of a `context.Context` to stop the PubSub -- it must be explicitly `Closed()`. This simplifies a bunch of the logic, and is how we use the pubsub anyway.
This commit is contained in:
parent
6f00ccfa64
commit
51707446d0
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -164,16 +165,36 @@ func (q *msgQueue) dropped() {
|
|||
q.cond.Broadcast()
|
||||
}
|
||||
|
||||
// pqListener is an interface that represents a *pq.Listener for testing
|
||||
type pqListener interface {
|
||||
io.Closer
|
||||
Listen(string) error
|
||||
Unlisten(string) error
|
||||
NotifyChan() <-chan *pq.Notification
|
||||
}
|
||||
|
||||
type pqListenerShim struct {
|
||||
*pq.Listener
|
||||
}
|
||||
|
||||
func (l pqListenerShim) NotifyChan() <-chan *pq.Notification {
|
||||
return l.Notify
|
||||
}
|
||||
|
||||
// PGPubsub is a pubsub implementation using PostgreSQL.
|
||||
type PGPubsub struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger slog.Logger
|
||||
listenDone chan struct{}
|
||||
pgListener *pq.Listener
|
||||
db *sql.DB
|
||||
mut sync.Mutex
|
||||
queues map[string]map[uuid.UUID]*msgQueue
|
||||
logger slog.Logger
|
||||
listenDone chan struct{}
|
||||
pgListener pqListener
|
||||
db *sql.DB
|
||||
|
||||
qMu sync.Mutex
|
||||
queues map[string]map[uuid.UUID]*msgQueue
|
||||
|
||||
// making the close state its own mutex domain simplifies closing logic so
|
||||
// that we don't have to hold the qMu --- which could block processing
|
||||
// notifications while the pqListener is closing.
|
||||
closeMu sync.Mutex
|
||||
closedListener bool
|
||||
closeListenerErr error
|
||||
|
||||
|
@ -192,16 +213,14 @@ const BufferSize = 2048
|
|||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil))
|
||||
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
|
||||
}
|
||||
|
||||
func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener))
|
||||
return p.subscribeQueue(event, newMsgQueue(context.Background(), nil, listener))
|
||||
}
|
||||
|
||||
func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
// if we hit an error, we need to close the queue so we don't
|
||||
|
@ -213,9 +232,13 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
|
|||
}
|
||||
}()
|
||||
|
||||
// The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches
|
||||
// notifies. We need to avoid holding the mutex while this happens, since holding the mutex
|
||||
// blocks reading notifications and can deadlock the pgListener.
|
||||
// c.f. https://github.com/coder/coder/issues/11950
|
||||
err = p.pgListener.Listen(event)
|
||||
if err == nil {
|
||||
p.logger.Debug(p.ctx, "started listening to event channel", slog.F("event", event))
|
||||
p.logger.Debug(context.Background(), "started listening to event channel", slog.F("event", event))
|
||||
}
|
||||
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
|
||||
// It's ok if it's already open!
|
||||
|
@ -224,6 +247,8 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen: %w", err)
|
||||
}
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
|
||||
var eventQs map[uuid.UUID]*msgQueue
|
||||
var ok bool
|
||||
|
@ -234,30 +259,36 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
|
|||
id := uuid.New()
|
||||
eventQs[id] = newQ
|
||||
return func() {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
p.qMu.Lock()
|
||||
listeners := p.queues[event]
|
||||
q := listeners[id]
|
||||
q.close()
|
||||
delete(listeners, id)
|
||||
if len(listeners) == 0 {
|
||||
delete(p.queues, event)
|
||||
}
|
||||
p.qMu.Unlock()
|
||||
// as above, we must not hold the lock while calling into pgListener
|
||||
|
||||
if len(listeners) == 0 {
|
||||
uErr := p.pgListener.Unlisten(event)
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
if uErr != nil && !p.closedListener {
|
||||
p.logger.Warn(p.ctx, "failed to unlisten", slog.Error(uErr), slog.F("event", event))
|
||||
p.logger.Warn(context.Background(), "failed to unlisten", slog.Error(uErr), slog.F("event", event))
|
||||
} else {
|
||||
p.logger.Debug(p.ctx, "stopped listening to event channel", slog.F("event", event))
|
||||
p.logger.Debug(context.Background(), "stopped listening to event channel", slog.F("event", event))
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PGPubsub) Publish(event string, message []byte) error {
|
||||
p.logger.Debug(p.ctx, "publish", slog.F("event", event), slog.F("message_len", len(message)))
|
||||
p.logger.Debug(context.Background(), "publish", slog.F("event", event), slog.F("message_len", len(message)))
|
||||
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
|
||||
// support the first parameter being a prepared statement.
|
||||
//nolint:gosec
|
||||
_, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
|
||||
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
|
||||
if err != nil {
|
||||
p.publishesTotal.WithLabelValues("false").Inc()
|
||||
return xerrors.Errorf("exec pg_notify: %w", err)
|
||||
|
@ -269,53 +300,38 @@ func (p *PGPubsub) Publish(event string, message []byte) error {
|
|||
|
||||
// Close closes the pubsub instance.
|
||||
func (p *PGPubsub) Close() error {
|
||||
p.logger.Info(p.ctx, "pubsub is closing")
|
||||
p.cancel()
|
||||
p.logger.Info(context.Background(), "pubsub is closing")
|
||||
err := p.closeListener()
|
||||
<-p.listenDone
|
||||
p.logger.Debug(p.ctx, "pubsub closed")
|
||||
p.logger.Debug(context.Background(), "pubsub closed")
|
||||
return err
|
||||
}
|
||||
|
||||
// closeListener closes the pgListener, unless it has already been closed.
|
||||
func (p *PGPubsub) closeListener() error {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
if p.closedListener {
|
||||
return p.closeListenerErr
|
||||
}
|
||||
p.closeListenerErr = p.pgListener.Close()
|
||||
p.closedListener = true
|
||||
p.closeListenerErr = p.pgListener.Close()
|
||||
|
||||
return p.closeListenerErr
|
||||
}
|
||||
|
||||
// listen begins receiving messages on the pq listener.
|
||||
func (p *PGPubsub) listen() {
|
||||
defer func() {
|
||||
p.logger.Info(p.ctx, "pubsub listen stopped receiving notify")
|
||||
cErr := p.closeListener()
|
||||
if cErr != nil {
|
||||
p.logger.Error(p.ctx, "failed to close listener")
|
||||
}
|
||||
p.logger.Info(context.Background(), "pubsub listen stopped receiving notify")
|
||||
close(p.listenDone)
|
||||
}()
|
||||
|
||||
var (
|
||||
notif *pq.Notification
|
||||
ok bool
|
||||
)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
case notif, ok = <-p.pgListener.Notify:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
notify := p.pgListener.NotifyChan()
|
||||
for notif := range notify {
|
||||
// A nil notification can be dispatched on reconnect.
|
||||
if notif == nil {
|
||||
p.logger.Debug(p.ctx, "notifying subscribers of a reconnection")
|
||||
p.logger.Debug(context.Background(), "notifying subscribers of a reconnection")
|
||||
p.recordReconnect()
|
||||
continue
|
||||
}
|
||||
|
@ -331,8 +347,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
|
|||
p.messagesTotal.WithLabelValues(sizeLabel).Inc()
|
||||
p.receivedBytesTotal.Add(float64(len(notif.Extra)))
|
||||
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
queues, ok := p.queues[notif.Channel]
|
||||
if !ok {
|
||||
return
|
||||
|
@ -344,8 +360,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
|
|||
}
|
||||
|
||||
func (p *PGPubsub) recordReconnect() {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
for _, listeners := range p.queues {
|
||||
for _, q := range listeners {
|
||||
q.dropped()
|
||||
|
@ -409,30 +425,32 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
|||
d: net.Dialer{},
|
||||
}
|
||||
)
|
||||
p.pgListener = pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
|
||||
switch t {
|
||||
case pq.ListenerEventConnected:
|
||||
p.logger.Info(ctx, "pubsub connected to postgres")
|
||||
p.connected.Set(1.0)
|
||||
case pq.ListenerEventDisconnected:
|
||||
p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err))
|
||||
p.connected.Set(0)
|
||||
case pq.ListenerEventReconnected:
|
||||
p.logger.Info(ctx, "pubsub reconnected to postgres")
|
||||
p.connected.Set(1)
|
||||
case pq.ListenerEventConnectionAttemptFailed:
|
||||
p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err))
|
||||
}
|
||||
// This callback gets events whenever the connection state changes.
|
||||
// Don't send if the errChannel has already been closed.
|
||||
select {
|
||||
case <-errCh:
|
||||
return
|
||||
default:
|
||||
errCh <- err
|
||||
close(errCh)
|
||||
}
|
||||
})
|
||||
p.pgListener = pqListenerShim{
|
||||
Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
|
||||
switch t {
|
||||
case pq.ListenerEventConnected:
|
||||
p.logger.Info(ctx, "pubsub connected to postgres")
|
||||
p.connected.Set(1.0)
|
||||
case pq.ListenerEventDisconnected:
|
||||
p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err))
|
||||
p.connected.Set(0)
|
||||
case pq.ListenerEventReconnected:
|
||||
p.logger.Info(ctx, "pubsub reconnected to postgres")
|
||||
p.connected.Set(1)
|
||||
case pq.ListenerEventConnectionAttemptFailed:
|
||||
p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err))
|
||||
}
|
||||
// This callback gets events whenever the connection state changes.
|
||||
// Don't send if the errChannel has already been closed.
|
||||
select {
|
||||
case <-errCh:
|
||||
return
|
||||
default:
|
||||
errCh <- err
|
||||
close(errCh)
|
||||
}
|
||||
}),
|
||||
}
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
|
@ -501,24 +519,31 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
|
|||
p.connected.Collect(metrics)
|
||||
|
||||
// implicit metrics
|
||||
p.mut.Lock()
|
||||
p.qMu.Lock()
|
||||
events := len(p.queues)
|
||||
subs := 0
|
||||
for _, subscriberMap := range p.queues {
|
||||
subs += len(subscriberMap)
|
||||
}
|
||||
p.mut.Unlock()
|
||||
p.qMu.Unlock()
|
||||
metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs))
|
||||
metrics <- prometheus.MustNewConstMetric(currentEventsDesc, prometheus.GaugeValue, float64(events))
|
||||
}
|
||||
|
||||
// New creates a new Pubsub implementation using a PostgreSQL connection.
|
||||
func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) {
|
||||
// Start a new context that will be canceled when the pubsub is closed.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p := &PGPubsub{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
p := newWithoutListener(logger, database)
|
||||
if err := p.startListener(startCtx, connectURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go p.listen()
|
||||
logger.Info(startCtx, "pubsub has started")
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// newWithoutListener creates a new PGPubsub without creating the pqListener.
|
||||
func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub {
|
||||
return &PGPubsub{
|
||||
logger: logger,
|
||||
listenDone: make(chan struct{}),
|
||||
db: database,
|
||||
|
@ -567,10 +592,4 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect
|
|||
Help: "Whether we are connected (1) or not connected (0) to postgres",
|
||||
}),
|
||||
}
|
||||
if err := p.startListener(startCtx, connectURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go p.listen()
|
||||
logger.Info(ctx, "pubsub has started")
|
||||
return p, nil
|
||||
}
|
||||
|
|
|
@ -3,10 +3,15 @@ package pubsub
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
|
@ -138,3 +143,115 @@ func Test_msgQueue_Full(t *testing.T) {
|
|||
// for the error, so we read 2 less than we sent.
|
||||
require.Equal(t, BufferSize, n)
|
||||
}
|
||||
|
||||
func TestPubSub_DoesntBlockNotify(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
uut := newWithoutListener(logger, nil)
|
||||
fListener := newFakePqListener()
|
||||
uut.pgListener = fListener
|
||||
go uut.listen()
|
||||
|
||||
cancels := make(chan func())
|
||||
go func() {
|
||||
subCancel, err := uut.Subscribe("bagels", func(ctx context.Context, message []byte) {
|
||||
t.Logf("got message: %s", string(message))
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
cancels <- subCancel
|
||||
}()
|
||||
subCancel := testutil.RequireRecvCtx(ctx, t, cancels)
|
||||
cancelDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(cancelDone)
|
||||
subCancel()
|
||||
}()
|
||||
testutil.RequireRecvCtx(ctx, t, cancelDone)
|
||||
|
||||
closeErrs := make(chan error)
|
||||
go func() {
|
||||
closeErrs <- uut.Close()
|
||||
}()
|
||||
err := testutil.RequireRecvCtx(ctx, t, closeErrs)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
const (
|
||||
numNotifications = 5
|
||||
testMessage = "birds of a feather"
|
||||
)
|
||||
|
||||
// fakePqListener is a fake version of pq.Listener. This test code tests for regressions of
|
||||
// https://github.com/coder/coder/issues/11950 where pq.Listener deadlocked because we blocked the
|
||||
// PGPubsub.listen() goroutine while calling other pq.Listener functions. So, all function calls
|
||||
// into the fakePqListener will send 5 notifications before returning to ensure the listen()
|
||||
// goroutine is unblocked.
|
||||
type fakePqListener struct {
|
||||
mu sync.Mutex
|
||||
channels map[string]struct{}
|
||||
notify chan *pq.Notification
|
||||
}
|
||||
|
||||
func (f *fakePqListener) Close() error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
ch := f.getTestChanLocked()
|
||||
for i := 0; i < numNotifications; i++ {
|
||||
f.notify <- &pq.Notification{Channel: ch, Extra: testMessage}
|
||||
}
|
||||
// note that the realPqListener must only be closed once, so go ahead and
|
||||
// close the notify unprotected here. If it panics, we have a bug.
|
||||
close(f.notify)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakePqListener) Listen(s string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
ch := f.getTestChanLocked()
|
||||
for i := 0; i < numNotifications; i++ {
|
||||
f.notify <- &pq.Notification{Channel: ch, Extra: testMessage}
|
||||
}
|
||||
if _, ok := f.channels[s]; ok {
|
||||
return pq.ErrChannelAlreadyOpen
|
||||
}
|
||||
f.channels[s] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakePqListener) Unlisten(s string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
ch := f.getTestChanLocked()
|
||||
for i := 0; i < numNotifications; i++ {
|
||||
f.notify <- &pq.Notification{Channel: ch, Extra: testMessage}
|
||||
}
|
||||
if _, ok := f.channels[s]; ok {
|
||||
delete(f.channels, s)
|
||||
return nil
|
||||
}
|
||||
return pq.ErrChannelNotOpen
|
||||
}
|
||||
|
||||
func (f *fakePqListener) NotifyChan() <-chan *pq.Notification {
|
||||
return f.notify
|
||||
}
|
||||
|
||||
// getTestChanLocked returns the name of a channel we are currently listening for, if there is one.
|
||||
// Otherwise, it just returns "test". We prefer to send test notifications for channels that appear
|
||||
// in the tests, but if there are none, just return anything.
|
||||
func (f *fakePqListener) getTestChanLocked() string {
|
||||
for c := range f.channels {
|
||||
return c
|
||||
}
|
||||
return "test"
|
||||
}
|
||||
|
||||
func newFakePqListener() *fakePqListener {
|
||||
return &fakePqListener{
|
||||
channels: make(map[string]struct{}),
|
||||
notify: make(chan *pq.Notification),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -109,60 +109,6 @@ func TestPubsub(t *testing.T) {
|
|||
message := <-messageChannel
|
||||
assert.Equal(t, string(message), data)
|
||||
})
|
||||
|
||||
t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
connectionURL, closePg, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
defer closePg()
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := pubsub.New(ctx, logger, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
|
||||
event := "test"
|
||||
done := make(chan struct{})
|
||||
called := make(chan struct{})
|
||||
unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) {
|
||||
defer close(done)
|
||||
select {
|
||||
case <-subCtx.Done():
|
||||
assert.Fail(t, "context should not be canceled")
|
||||
default:
|
||||
}
|
||||
close(called)
|
||||
select {
|
||||
case <-subCtx.Done():
|
||||
case <-ctx.Done():
|
||||
assert.Fail(t, "timeout waiting for sub context to be canceled")
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer unsub()
|
||||
|
||||
go func() {
|
||||
err := pubsub.Publish(event, nil)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-called:
|
||||
case <-ctx.Done():
|
||||
require.Fail(t, "timeout waiting for handler to be called")
|
||||
}
|
||||
err = pubsub.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
require.Fail(t, "timeout waiting for handler to finish")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPubsub_ordering(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue