
284 lines
8.3 KiB
Raw Normal View History

package peerbroker
import (
protobuf ""
var (
// Each NegotiateConnection() function call spawns a new stream.
streamIDLength = len(uuid.NewString())
// We shouldn't PubSub anything larger than this!
maxPayloadSizeBytes = 8192
// ProxyOptions provides values to configure a proxy.
type ProxyOptions struct {
ChannelID string
Logger slog.Logger
Pubsub database.Pubsub
// ProxyDial writes client negotiation streams over PubSub.
// PubSub is used to geodistribute WebRTC handshakes. All negotiation
// messages are small in size (<=8KB), and we don't require delivery
// guarantees because connections can always be renegotiated.
// ┌────────────────────┐ ┌─────────────────────────────┐
// │ coderd │ │ coderd │
// ┌─────────────────────┐ │/<agent-id>/connect │ │ /<agent-id>/listen │
// │ client │ │ │ │ │ ┌─────┐
// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the <agent-id>│◄──┤agent│
// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘
// └─────────────────────┘ │<agent-id> channel: │ │from payloads to create new │
// │ │ │NegotiateConnection() streams│
// │<stream-id><payload>│ │or write to existing ones. │
// └────────────────────┘ └─────────────────────────────┘
func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) {
proxyDial := &proxyDial{
channelID: options.ChannelID,
logger: options.Logger,
pubsub: options.Pubsub,
connection: client,
streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient),
return proxyDial, proxyDial.listen()
// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener
// as new NegotiateConnection() streams.
func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error {
mux := drpcmux.New()
err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{
channelID: options.ChannelID,
pubsub: options.Pubsub,
logger: options.Logger,
if err != nil {
return xerrors.Errorf("register peer broker: %w", err)
server := drpcserver.New(mux)
err = server.Serve(ctx, connListener)
if err != nil {
if errors.Is(err, yamux.ErrSessionShutdown) {
return nil
return xerrors.Errorf("serve: %w", err)
return nil
type proxyListen struct {
channelID string
pubsub database.Pubsub
logger slog.Logger
func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
streamID := uuid.NewString()
var err error
closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) {
err := p.onServerToClientMessage(streamID, stream, message)
if err != nil {
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
if err != nil {
return xerrors.Errorf("subscribe: %w", err)
defer closeSubscribe()
for {
clientToServerMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return xerrors.Errorf("recv: %w", err)
data, err := protobuf.Marshal(clientToServerMessage)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
if len(data) > maxPayloadSizeBytes {
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
data = append([]byte(streamID), data...)
err = p.pubsub.Publish(proxyOutID(p.channelID), marshal(data))
if err != nil {
return xerrors.Errorf("publish: %w", err)
return nil
func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error {
var err error
message, err = unmarshal(message)
if err != nil {
return xerrors.Errorf("decode: %w", err)
if len(message) < streamIDLength {
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
serverStreamID := string(message[0:streamIDLength])
if serverStreamID != streamID {
// It's not trying to communicate with this stream!
return nil
var msg proto.Exchange
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
if err != nil {
return xerrors.Errorf("unmarshal message: %w", err)
err = stream.Send(&msg)
if err != nil {
return xerrors.Errorf("send message: %w", err)
return nil
type proxyDial struct {
channelID string
pubsub database.Pubsub
logger slog.Logger
connection proto.DRPCPeerBrokerClient
closeSubscribe func()
streamMutex sync.Mutex
streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient
func (p *proxyDial) listen() error {
var err error
p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) {
err := p.onClientToServerMessage(ctx, message)
if err != nil {
p.logger.Debug(ctx, "failed to accept client message", slog.Error(err))
if err != nil {
return err
return nil
func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error {
var err error
message, err = unmarshal(message)
if err != nil {
return xerrors.Errorf("decode: %w", err)
if len(message) < streamIDLength {
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
streamID := string(message[0:streamIDLength])
stream, ok := p.streams[streamID]
if !ok {
stream, err = p.connection.NegotiateConnection(ctx)
if err != nil {
return xerrors.Errorf("negotiate connection: %w", err)
p.streams[streamID] = stream
go func() {
defer stream.Close()
err := p.onServerToClientMessage(streamID, stream)
if err != nil {
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
go func() {
delete(p.streams, streamID)
var msg proto.Exchange
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
if err != nil {
return xerrors.Errorf("unmarshal message: %w", err)
err = stream.Send(&msg)
if err != nil {
return xerrors.Errorf("write message: %w", err)
return nil
func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error {
for {
serverToClientMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
if errors.Is(err, context.Canceled) {
return xerrors.Errorf("recv: %w", err)
data, err := protobuf.Marshal(serverToClientMessage)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
if len(data) > maxPayloadSizeBytes {
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
data = append([]byte(streamID), data...)
err = p.pubsub.Publish(proxyInID(p.channelID), marshal(data))
if err != nil {
return xerrors.Errorf("publish: %w", err)
return nil
func (p *proxyDial) Close() error {
defer p.streamMutex.Unlock()
return nil
// base64 needs to be used here to keep the pubsub messages in UTF-8 range.
// PostgreSQL cannot handle non UTF-8 messages over pubsub.
func marshal(data []byte) []byte {
return []byte(base64.StdEncoding.EncodeToString(data))
func unmarshal(data []byte) ([]byte, error) {
return base64.StdEncoding.DecodeString(string(data))
func proxyOutID(channelID string) string {
return fmt.Sprintf("%s-out", channelID)
func proxyInID(channelID string) string {
return fmt.Sprintf("%s-in", channelID)