fix: stop holding Pubsub mutex while calling pq.Listener (#12518)

fixes #11950

https://github.com/coder/coder/issues/11950#issuecomment-1987756088 explains the bug

We were also calling into `Unlisten()` and `Close()` while holding the mutex.  I don't believe that `Close()` depends on the notification loop being unblocked, but it's hard to be sure, and the safest thing to do is assume it could block.

So, I added a unit test that fakes out `pq.Listener` and sends a bunch of notifies every time we call into it to hopefully prevent regression where we hold the mutex while calling into these functions.

It also removes the use of a `context.Context` to stop the PubSub -- it must be explicitly `Closed()`.  This simplifies a bunch of the logic, and is how we use the pubsub anyway.
This commit is contained in:
Spike Curtis 2024-03-12 09:44:12 +04:00 committed by GitHub
parent 6f00ccfa64
commit 51707446d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 221 additions and 139 deletions

View File

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"io"
"net"
"sync"
"time"
@ -164,16 +165,36 @@ func (q *msgQueue) dropped() {
q.cond.Broadcast()
}
// pqListener is an interface that represents a *pq.Listener for testing
type pqListener interface {
io.Closer
Listen(string) error
Unlisten(string) error
NotifyChan() <-chan *pq.Notification
}
type pqListenerShim struct {
*pq.Listener
}
func (l pqListenerShim) NotifyChan() <-chan *pq.Notification {
return l.Notify
}
// PGPubsub is a pubsub implementation using PostgreSQL.
type PGPubsub struct {
ctx context.Context
cancel context.CancelFunc
logger slog.Logger
listenDone chan struct{}
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
queues map[string]map[uuid.UUID]*msgQueue
logger slog.Logger
listenDone chan struct{}
pgListener pqListener
db *sql.DB
qMu sync.Mutex
queues map[string]map[uuid.UUID]*msgQueue
// making the close state its own mutex domain simplifies closing logic so
// that we don't have to hold the qMu --- which could block processing
// notifications while the pqListener is closing.
closeMu sync.Mutex
closedListener bool
closeListenerErr error
@ -192,16 +213,14 @@ const BufferSize = 2048
// Subscribe calls the listener when an event matching the name is received.
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil))
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
}
func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener))
return p.subscribeQueue(event, newMsgQueue(context.Background(), nil, listener))
}
func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) {
p.mut.Lock()
defer p.mut.Unlock()
defer func() {
if err != nil {
// if we hit an error, we need to close the queue so we don't
@ -213,9 +232,13 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
}
}()
// The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches
// notifies. We need to avoid holding the mutex while this happens, since holding the mutex
// blocks reading notifications and can deadlock the pgListener.
// c.f. https://github.com/coder/coder/issues/11950
err = p.pgListener.Listen(event)
if err == nil {
p.logger.Debug(p.ctx, "started listening to event channel", slog.F("event", event))
p.logger.Debug(context.Background(), "started listening to event channel", slog.F("event", event))
}
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
// It's ok if it's already open!
@ -224,6 +247,8 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
if err != nil {
return nil, xerrors.Errorf("listen: %w", err)
}
p.qMu.Lock()
defer p.qMu.Unlock()
var eventQs map[uuid.UUID]*msgQueue
var ok bool
@ -234,30 +259,36 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
id := uuid.New()
eventQs[id] = newQ
return func() {
p.mut.Lock()
defer p.mut.Unlock()
p.qMu.Lock()
listeners := p.queues[event]
q := listeners[id]
q.close()
delete(listeners, id)
if len(listeners) == 0 {
delete(p.queues, event)
}
p.qMu.Unlock()
// as above, we must not hold the lock while calling into pgListener
if len(listeners) == 0 {
uErr := p.pgListener.Unlisten(event)
p.closeMu.Lock()
defer p.closeMu.Unlock()
if uErr != nil && !p.closedListener {
p.logger.Warn(p.ctx, "failed to unlisten", slog.Error(uErr), slog.F("event", event))
p.logger.Warn(context.Background(), "failed to unlisten", slog.Error(uErr), slog.F("event", event))
} else {
p.logger.Debug(p.ctx, "stopped listening to event channel", slog.F("event", event))
p.logger.Debug(context.Background(), "stopped listening to event channel", slog.F("event", event))
}
}
}, nil
}
func (p *PGPubsub) Publish(event string, message []byte) error {
p.logger.Debug(p.ctx, "publish", slog.F("event", event), slog.F("message_len", len(message)))
p.logger.Debug(context.Background(), "publish", slog.F("event", event), slog.F("message_len", len(message)))
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
// support the first parameter being a prepared statement.
//nolint:gosec
_, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil {
p.publishesTotal.WithLabelValues("false").Inc()
return xerrors.Errorf("exec pg_notify: %w", err)
@ -269,53 +300,38 @@ func (p *PGPubsub) Publish(event string, message []byte) error {
// Close closes the pubsub instance.
func (p *PGPubsub) Close() error {
p.logger.Info(p.ctx, "pubsub is closing")
p.cancel()
p.logger.Info(context.Background(), "pubsub is closing")
err := p.closeListener()
<-p.listenDone
p.logger.Debug(p.ctx, "pubsub closed")
p.logger.Debug(context.Background(), "pubsub closed")
return err
}
// closeListener closes the pgListener, unless it has already been closed.
func (p *PGPubsub) closeListener() error {
p.mut.Lock()
defer p.mut.Unlock()
p.closeMu.Lock()
defer p.closeMu.Unlock()
if p.closedListener {
return p.closeListenerErr
}
p.closeListenerErr = p.pgListener.Close()
p.closedListener = true
p.closeListenerErr = p.pgListener.Close()
return p.closeListenerErr
}
// listen begins receiving messages on the pq listener.
func (p *PGPubsub) listen() {
defer func() {
p.logger.Info(p.ctx, "pubsub listen stopped receiving notify")
cErr := p.closeListener()
if cErr != nil {
p.logger.Error(p.ctx, "failed to close listener")
}
p.logger.Info(context.Background(), "pubsub listen stopped receiving notify")
close(p.listenDone)
}()
var (
notif *pq.Notification
ok bool
)
for {
select {
case <-p.ctx.Done():
return
case notif, ok = <-p.pgListener.Notify:
if !ok {
return
}
}
notify := p.pgListener.NotifyChan()
for notif := range notify {
// A nil notification can be dispatched on reconnect.
if notif == nil {
p.logger.Debug(p.ctx, "notifying subscribers of a reconnection")
p.logger.Debug(context.Background(), "notifying subscribers of a reconnection")
p.recordReconnect()
continue
}
@ -331,8 +347,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
p.messagesTotal.WithLabelValues(sizeLabel).Inc()
p.receivedBytesTotal.Add(float64(len(notif.Extra)))
p.mut.Lock()
defer p.mut.Unlock()
p.qMu.Lock()
defer p.qMu.Unlock()
queues, ok := p.queues[notif.Channel]
if !ok {
return
@ -344,8 +360,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
}
func (p *PGPubsub) recordReconnect() {
p.mut.Lock()
defer p.mut.Unlock()
p.qMu.Lock()
defer p.qMu.Unlock()
for _, listeners := range p.queues {
for _, q := range listeners {
q.dropped()
@ -409,30 +425,32 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
d: net.Dialer{},
}
)
p.pgListener = pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
switch t {
case pq.ListenerEventConnected:
p.logger.Info(ctx, "pubsub connected to postgres")
p.connected.Set(1.0)
case pq.ListenerEventDisconnected:
p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err))
p.connected.Set(0)
case pq.ListenerEventReconnected:
p.logger.Info(ctx, "pubsub reconnected to postgres")
p.connected.Set(1)
case pq.ListenerEventConnectionAttemptFailed:
p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err))
}
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
case <-errCh:
return
default:
errCh <- err
close(errCh)
}
})
p.pgListener = pqListenerShim{
Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
switch t {
case pq.ListenerEventConnected:
p.logger.Info(ctx, "pubsub connected to postgres")
p.connected.Set(1.0)
case pq.ListenerEventDisconnected:
p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err))
p.connected.Set(0)
case pq.ListenerEventReconnected:
p.logger.Info(ctx, "pubsub reconnected to postgres")
p.connected.Set(1)
case pq.ListenerEventConnectionAttemptFailed:
p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err))
}
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
case <-errCh:
return
default:
errCh <- err
close(errCh)
}
}),
}
select {
case err := <-errCh:
if err != nil {
@ -501,24 +519,31 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
p.connected.Collect(metrics)
// implicit metrics
p.mut.Lock()
p.qMu.Lock()
events := len(p.queues)
subs := 0
for _, subscriberMap := range p.queues {
subs += len(subscriberMap)
}
p.mut.Unlock()
p.qMu.Unlock()
metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs))
metrics <- prometheus.MustNewConstMetric(currentEventsDesc, prometheus.GaugeValue, float64(events))
}
// New creates a new Pubsub implementation using a PostgreSQL connection.
func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) {
// Start a new context that will be canceled when the pubsub is closed.
ctx, cancel := context.WithCancel(context.Background())
p := &PGPubsub{
ctx: ctx,
cancel: cancel,
p := newWithoutListener(logger, database)
if err := p.startListener(startCtx, connectURL); err != nil {
return nil, err
}
go p.listen()
logger.Info(startCtx, "pubsub has started")
return p, nil
}
// newWithoutListener creates a new PGPubsub without creating the pqListener.
func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub {
return &PGPubsub{
logger: logger,
listenDone: make(chan struct{}),
db: database,
@ -567,10 +592,4 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect
Help: "Whether we are connected (1) or not connected (0) to postgres",
}),
}
if err := p.startListener(startCtx, connectURL); err != nil {
return nil, err
}
go p.listen()
logger.Info(ctx, "pubsub has started")
return p, nil
}

View File

@ -3,10 +3,15 @@ package pubsub
import (
"context"
"fmt"
"sync"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"
)
@ -138,3 +143,115 @@ func Test_msgQueue_Full(t *testing.T) {
// for the error, so we read 2 less than we sent.
require.Equal(t, BufferSize, n)
}
func TestPubSub_DoesntBlockNotify(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut := newWithoutListener(logger, nil)
fListener := newFakePqListener()
uut.pgListener = fListener
go uut.listen()
cancels := make(chan func())
go func() {
subCancel, err := uut.Subscribe("bagels", func(ctx context.Context, message []byte) {
t.Logf("got message: %s", string(message))
})
assert.NoError(t, err)
cancels <- subCancel
}()
subCancel := testutil.RequireRecvCtx(ctx, t, cancels)
cancelDone := make(chan struct{})
go func() {
defer close(cancelDone)
subCancel()
}()
testutil.RequireRecvCtx(ctx, t, cancelDone)
closeErrs := make(chan error)
go func() {
closeErrs <- uut.Close()
}()
err := testutil.RequireRecvCtx(ctx, t, closeErrs)
require.NoError(t, err)
}
const (
numNotifications = 5
testMessage = "birds of a feather"
)
// fakePqListener is a fake version of pq.Listener. This test code tests for regressions of
// https://github.com/coder/coder/issues/11950 where pq.Listener deadlocked because we blocked the
// PGPubsub.listen() goroutine while calling other pq.Listener functions. So, all function calls
// into the fakePqListener will send 5 notifications before returning to ensure the listen()
// goroutine is unblocked.
type fakePqListener struct {
mu sync.Mutex
channels map[string]struct{}
notify chan *pq.Notification
}
func (f *fakePqListener) Close() error {
f.mu.Lock()
defer f.mu.Unlock()
ch := f.getTestChanLocked()
for i := 0; i < numNotifications; i++ {
f.notify <- &pq.Notification{Channel: ch, Extra: testMessage}
}
// note that the realPqListener must only be closed once, so go ahead and
// close the notify unprotected here. If it panics, we have a bug.
close(f.notify)
return nil
}
func (f *fakePqListener) Listen(s string) error {
f.mu.Lock()
defer f.mu.Unlock()
ch := f.getTestChanLocked()
for i := 0; i < numNotifications; i++ {
f.notify <- &pq.Notification{Channel: ch, Extra: testMessage}
}
if _, ok := f.channels[s]; ok {
return pq.ErrChannelAlreadyOpen
}
f.channels[s] = struct{}{}
return nil
}
func (f *fakePqListener) Unlisten(s string) error {
f.mu.Lock()
defer f.mu.Unlock()
ch := f.getTestChanLocked()
for i := 0; i < numNotifications; i++ {
f.notify <- &pq.Notification{Channel: ch, Extra: testMessage}
}
if _, ok := f.channels[s]; ok {
delete(f.channels, s)
return nil
}
return pq.ErrChannelNotOpen
}
func (f *fakePqListener) NotifyChan() <-chan *pq.Notification {
return f.notify
}
// getTestChanLocked returns the name of a channel we are currently listening for, if there is one.
// Otherwise, it just returns "test". We prefer to send test notifications for channels that appear
// in the tests, but if there are none, just return anything.
func (f *fakePqListener) getTestChanLocked() string {
for c := range f.channels {
return c
}
return "test"
}
func newFakePqListener() *fakePqListener {
return &fakePqListener{
channels: make(map[string]struct{}),
notify: make(chan *pq.Notification),
}
}

View File

@ -109,60 +109,6 @@ func TestPubsub(t *testing.T) {
message := <-messageChannel
assert.Equal(t, string(message), data)
})
t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
defer closePg()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := pubsub.New(ctx, logger, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
event := "test"
done := make(chan struct{})
called := make(chan struct{})
unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) {
defer close(done)
select {
case <-subCtx.Done():
assert.Fail(t, "context should not be canceled")
default:
}
close(called)
select {
case <-subCtx.Done():
case <-ctx.Done():
assert.Fail(t, "timeout waiting for sub context to be canceled")
}
})
require.NoError(t, err)
defer unsub()
go func() {
err := pubsub.Publish(event, nil)
assert.NoError(t, err)
}()
select {
case <-called:
case <-ctx.Done():
require.Fail(t, "timeout waiting for handler to be called")
}
err = pubsub.Close()
require.NoError(t, err)
select {
case <-done:
case <-ctx.Done():
require.Fail(t, "timeout waiting for handler to finish")
}
})
}
func TestPubsub_ordering(t *testing.T) {