mirror of https://github.com/coder/coder.git
130 lines
2.8 KiB
Go
130 lines
2.8 KiB
Go
package provisionersdk
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/hashicorp/yamux"
|
|
"github.com/valyala/fasthttp/fasthttputil"
|
|
"storj.io/drpc"
|
|
"storj.io/drpc/drpcconn"
|
|
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
)
|
|
|
|
const (
|
|
// MaxMessageSize is the maximum payload size that can be
|
|
// transported without error.
|
|
MaxMessageSize = 4 << 20
|
|
)
|
|
|
|
// MultiplexedConn returns a multiplexed dRPC connection from a yamux Session.
|
|
func MultiplexedConn(session *yamux.Session) drpc.Conn {
|
|
return &multiplexedDRPC{session}
|
|
}
|
|
|
|
// Allows concurrent requests on a single dRPC connection.
|
|
// Required for calling functions concurrently.
|
|
type multiplexedDRPC struct {
|
|
session *yamux.Session
|
|
}
|
|
|
|
func (m *multiplexedDRPC) Close() error {
|
|
return m.session.Close()
|
|
}
|
|
|
|
func (m *multiplexedDRPC) Closed() <-chan struct{} {
|
|
return m.session.CloseChan()
|
|
}
|
|
|
|
func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inMessage, outMessage drpc.Message) error {
|
|
conn, err := m.session.Open()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dConn := drpcconn.New(conn)
|
|
defer func() {
|
|
_ = dConn.Close()
|
|
}()
|
|
return dConn.Invoke(ctx, rpc, enc, inMessage, outMessage)
|
|
}
|
|
|
|
func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
|
|
conn, err := m.session.Open()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dConn := drpcconn.New(conn)
|
|
stream, err := dConn.NewStream(ctx, rpc, enc)
|
|
if err == nil {
|
|
go func() {
|
|
<-stream.Context().Done()
|
|
_ = dConn.Close()
|
|
}()
|
|
}
|
|
return stream, err
|
|
}
|
|
|
|
func MemTransportPipe() (drpc.Conn, net.Listener) {
|
|
m := &memDRPC{
|
|
closed: make(chan struct{}),
|
|
l: fasthttputil.NewInmemoryListener(),
|
|
}
|
|
|
|
return m, m.l
|
|
}
|
|
|
|
type memDRPC struct {
|
|
closeOnce sync.Once
|
|
closed chan struct{}
|
|
l *fasthttputil.InmemoryListener
|
|
}
|
|
|
|
func (m *memDRPC) Close() error {
|
|
err := m.l.Close()
|
|
m.closeOnce.Do(func() { close(m.closed) })
|
|
return err
|
|
}
|
|
|
|
func (m *memDRPC) Closed() <-chan struct{} {
|
|
return m.closed
|
|
}
|
|
|
|
func (m *memDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inMessage, outMessage drpc.Message) error {
|
|
conn, err := m.l.Dial()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
|
|
defer func() {
|
|
_ = dConn.Close()
|
|
_ = conn.Close()
|
|
}()
|
|
return dConn.Invoke(ctx, rpc, enc, inMessage, outMessage)
|
|
}
|
|
|
|
func (m *memDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
|
|
conn, err := m.l.Dial()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
|
|
stream, err := dConn.NewStream(ctx, rpc, enc)
|
|
if err != nil {
|
|
_ = dConn.Close()
|
|
_ = conn.Close()
|
|
return nil, err
|
|
}
|
|
go func() {
|
|
select {
|
|
case <-stream.Context().Done():
|
|
case <-m.closed:
|
|
}
|
|
_ = dConn.Close()
|
|
_ = conn.Close()
|
|
}()
|
|
return stream, nil
|
|
}
|