fix: don't use yamux for in-memory provisioner{,d} streams (#5136)

This commit is contained in:
Colin Adler 2022-11-22 12:19:32 -06:00 committed by GitHub
parent 2b6c229e4e
commit 1f20cab110
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 105 additions and 57 deletions

View File

@ -890,7 +890,7 @@ func newProvisionerDaemon(
return nil, xerrors.Errorf("mkdir %q: %w", cfg.CacheDirectory.Value, err)
}
terraformClient, terraformServer := provisionersdk.TransportPipe()
terraformClient, terraformServer := provisionersdk.MemTransportPipe()
go func() {
<-ctx.Done()
_ = terraformClient.Close()
@ -920,11 +920,11 @@ func newProvisionerDaemon(
}
provisioners := provisionerd.Provisioners{
string(database.ProvisionerTypeTerraform): sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(terraformClient)),
string(database.ProvisionerTypeTerraform): sdkproto.NewDRPCProvisionerClient(terraformClient),
}
// include echo provisioner when in dev mode
if dev {
echoClient, echoServer := provisionersdk.TransportPipe()
echoClient, echoServer := provisionersdk.MemTransportPipe()
go func() {
<-ctx.Done()
_ = echoClient.Close()
@ -941,7 +941,7 @@ func newProvisionerDaemon(
}
}
}()
provisioners[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient))
provisioners[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient)
}
return provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
// This debounces calls to listen every second. Read the comment

View File

@ -644,7 +644,7 @@ func compressHandler(h http.Handler) http.Handler {
// CreateInMemoryProvisionerDaemon is an in-memory connection to a provisionerd. Useful when starting coderd and provisionerd
// in the same process.
func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) {
clientSession, serverSession := provisionersdk.TransportPipe()
clientSession, serverSession := provisionersdk.MemTransportPipe()
defer func() {
if err != nil {
_ = clientSession.Close()
@ -705,5 +705,5 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
_ = serverSession.Close()
}()
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil
return proto.NewDRPCProvisionerDaemonClient(clientSession), nil
}

View File

@ -315,7 +315,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
// well with coderd testing. It registers the "echo" provisioner for
// quick testing.
func NewProvisionerDaemon(t *testing.T, coderAPI *coderd.API) io.Closer {
echoClient, echoServer := provisionersdk.TransportPipe()
echoClient, echoServer := provisionersdk.MemTransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(func() {
_ = echoClient.Close()
@ -339,7 +339,7 @@ func NewProvisionerDaemon(t *testing.T, coderAPI *coderd.API) io.Closer {
UpdateInterval: 250 * time.Millisecond,
ForceCancelInterval: time.Second,
Provisioners: provisionerd.Provisioners{
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient)),
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
},
WorkDirectory: t.TempDir(),
})
@ -350,7 +350,7 @@ func NewProvisionerDaemon(t *testing.T, coderAPI *coderd.API) io.Closer {
}
func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uuid.UUID, tags map[string]string) io.Closer {
echoClient, echoServer := provisionersdk.TransportPipe()
echoClient, echoServer := provisionersdk.MemTransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(func() {
_ = echoClient.Close()
@ -374,7 +374,7 @@ func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uui
UpdateInterval: 250 * time.Millisecond,
ForceCancelInterval: time.Second,
Provisioners: provisionerd.Provisioners{
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient)),
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
},
WorkDirectory: t.TempDir(),
})

View File

@ -212,5 +212,5 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U
if err != nil {
return nil, xerrors.Errorf("multiplex client: %w", err)
}
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(session)), nil
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil
}

View File

@ -69,7 +69,7 @@ func provisionerDaemonStart() *cobra.Command {
return xerrors.Errorf("mkdir %q: %w", cacheDir, err)
}
terraformClient, terraformServer := provisionersdk.TransportPipe()
terraformClient, terraformServer := provisionersdk.MemTransportPipe()
go func() {
<-ctx.Done()
_ = terraformClient.Close()
@ -104,7 +104,7 @@ func provisionerDaemonStart() *cobra.Command {
logger.Info(ctx, "starting provisioner daemon", slog.F("tags", tags))
provisioners := provisionerd.Provisioners{
string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(provisionersdk.Conn(terraformClient)),
string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(terraformClient),
}
srv := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return client.ServeProvisionerDaemon(ctx, org.ID, []codersdk.ProvisionerType{

3
go.mod
View File

@ -53,7 +53,6 @@ replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20220811105153-
require (
cdr.dev/slog v1.4.2-0.20220525200111-18dce5c2cd5f
cloud.google.com/go/compute v1.12.1 // indirect
cloud.google.com/go/compute/metadata v0.2.1
github.com/AlecAivazis/survey/v2 v2.3.5
github.com/adrg/xdg v0.4.0
@ -129,6 +128,7 @@ require (
github.com/tabbed/pqtype v0.1.1
github.com/u-root/u-root v0.10.0
github.com/unrolled/secure v1.13.0
github.com/valyala/fasthttp v1.41.0
go.mozilla.org/pkcs7 v0.0.0-20200128120323-432b2356ecb1
go.nhat.io/otelsql v0.7.0
go.opentelemetry.io/otel v1.11.1
@ -166,6 +166,7 @@ require (
)
require (
cloud.google.com/go/compute v1.12.1 // indirect
filippo.io/edwards25519 v1.0.0-rc.1 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.5.2 // indirect

5
go.sum
View File

@ -1776,8 +1776,11 @@ github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijb
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/uudashr/gocognit v1.0.5/go.mod h1:wgYz0mitoKOTysqxTDMOUXg+Jb5SvtihkfmugIZYpEA=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.30.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus=
github.com/valyala/fasthttp v1.41.0 h1:zeR0Z1my1wDHTRiamBCXVglQdbUwgb9uWG3k1HQz6jY=
github.com/valyala/fasthttp v1.41.0/go.mod h1:f6VbjjoI3z1NDOZOv17o6RvtRSWxC77seBFc2uWtgiY=
github.com/valyala/quicktemplate v1.7.0/go.mod h1:sqKJnoaOF88V07vkO+9FL8fb9uZg/VPSJnLYn+LmLk8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE=
@ -1975,6 +1978,7 @@ golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -2117,6 +2121,7 @@ golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/oauth2 v0.0.0-20180227000427-d7d64896b5ff/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=

View File

@ -23,7 +23,7 @@ func TestEcho(t *testing.T) {
fs := afero.NewMemMapFs()
// Create an in-memory provisioner to communicate with.
client, server := provisionersdk.TransportPipe()
client, server := provisionersdk.MemTransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(func() {
_ = client.Close()
@ -36,7 +36,7 @@ func TestEcho(t *testing.T) {
})
assert.NoError(t, err)
}()
api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client))
api := proto.NewDRPCProvisionerClient(client)
t.Run("Parse", func(t *testing.T) {
t.Parallel()

View File

@ -36,7 +36,7 @@ func setupProvisioner(t *testing.T, opts *provisionerServeOptions) (context.Cont
opts = &provisionerServeOptions{}
}
cachePath := t.TempDir()
client, server := provisionersdk.TransportPipe()
client, server := provisionersdk.MemTransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
serverErr := make(chan error, 1)
t.Cleanup(func() {
@ -59,7 +59,7 @@ func setupProvisioner(t *testing.T, opts *provisionerServeOptions) (context.Cont
ExitTimeout: opts.exitTimeout,
})
}()
api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client))
api := proto.NewDRPCProvisionerClient(client)
return ctx, api
}

View File

@ -14,6 +14,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/spf13/afero"
"github.com/valyala/fasthttp/fasthttputil"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
"go.opentelemetry.io/otel/trace"
@ -344,7 +345,7 @@ func (p *Server) acquireJob(ctx context.Context) {
}
func retryable(err error) bool {
return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) ||
return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) || xerrors.Is(err, fasthttputil.ErrInmemoryListenerClosed) ||
// annoyingly, dRPC sometimes returns context.Canceled if the transport was closed, even if the context for
// the RPC *is not canceled*. Retrying is fine if the RPC context is not canceled.
xerrors.Is(err, context.Canceled)

View File

@ -843,6 +843,7 @@ func TestProvisionerd(t *testing.T) {
<-failChan
_ = client.DRPCConn().Close()
second.Store(true)
time.Sleep(50 * time.Millisecond)
failedOnce.Do(func() { close(failedChan) })
}()
}
@ -1075,7 +1076,7 @@ func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestSer
return &proto.Empty{}, nil
}
}
clientPipe, serverPipe := provisionersdk.TransportPipe()
clientPipe, serverPipe := provisionersdk.MemTransportPipe()
t.Cleanup(func() {
_ = clientPipe.Close()
_ = serverPipe.Close()
@ -1089,14 +1090,14 @@ func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestSer
go func() {
_ = srv.Serve(ctx, serverPipe)
}()
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientPipe))
return proto.NewDRPCProvisionerDaemonClient(clientPipe)
}
// Creates a provisioner protobuf client that's connected
// to the server implementation provided.
func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
t.Helper()
clientPipe, serverPipe := provisionersdk.TransportPipe()
clientPipe, serverPipe := provisionersdk.MemTransportPipe()
t.Cleanup(func() {
_ = clientPipe.Close()
_ = serverPipe.Close()
@ -1110,7 +1111,7 @@ func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkprot
go func() {
_ = srv.Serve(ctx, serverPipe)
}()
return sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(clientPipe))
return sdkproto.NewDRPCProvisionerClient(clientPipe)
}
type provisionerTestServer struct {

View File

@ -7,12 +7,12 @@ import (
"net"
"os"
"github.com/hashicorp/yamux"
"github.com/valyala/fasthttp/fasthttputil"
"golang.org/x/xerrors"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"github.com/hashicorp/yamux"
"github.com/coder/coder/provisionersdk/proto"
)
@ -58,18 +58,14 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser
// short-lived processes that can be executed concurrently.
err = srv.Serve(ctx, options.Listener)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
if errors.Is(err, context.Canceled) {
return nil
}
if errors.Is(err, io.ErrClosedPipe) {
return nil
}
if errors.Is(err, yamux.ErrSessionShutdown) {
if errors.Is(err, io.EOF) ||
errors.Is(err, context.Canceled) ||
errors.Is(err, io.ErrClosedPipe) ||
errors.Is(err, yamux.ErrSessionShutdown) ||
errors.Is(err, fasthttputil.ErrInmemoryListenerClosed) {
return nil
}
return xerrors.Errorf("serve transport: %w", err)
}
return nil

View File

@ -21,7 +21,7 @@ func TestProvisionerSDK(t *testing.T) {
t.Parallel()
t.Run("Serve", func(t *testing.T) {
t.Parallel()
client, server := provisionersdk.TransportPipe()
client, server := provisionersdk.MemTransportPipe()
defer client.Close()
defer server.Close()
@ -34,7 +34,7 @@ func TestProvisionerSDK(t *testing.T) {
assert.NoError(t, err)
}()
api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client))
api := proto.NewDRPCProvisionerClient(client)
stream, err := api.Parse(context.Background(), &proto.Parse_Request{})
require.NoError(t, err)
_, err = stream.Recv()
@ -43,7 +43,7 @@ func TestProvisionerSDK(t *testing.T) {
t.Run("ServeClosedPipe", func(t *testing.T) {
t.Parallel()
client, server := provisionersdk.TransportPipe()
client, server := provisionersdk.MemTransportPipe()
_ = client.Close()
_ = server.Close()

View File

@ -2,10 +2,11 @@ package provisionersdk
import (
"context"
"io"
"net"
"sync"
"github.com/hashicorp/yamux"
"github.com/valyala/fasthttp/fasthttputil"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
)
@ -16,24 +17,8 @@ const (
MaxMessageSize = 4 << 20
)
// TransportPipe creates an in-memory pipe for dRPC transport.
func TransportPipe() (*yamux.Session, *yamux.Session) {
c1, c2 := net.Pipe()
yamuxConfig := yamux.DefaultConfig()
yamuxConfig.LogOutput = io.Discard
client, err := yamux.Client(c1, yamuxConfig)
if err != nil {
panic(err)
}
server, err := yamux.Server(c2, yamuxConfig)
if err != nil {
panic(err)
}
return client, server
}
// Conn returns a multiplexed dRPC connection from a yamux session.
func Conn(session *yamux.Session) drpc.Conn {
// MultiplexedConn returns a multiplexed dRPC connection from a yamux session.
func MultiplexedConn(session *yamux.Session) drpc.Conn {
return &multiplexedDRPC{session}
}
@ -78,3 +63,62 @@ func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.En
}
return stream, err
}
func MemTransportPipe() (drpc.Conn, net.Listener) {
m := &memDRPC{
closed: make(chan struct{}),
l: fasthttputil.NewInmemoryListener(),
}
return m, m.l
}
type memDRPC struct {
closeOnce sync.Once
closed chan struct{}
l *fasthttputil.InmemoryListener
}
func (m *memDRPC) Close() error {
err := m.l.Close()
m.closeOnce.Do(func() { close(m.closed) })
return err
}
func (m *memDRPC) Closed() <-chan struct{} {
return m.closed
}
func (m *memDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inMessage, outMessage drpc.Message) error {
conn, err := m.l.Dial()
if err != nil {
return err
}
dConn := drpcconn.New(conn)
defer func() {
_ = dConn.Close()
_ = conn.Close()
}()
return dConn.Invoke(ctx, rpc, enc, inMessage, outMessage)
}
func (m *memDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
conn, err := m.l.Dial()
if err != nil {
return nil, err
}
dConn := drpcconn.New(conn)
stream, err := dConn.NewStream(ctx, rpc, enc)
if err == nil {
go func() {
select {
case <-stream.Context().Done():
case <-m.closed:
}
_ = dConn.Close()
_ = conn.Close()
}()
}
return stream, err
}