coder/coderd/database/pubsub.go

156 lines
3.5 KiB
Go

package database
import (
"context"
"database/sql"
"errors"
"sync"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"golang.org/x/xerrors"
)
// Listener represents a pubsub handler.
type Listener func(ctx context.Context, message []byte)
// Pubsub is a generic interface for broadcasting and receiving messages.
// Implementors should assume high-availability with the backing implementation.
type Pubsub interface {
Subscribe(event string, listener Listener) (cancel func(), err error)
Publish(event string, message []byte) error
Close() error
}
// Pubsub implementation using PostgreSQL.
type pgPubsub struct {
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
listeners map[string]map[string]Listener
}
// Subscribe calls the listener when an event matching the name is received.
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
p.mut.Lock()
defer p.mut.Unlock()
err = p.pgListener.Listen(event)
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
// It's ok if it's already open!
err = nil
}
if err != nil {
return nil, xerrors.Errorf("listen: %w", err)
}
var listeners map[string]Listener
var ok bool
if listeners, ok = p.listeners[event]; !ok {
listeners = map[string]Listener{}
p.listeners[event] = listeners
}
var id string
for {
id = uuid.New().String()
if _, ok = listeners[id]; !ok {
break
}
}
listeners[id] = listener
return func() {
p.mut.Lock()
defer p.mut.Unlock()
listeners := p.listeners[event]
delete(listeners, id)
if len(listeners) == 0 {
_ = p.pgListener.Unlisten(event)
}
}, nil
}
func (p *pgPubsub) Publish(event string, message []byte) error {
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil {
return xerrors.Errorf("exec: %w", err)
}
return nil
}
// Close closes the pubsub instance.
func (p *pgPubsub) Close() error {
return p.pgListener.Close()
}
// listen begins receiving messages on the pq listener.
func (p *pgPubsub) listen(ctx context.Context) {
var (
notif *pq.Notification
ok bool
)
defer p.pgListener.Close()
for {
select {
case <-ctx.Done():
return
case notif, ok = <-p.pgListener.Notify:
if !ok {
return
}
}
// A nil notification can be dispatched on reconnect.
if notif == nil {
continue
}
p.listenReceive(ctx, notif)
}
}
func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
p.mut.Lock()
defer p.mut.Unlock()
listeners, ok := p.listeners[notif.Channel]
if !ok {
return
}
extra := []byte(notif.Extra)
for _, listener := range listeners {
go listener(ctx, extra)
}
}
// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq.
errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
case <-errCh:
return
default:
errCh <- err
close(errCh)
}
})
select {
case err := <-errCh:
if err != nil {
return nil, xerrors.Errorf("create pq listener: %w", err)
}
case <-ctx.Done():
return nil, ctx.Err()
}
pgPubsub := &pgPubsub{
db: database,
pgListener: listener,
listeners: make(map[string]map[string]Listener),
}
go pgPubsub.listen(ctx)
return pgPubsub, nil
}