coder/codersdk/drpc/transport.go

130 lines
2.8 KiB
Go

package drpc
import (
"context"
"net"
"sync"
"github.com/hashicorp/yamux"
"github.com/valyala/fasthttp/fasthttputil"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"github.com/coder/coder/v2/coderd/tracing"
)
const (
// MaxMessageSize is the maximum payload size that can be
// transported without error.
MaxMessageSize = 4 << 20
)
// MultiplexedConn returns a multiplexed dRPC connection from a yamux Session.
func MultiplexedConn(session *yamux.Session) drpc.Conn {
return &multiplexedDRPC{session}
}
// Allows concurrent requests on a single dRPC connection.
// Required for calling functions concurrently.
type multiplexedDRPC struct {
session *yamux.Session
}
func (m *multiplexedDRPC) Close() error {
return m.session.Close()
}
func (m *multiplexedDRPC) Closed() <-chan struct{} {
return m.session.CloseChan()
}
func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inMessage, outMessage drpc.Message) error {
conn, err := m.session.Open()
if err != nil {
return err
}
dConn := drpcconn.New(conn)
defer func() {
_ = dConn.Close()
}()
return dConn.Invoke(ctx, rpc, enc, inMessage, outMessage)
}
func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
conn, err := m.session.Open()
if err != nil {
return nil, err
}
dConn := drpcconn.New(conn)
stream, err := dConn.NewStream(ctx, rpc, enc)
if err == nil {
go func() {
<-stream.Context().Done()
_ = dConn.Close()
}()
}
return stream, err
}
func MemTransportPipe() (drpc.Conn, net.Listener) {
m := &memDRPC{
closed: make(chan struct{}),
l: fasthttputil.NewInmemoryListener(),
}
return m, m.l
}
type memDRPC struct {
closeOnce sync.Once
closed chan struct{}
l *fasthttputil.InmemoryListener
}
func (m *memDRPC) Close() error {
err := m.l.Close()
m.closeOnce.Do(func() { close(m.closed) })
return err
}
func (m *memDRPC) Closed() <-chan struct{} {
return m.closed
}
func (m *memDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inMessage, outMessage drpc.Message) error {
conn, err := m.l.Dial()
if err != nil {
return err
}
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
defer func() {
_ = dConn.Close()
_ = conn.Close()
}()
return dConn.Invoke(ctx, rpc, enc, inMessage, outMessage)
}
func (m *memDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
conn, err := m.l.Dial()
if err != nil {
return nil, err
}
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
stream, err := dConn.NewStream(ctx, rpc, enc)
if err != nil {
_ = dConn.Close()
_ = conn.Close()
return nil, err
}
go func() {
select {
case <-stream.Context().Done():
case <-m.closed:
}
_ = dConn.Close()
_ = conn.Close()
}()
return stream, nil
}