feat: add support for networked provisioners (#9593)

* Refactor provisionerd to use interface to connect to provisioners

Signed-off-by: Spike Curtis <spike@coder.com>

* feat: add support for networked provisioners

Signed-off-by: Spike Curtis <spike@coder.com>

* fix token length and linting

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2023-09-08 13:53:48 +04:00 committed by GitHub
parent 8b51a2f3c5
commit 11b6068112
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 994 additions and 45 deletions

View File

@ -1309,7 +1309,7 @@ func newProvisionerDaemon(
return nil, xerrors.Errorf("mkdir work dir: %w", err)
}
provisioners := provisionerd.Provisioners{}
connector := provisionerd.LocalProvisioners{}
if cfg.Provisioner.DaemonsEcho {
echoClient, echoServer := provisionersdk.MemTransportPipe()
wg.Add(1)
@ -1336,7 +1336,7 @@ func newProvisionerDaemon(
}
}
}()
provisioners[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient)
connector[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient)
} else {
tfDir := filepath.Join(cacheDir, "tf")
err = os.MkdirAll(tfDir, 0o700)
@ -1375,7 +1375,7 @@ func newProvisionerDaemon(
}
}()
provisioners[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient)
connector[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient)
}
debounce := time.Second
@ -1390,7 +1390,7 @@ func newProvisionerDaemon(
JobPollDebounce: debounce,
UpdateInterval: time.Second,
ForceCancelInterval: cfg.Provisioner.ForceCancelInterval.Value(),
Provisioners: provisioners,
Connector: connector,
TracerProvider: coderAPI.TracerProvider,
Metrics: &metrics,
}), nil

View File

@ -484,7 +484,7 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer {
JobPollInterval: 50 * time.Millisecond,
UpdateInterval: 250 * time.Millisecond,
ForceCancelInterval: time.Second,
Provisioners: provisionerd.Provisioners{
Connector: provisionerd.LocalProvisioners{
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
},
})
@ -524,7 +524,7 @@ func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uui
JobPollInterval: 50 * time.Millisecond,
UpdateInterval: 250 * time.Millisecond,
ForceCancelInterval: time.Second,
Provisioners: provisionerd.Provisioners{
Connector: provisionerd.LocalProvisioners{
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
},
})

View File

@ -124,7 +124,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
logger.Info(ctx, "starting provisioner daemon", slog.F("tags", tags))
provisioners := provisionerd.Provisioners{
connector := provisionerd.LocalProvisioners{
string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(terraformClient),
}
srv := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
@ -140,7 +140,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
JobPollInterval: pollInterval,
JobPollJitter: pollJitter,
UpdateInterval: 500 * time.Millisecond,
Provisioners: provisioners,
Connector: connector,
})
var exitErr error

View File

@ -253,7 +253,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
errCh <- err
}()
provisioners := provisionerd.Provisioners{
connector := provisionerd.LocalProvisioners{
string(database.ProvisionerTypeEcho): proto.NewDRPCProvisionerClient(terraformClient),
}
another := codersdk.New(client.URL)
@ -269,8 +269,8 @@ func TestProvisionerDaemonServe(t *testing.T) {
PreSharedKey: "provisionersftw",
})
}, &provisionerd.Options{
Logger: logger.Named("provisionerd"),
Provisioners: provisioners,
Logger: logger.Named("provisionerd"),
Connector: connector,
})
defer pd.Close()

View File

@ -0,0 +1,466 @@
package provisionerd
import (
"bufio"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"golang.org/x/xerrors"
"storj.io/drpc/drpcconn"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/provisioner/echo"
agpl "github.com/coder/coder/v2/provisionerd"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
)
// Executor is responsible for executing the remote provisioners.
//
// TODO: this interface is where we will run Kubernetes Jobs in a future
// version; right now, only the unit tests implement this interface.
type Executor interface {
// Execute a provisioner that connects back to the remoteConnector. errCh
// allows signalling of errors asynchronously and is closed on completion
// with no error.
Execute(
ctx context.Context,
provisionerType database.ProvisionerType,
jobID, token, daemonCert, daemonAddress string) (errCh <-chan error)
}
type waiter struct {
ctx context.Context
job *proto.AcquiredJob
respCh chan<- agpl.ConnectResponse
token string
}
type remoteConnector struct {
ctx context.Context
executor Executor
cert string
addr string
listener net.Listener
logger slog.Logger
tlsCfg *tls.Config
mu sync.Mutex
waiters map[string]waiter
}
func NewRemoteConnector(ctx context.Context, logger slog.Logger, exec Executor) (agpl.Connector, error) {
// nolint: gosec
listener, err := net.Listen("tcp", ":0")
if err != nil {
return nil, xerrors.Errorf("failed to listen: %w", err)
}
go func() {
<-ctx.Done()
ce := listener.Close()
logger.Debug(ctx, "listener closed", slog.Error(ce))
}()
r := &remoteConnector{
ctx: ctx,
executor: exec,
listener: listener,
addr: listener.Addr().String(),
logger: logger,
waiters: make(map[string]waiter),
}
err = r.genCert()
if err != nil {
return nil, xerrors.Errorf("failed to generate certificate: %w", err)
}
go r.listenLoop()
return r, nil
}
func (r *remoteConnector) genCert() error {
privateKey, cert, err := GenCert()
if err != nil {
return err
}
r.cert = string(cert)
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return xerrors.Errorf("failed to marshal private key: %w", err)
}
pkPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes})
certKey, err := tls.X509KeyPair(cert, pkPEM)
if err != nil {
return xerrors.Errorf("failed to create TLS certificate: %w", err)
}
r.tlsCfg = &tls.Config{Certificates: []tls.Certificate{certKey}, MinVersion: tls.VersionTLS13}
return nil
}
// GenCert is a helper function that generates a private key and certificate. It
// is exported so that we can test a certificate generated in exactly the same
// way, but with a different private key.
func GenCert() (*ecdsa.PrivateKey, []byte, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, xerrors.Errorf("generate private key: %w", err)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "Coder Provisioner Daemon",
},
DNSNames: []string{serverName},
NotBefore: time.Now(),
// cert is valid for 5 years, which is much longer than we expect this
// process to stay up. The idea is that the certificate is self-signed
// and is valid for as long as the daemon is up and starting new remote
// provisioners
NotAfter: time.Now().Add(time.Hour * 24 * 365 * 5),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, nil, xerrors.Errorf("failed to create certificate: %w", err)
}
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
return privateKey, cert, nil
}
func (r *remoteConnector) listenLoop() {
for {
conn, err := r.listener.Accept()
if err != nil {
r.logger.Info(r.ctx, "stopping listenLoop", slog.Error(err))
return
}
go r.handleConn(conn)
}
}
func (r *remoteConnector) handleConn(conn net.Conn) {
logger := r.logger.With(slog.F("remote_addr", conn.RemoteAddr()))
// If we hit an error while setting up, we want to close the connection.
// This construction makes the default to close until we explicitly set
// closeConn = false just before handing the connection over the respCh.
closeConn := true
defer func() {
if closeConn {
ce := conn.Close()
logger.Debug(r.ctx, "closed connection", slog.Error(ce))
}
}()
tlsConn := tls.Server(conn, r.tlsCfg)
err := tlsConn.HandshakeContext(r.ctx)
if err != nil {
logger.Info(r.ctx, "failed TLS handshake", slog.Error(err))
return
}
w, err := r.authenticate(tlsConn)
if err != nil {
logger.Info(r.ctx, "failed provisioner authentication", slog.Error(err))
return
}
logger = logger.With(slog.F("job_id", w.job.JobId))
logger.Info(r.ctx, "provisioner connected")
closeConn = false // we're passing the conn over the channel
w.respCh <- agpl.ConnectResponse{
Job: w.job,
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.New(tlsConn)),
}
}
var (
errInvalidJobID = xerrors.New("invalid jobID")
errInvalidToken = xerrors.New("invalid token")
)
func (r *remoteConnector) pullWaiter(jobID, token string) (waiter, error) {
r.mu.Lock()
defer r.mu.Unlock()
// provisioners authenticate with a jobID and token. The jobID is required
// because we need to use public information for the lookup, to avoid timing
// attacks against the token.
w, ok := r.waiters[jobID]
if !ok {
return waiter{}, errInvalidJobID
}
if subtle.ConstantTimeCompare([]byte(token), []byte(w.token)) == 1 {
delete(r.waiters, jobID)
return w, nil
}
return waiter{}, errInvalidToken
}
func (r *remoteConnector) Connect(
ctx context.Context, job *proto.AcquiredJob, respCh chan<- agpl.ConnectResponse,
) {
pt := database.ProvisionerType(job.Provisioner)
if !pt.Valid() {
go errResponse(job, respCh, xerrors.Errorf("invalid provisioner type: %s", job.Provisioner))
}
tb := make([]byte, 16) // 128-bit token
n, err := rand.Read(tb)
if err != nil {
go errResponse(job, respCh, err)
return
}
if n != 16 {
go errResponse(job, respCh, xerrors.New("short read generating token"))
}
token := base64.StdEncoding.EncodeToString(tb)
r.mu.Lock()
defer r.mu.Unlock()
r.waiters[job.JobId] = waiter{
ctx: ctx,
job: job,
respCh: respCh,
token: token,
}
go r.handleContextExpired(ctx, job.JobId)
errCh := r.executor.Execute(ctx, pt, job.JobId, token, r.cert, r.addr)
go r.handleExecError(job.JobId, errCh)
}
func (r *remoteConnector) handleContextExpired(ctx context.Context, jobID string) {
<-ctx.Done()
r.mu.Lock()
defer r.mu.Unlock()
w, ok := r.waiters[jobID]
if !ok {
// something else already responded.
return
}
delete(r.waiters, jobID)
// separate goroutine, so we don't hold the lock while trying to write
// to the channel.
go func() {
w.respCh <- agpl.ConnectResponse{
Job: w.job,
Error: ctx.Err(),
}
}()
}
func (r *remoteConnector) handleExecError(jobID string, errCh <-chan error) {
err := <-errCh
if err == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
w, ok := r.waiters[jobID]
if !ok {
// something else already responded.
return
}
delete(r.waiters, jobID)
// separate goroutine, so we don't hold the lock while trying to write
// to the channel.
go func() {
w.respCh <- agpl.ConnectResponse{
Job: w.job,
Error: err,
}
}()
}
func errResponse(job *proto.AcquiredJob, respCh chan<- agpl.ConnectResponse, err error) {
respCh <- agpl.ConnectResponse{
Job: job,
Error: err,
}
}
// EphemeralEcho starts an Echo provisioner that connects to provisioner daemon,
// handles one job, then exits.
func EphemeralEcho(
ctx context.Context,
logger slog.Logger,
cacheDir, jobID, token, daemonCert, daemonAddress string,
) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
workdir := filepath.Join(cacheDir, "echo")
err := os.MkdirAll(workdir, 0o777)
if err != nil {
return xerrors.Errorf("create workdir %s: %w", workdir, err)
}
conn, err := DialTLS(ctx, daemonCert, daemonAddress)
if err != nil {
return err
}
defer conn.Close()
err = AuthenticateProvisioner(conn, token, jobID)
if err != nil {
return err
}
// so it's a little confusing, but the provisioner is the client with
// respect to TLS, but is the server with respect to dRPC
exitErr := echo.Serve(ctx, &provisionersdk.ServeOptions{
Conn: conn,
Logger: logger.Named("echo"),
WorkDirectory: workdir,
})
logger.Debug(ctx, "echo.Serve done", slog.Error(exitErr))
if xerrors.Is(exitErr, context.Canceled) {
return nil
}
return exitErr
}
// DialTLS establishes a TLS connection to the given addr using the given cert
// as the root CA
func DialTLS(ctx context.Context, cert, addr string) (*tls.Conn, error) {
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM([]byte(cert))
if !ok {
return nil, xerrors.New("failed to parse daemon certificate")
}
cfg := &tls.Config{RootCAs: roots, MinVersion: tls.VersionTLS13, ServerName: serverName}
d := net.Dialer{}
nc, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, xerrors.Errorf("dial: %w", err)
}
tc := tls.Client(nc, cfg)
// Explicitly handshake so we don't have to mess with setting read
// and write deadlines.
err = tc.HandshakeContext(ctx)
if err != nil {
_ = nc.Close()
return nil, xerrors.Errorf("TLS handshake: %w", err)
}
return tc, nil
}
// Authentication Protocol:
//
// Ephemeral provisioners connect to the connector using TLS. This allows the
// provisioner to authenticate the daemon/connector based on the TLS certificate
// delivered to the provisioner out-of-band.
//
// The daemon/connector authenticates the provisioner by jobID and token, which
// are sent over the TLS connection separated by newlines. The daemon/connector
// responds with a 3-byte response to complete the handshake.
//
// Although the token is unique to the job and unambiguous, we also send the
// jobID. This allows the daemon/connector to look up the job based on public
// information (jobID), shielding the token from timing attacks. I'm not sure
// how practical a timing attack against an in-memory golang map is, but it's
// better to avoid it entirely. After the job is looked up by jobID, we do a
// constant time compare on the token to authenticate.
//
// Also note that we don't really have to worry about cross-version
// compatibility in this protocol, since the provisioners are always started by
// the same daemon/connector as they connect to.
// Responses are all exactly 3 bytes so that don't have to use a scanner
// which might accidentally buffer some of the first dRPC request.
const (
responseOK = "OK\n"
responseInvalidJobID = "IJ\n"
responseInvalidToken = "IT\n"
)
// serverName is the name on the x509 certificate the daemon/connector generates
// this name doesn't matter as long as both sides agree, since the provisioners
// get the IP address directly. It is also fine to reuse, since each generates
// a unique private key and self-signs, we will not correctly authenticate to
// a different provisionerd.
const serverName = "provisionerd"
// AuthenticateProvisioner performs the provisioner's side of the authentication
// protocol.
func AuthenticateProvisioner(conn io.ReadWriter, token, jobID string) error {
sb := strings.Builder{}
_, _ = sb.WriteString(jobID)
_, _ = sb.WriteString("\n")
_, _ = sb.WriteString(token)
_, _ = sb.WriteString("\n")
_, err := conn.Write([]byte(sb.String()))
if err != nil {
return xerrors.Errorf("failed to write token: %w", err)
}
b := make([]byte, 3)
_, err = conn.Read(b)
if err != nil {
return xerrors.Errorf("failed to read token resp: %w", err)
}
if string(b) != responseOK {
// convert to a human-readable format
var reason string
switch string(b) {
case responseInvalidJobID:
reason = "invalid job ID"
case responseInvalidToken:
reason = "invalid token"
default:
reason = fmt.Sprintf("unknown response code: %s", b)
}
return xerrors.Errorf("authenticate protocol error: %s", reason)
}
return nil
}
// authenticate performs the daemon/connector's side of the authentication
// protocol.
func (r *remoteConnector) authenticate(conn io.ReadWriter) (waiter, error) {
// it's fine to use a scanner here because the provisioner side doesn't hand
// off the connection to the dRPC handler until after we send our response.
scn := bufio.NewScanner(conn)
if ok := scn.Scan(); !ok {
return waiter{}, xerrors.Errorf("failed to receive jobID: %w", scn.Err())
}
jobID := scn.Text()
if ok := scn.Scan(); !ok {
return waiter{}, xerrors.Errorf("failed to receive job token: %w", scn.Err())
}
token := scn.Text()
w, err := r.pullWaiter(jobID, token)
if err == nil {
_, err = conn.Write([]byte(responseOK))
if err != nil {
err = xerrors.Errorf("failed to write authentication response: %w", err)
// if we fail here, it's our responsibility to send the error response on the respCh
// because we're not going to return the waiter to the caller.
go errResponse(w.job, w.respCh, err)
return waiter{}, err
}
return w, nil
}
if xerrors.Is(err, errInvalidJobID) {
_, wErr := conn.Write([]byte(responseInvalidJobID))
r.logger.Debug(r.ctx, "responded invalid jobID", slog.Error(wErr))
}
if xerrors.Is(err, errInvalidToken) {
_, wErr := conn.Write([]byte(responseInvalidToken))
r.logger.Debug(r.ctx, "responded invalid token", slog.Error(wErr))
}
return waiter{}, err
}

View File

@ -0,0 +1,371 @@
package provisionerd_test
import (
"context"
"crypto/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/enterprise/provisionerd"
"github.com/coder/coder/v2/provisioner/echo"
agpl "github.com/coder/coder/v2/provisionerd"
"github.com/coder/coder/v2/provisionerd/proto"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestRemoteConnector_Mainline(t *testing.T) {
t.Parallel()
cases := []struct {
name string
smokescreen bool
}{
{name: "NoSmokescreen", smokescreen: false},
{name: "Smokescreen", smokescreen: true},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
exec := &testExecutor{
t: t,
logger: logger,
smokescreen: tc.smokescreen,
}
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
require.NoError(t, err)
respCh := make(chan agpl.ConnectResponse)
job := &proto.AcquiredJob{
JobId: "test-job",
Provisioner: string(database.ProvisionerTypeEcho),
}
uut.Connect(ctx, job, respCh)
var resp agpl.ConnectResponse
select {
case <-ctx.Done():
t.Error("timeout waiting for connect response")
case resp = <-respCh:
// OK
}
require.NoError(t, resp.Error)
require.Equal(t, job, resp.Job)
require.NotNil(t, resp.Client)
// check that we can communicate with the provisioner
er := &echo.Responses{
Parse: echo.ParseComplete,
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: echo.PlanComplete,
}
arc, err := echo.Tar(er)
require.NoError(t, err)
c := resp.Client
s, err := c.Session(ctx)
require.NoError(t, err)
err = s.Send(&sdkproto.Request{Type: &sdkproto.Request_Config{Config: &sdkproto.Config{
TemplateSourceArchive: arc,
}}})
require.NoError(t, err)
err = s.Send(&sdkproto.Request{Type: &sdkproto.Request_Parse{Parse: &sdkproto.ParseRequest{}}})
require.NoError(t, err)
r, err := s.Recv()
require.NoError(t, err)
require.IsType(t, &sdkproto.Response_Parse{}, r.Type)
})
}
}
func TestRemoteConnector_BadToken(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
exec := &testExecutor{
t: t,
logger: logger,
overrideToken: "bad-token",
}
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
require.NoError(t, err)
respCh := make(chan agpl.ConnectResponse)
job := &proto.AcquiredJob{
JobId: "test-job",
Provisioner: string(database.ProvisionerTypeEcho),
}
uut.Connect(ctx, job, respCh)
var resp agpl.ConnectResponse
select {
case <-ctx.Done():
t.Fatal("timeout waiting for connect response")
case resp = <-respCh:
// OK
}
require.Equal(t, job, resp.Job)
require.ErrorContains(t, resp.Error, "invalid token")
}
func TestRemoteConnector_BadJobID(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
exec := &testExecutor{
t: t,
logger: logger,
overrideJobID: "bad-job",
}
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
require.NoError(t, err)
respCh := make(chan agpl.ConnectResponse)
job := &proto.AcquiredJob{
JobId: "test-job",
Provisioner: string(database.ProvisionerTypeEcho),
}
uut.Connect(ctx, job, respCh)
var resp agpl.ConnectResponse
select {
case <-ctx.Done():
t.Fatal("timeout waiting for connect response")
case resp = <-respCh:
// OK
}
require.Equal(t, job, resp.Job)
require.ErrorContains(t, resp.Error, "invalid job ID")
}
func TestRemoteConnector_BadCert(t *testing.T) {
t.Parallel()
_, cert, err := provisionerd.GenCert()
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
exec := &testExecutor{
t: t,
logger: logger,
overrideCert: string(cert),
}
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
require.NoError(t, err)
respCh := make(chan agpl.ConnectResponse)
job := &proto.AcquiredJob{
JobId: "test-job",
Provisioner: string(database.ProvisionerTypeEcho),
}
uut.Connect(ctx, job, respCh)
var resp agpl.ConnectResponse
select {
case <-ctx.Done():
t.Fatal("timeout waiting for connect response")
case resp = <-respCh:
// OK
}
require.Equal(t, job, resp.Job)
require.ErrorContains(t, resp.Error, "certificate signed by unknown authority")
}
func TestRemoteConnector_Fuzz(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
exec := newFuzzExecutor(t, logger)
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
require.NoError(t, err)
respCh := make(chan agpl.ConnectResponse)
job := &proto.AcquiredJob{
JobId: "test-job",
Provisioner: string(database.ProvisionerTypeEcho),
}
connectCtx, connectCtxCancel := context.WithCancel(ctx)
defer connectCtxCancel()
uut.Connect(connectCtx, job, respCh)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for fuzzer")
case <-exec.done:
// Connector hung up on the fuzzer
}
require.Less(t, exec.bytesFuzzed, 2<<20, "should not allow more than 1 MiB")
connectCtxCancel()
var resp agpl.ConnectResponse
select {
case <-ctx.Done():
t.Fatal("timeout waiting for connect response")
case resp = <-respCh:
// OK
}
require.Equal(t, job, resp.Job)
require.ErrorIs(t, resp.Error, context.Canceled)
}
func TestRemoteConnector_CancelConnect(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
exec := &testExecutor{
t: t,
logger: logger,
dontStart: true,
}
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
require.NoError(t, err)
respCh := make(chan agpl.ConnectResponse)
job := &proto.AcquiredJob{
JobId: "test-job",
Provisioner: string(database.ProvisionerTypeEcho),
}
connectCtx, connectCtxCancel := context.WithCancel(ctx)
defer connectCtxCancel()
uut.Connect(connectCtx, job, respCh)
connectCtxCancel()
var resp agpl.ConnectResponse
select {
case <-ctx.Done():
t.Fatal("timeout waiting for connect response")
case resp = <-respCh:
// OK
}
require.Equal(t, job, resp.Job)
require.ErrorIs(t, resp.Error, context.Canceled)
}
type testExecutor struct {
t *testing.T
logger slog.Logger
overrideToken string
overrideJobID string
overrideCert string
// dontStart simulates when everything looks good to the connector but
// the provisioner never starts
dontStart bool
// smokescreen starts a connection that fails authentication before starting
// the real connection. Tests that failed connections don't interfere with
// real ones.
smokescreen bool
}
func (e *testExecutor) Execute(
ctx context.Context,
provisionerType database.ProvisionerType,
jobID, token, daemonCert, daemonAddress string,
) <-chan error {
assert.Equal(e.t, database.ProvisionerTypeEcho, provisionerType)
if e.overrideToken != "" {
token = e.overrideToken
}
if e.overrideJobID != "" {
jobID = e.overrideJobID
}
if e.overrideCert != "" {
daemonCert = e.overrideCert
}
cacheDir := e.t.TempDir()
errCh := make(chan error)
go func() {
defer close(errCh)
if e.smokescreen {
e.doSmokeScreen(ctx, jobID, daemonCert, daemonAddress)
}
if !e.dontStart {
err := provisionerd.EphemeralEcho(ctx, e.logger, cacheDir, jobID, token, daemonCert, daemonAddress)
e.logger.Debug(ctx, "provisioner done", slog.Error(err))
if err != nil {
errCh <- err
}
}
}()
return errCh
}
func (e *testExecutor) doSmokeScreen(ctx context.Context, jobID, daemonCert, daemonAddress string) {
conn, err := provisionerd.DialTLS(ctx, daemonCert, daemonAddress)
if !assert.NoError(e.t, err) {
return
}
defer conn.Close()
err = provisionerd.AuthenticateProvisioner(conn, "smokescreen", jobID)
assert.ErrorContains(e.t, err, "invalid token")
}
type fuzzExecutor struct {
t *testing.T
logger slog.Logger
done chan struct{}
bytesFuzzed int
}
func newFuzzExecutor(t *testing.T, logger slog.Logger) *fuzzExecutor {
return &fuzzExecutor{
t: t,
logger: logger,
done: make(chan struct{}),
bytesFuzzed: 0,
}
}
func (e *fuzzExecutor) Execute(
ctx context.Context,
_ database.ProvisionerType,
_, _, daemonCert, daemonAddress string,
) <-chan error {
errCh := make(chan error)
go func() {
defer close(errCh)
defer close(e.done)
conn, err := provisionerd.DialTLS(ctx, daemonCert, daemonAddress)
assert.NoError(e.t, err)
rb := make([]byte, 128)
for {
if ctx.Err() != nil {
e.t.Error("context canceled while fuzzing")
return
}
n, err := rand.Read(rb)
if err != nil {
e.t.Errorf("random read: %s", err)
}
if n < 128 {
e.t.Error("short random read")
return
}
// replace newlines so the Connector doesn't think we are done
// with the JobID
for i := 0; i < len(rb); i++ {
if rb[i] == '\n' || rb[i] == '\r' {
rb[i] = 'A'
}
}
n, err = conn.Write(rb)
e.bytesFuzzed += n
if err != nil {
return
}
}
}()
return errCh
}

View File

@ -0,0 +1,28 @@
package provisionerd
import (
"context"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/provisionerd/proto"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
)
// LocalProvisioners is a Connector that stores a static set of in-process
// provisioners.
type LocalProvisioners map[string]sdkproto.DRPCProvisionerClient
func (l LocalProvisioners) Connect(_ context.Context, job *proto.AcquiredJob, respCh chan<- ConnectResponse) {
r := ConnectResponse{Job: job}
p, ok := l[job.Provisioner]
if ok {
r.Client = p
} else {
r.Error = xerrors.Errorf("missing provisioner type %s", job.Provisioner)
}
go func() {
respCh <- r
}()
}

View File

@ -32,8 +32,24 @@ import (
// Dialer represents the function to create a daemon client connection.
type Dialer func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error)
// Provisioners maps provisioner ID to implementation.
type Provisioners map[string]sdkproto.DRPCProvisionerClient
// ConnectResponse is the response returned asynchronously from Connector.Connect
// containing either the Provisioner Client or an Error. The Job is also returned
// unaltered to disambiguate responses if the respCh is shared among multiple jobs
type ConnectResponse struct {
Job *proto.AcquiredJob
Client sdkproto.DRPCProvisionerClient
Error error
}
// Connector allows the provisioner daemon to Connect to a provisioner
// for the given job.
type Connector interface {
// Connect to the correct provisioner for the given job. The response is
// delivered asynchronously over the respCh. If the provided context expires,
// the Connector may stop waiting for the provisioner and return an error
// response.
Connect(ctx context.Context, job *proto.AcquiredJob, respCh chan<- ConnectResponse)
}
// Options provides customizations to the behavior of a provisioner daemon.
type Options struct {
@ -47,7 +63,7 @@ type Options struct {
JobPollInterval time.Duration
JobPollJitter time.Duration
JobPollDebounce time.Duration
Provisioners Provisioners
Connector Connector
}
// New creates and starts a provisioner daemon.
@ -375,11 +391,13 @@ func (p *Server) acquireJob(ctx context.Context) {
p.opts.Logger.Debug(ctx, "acquired job", fields...)
provisioner, ok := p.opts.Provisioners[job.Provisioner]
if !ok {
respCh := make(chan ConnectResponse)
p.opts.Connector.Connect(ctx, job, respCh)
resp := <-respCh
if resp.Error != nil {
err := p.FailJob(ctx, &proto.FailedJob{
JobId: job.JobId,
Error: fmt.Sprintf("no provisioner %s", job.Provisioner),
Error: fmt.Sprintf("failed to connect to provisioner: %s", resp.Error),
})
if err != nil {
p.opts.Logger.Error(ctx, "provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err))
@ -394,7 +412,7 @@ func (p *Server) acquireJob(ctx context.Context) {
Updater: p,
QuotaCommitter: p,
Logger: p.opts.Logger.Named("runner"),
Provisioner: provisioner,
Provisioner: resp.Client,
UpdateInterval: p.opts.UpdateInterval,
ForceCancelInterval: p.opts.ForceCancelInterval,
LogDebounceInterval: p.opts.LogBufferInterval,

View File

@ -60,7 +60,7 @@ func TestProvisionerd(t *testing.T) {
})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil
}, provisionerd.Provisioners{})
}, provisionerd.LocalProvisioners{})
require.NoError(t, closer.Close())
})
@ -74,7 +74,7 @@ func TestProvisionerd(t *testing.T) {
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
defer close(completeChan)
return nil, xerrors.New("an error")
}, provisionerd.Provisioners{})
}, provisionerd.LocalProvisioners{})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
@ -101,7 +101,7 @@ func TestProvisionerd(t *testing.T) {
},
updateJob: noopUpdateJob,
}), nil
}, provisionerd.Provisioners{})
}, provisionerd.LocalProvisioners{})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
@ -141,7 +141,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(_ *provisionersdk.Session, _ *sdkproto.ParseRequest, _ <-chan struct{}) *sdkproto.ParseComplete {
closerMutex.Lock()
@ -195,7 +195,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
@ -237,7 +237,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(
_ *provisionersdk.Session,
@ -304,7 +304,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(
s *provisionersdk.Session,
@ -398,7 +398,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
_ *provisionersdk.Session,
@ -472,7 +472,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
s *provisionersdk.Session,
@ -553,7 +553,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
s *provisionersdk.Session,
@ -638,7 +638,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
s *provisionersdk.Session,
@ -714,7 +714,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
s *provisionersdk.Session,
@ -800,7 +800,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
s *provisionersdk.Session,
@ -886,7 +886,7 @@ func TestProvisionerd(t *testing.T) {
}()
}
return client, nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
_ *provisionersdk.Session,
@ -971,7 +971,7 @@ func TestProvisionerd(t *testing.T) {
}()
}
return client, nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
_ *provisionersdk.Session,
@ -1055,7 +1055,7 @@ func TestProvisionerd(t *testing.T) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.Provisioners{
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
plan: func(
s *provisionersdk.Session,
@ -1103,12 +1103,12 @@ func createTar(t *testing.T, files map[string]string) []byte {
}
// Creates a provisionerd implementation with the provided dialer and provisioners.
func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners provisionerd.Provisioners) *provisionerd.Server {
func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, connector provisionerd.LocalProvisioners) *provisionerd.Server {
server := provisionerd.New(dialer, &provisionerd.Options{
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("provisionerd").Leveled(slog.LevelDebug),
JobPollInterval: 50 * time.Millisecond,
UpdateInterval: 50 * time.Millisecond,
Provisioners: provisioners,
Connector: connector,
})
t.Cleanup(func() {
_ = server.Close()

View File

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/yamux"
"github.com/valyala/fasthttp/fasthttputil"
"golang.org/x/xerrors"
"storj.io/drpc"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
@ -21,8 +22,10 @@ import (
// ServeOptions are configurations to serve a provisioner.
type ServeOptions struct {
// Conn specifies a custom transport to serve the dRPC connection.
Listener net.Listener
// Listener serves multiple connections. Cannot be combined with Conn.
Listener net.Listener
// Conn is a single connection to serve. Cannot be combined with Listener.
Conn drpc.Transport
Logger slog.Logger
WorkDirectory string
}
@ -38,8 +41,11 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error {
if options == nil {
options = &ServeOptions{}
}
// Default to using stdio.
if options.Listener == nil {
if options.Listener != nil && options.Conn != nil {
return xerrors.New("specify Listener or Conn, not both")
}
// Default to using stdio with yamux as a Listener
if options.Listener == nil && options.Conn == nil {
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
stdio, err := yamux.Server(&readWriteCloser{
@ -75,10 +81,12 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error {
return xerrors.Errorf("register provisioner: %w", err)
}
srv := drpcserver.New(&tracing.DRPCHandler{Handler: mux})
// Only serve a single connection on the transport.
// Transports are not multiplexed, and provisioners are
// short-lived processes that can be executed concurrently.
err = srv.Serve(ctx, options.Listener)
if options.Listener != nil {
err = srv.Serve(ctx, options.Listener)
} else if options.Conn != nil {
err = srv.ServeOne(ctx, options.Conn)
}
if err != nil {
if errors.Is(err, io.EOF) ||
errors.Is(err, context.Canceled) ||

View File

@ -2,14 +2,17 @@ package provisionersdk_test
import (
"context"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"storj.io/drpc/drpcconn"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
@ -18,7 +21,7 @@ func TestMain(m *testing.M) {
func TestProvisionerSDK(t *testing.T) {
t.Parallel()
t.Run("Serve", func(t *testing.T) {
t.Run("ServeListener", func(t *testing.T) {
t.Parallel()
client, server := provisionersdk.MemTransportPipe()
defer client.Close()
@ -72,6 +75,61 @@ func TestProvisionerSDK(t *testing.T) {
})
require.NoError(t, err)
})
t.Run("ServeConn", func(t *testing.T) {
t.Parallel()
client, server := net.Pipe()
defer client.Close()
defer server.Close()
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancelFunc()
srvErr := make(chan error, 1)
go func() {
err := provisionersdk.Serve(ctx, unimplementedServer{}, &provisionersdk.ServeOptions{
Conn: server,
WorkDirectory: t.TempDir(),
})
srvErr <- err
}()
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
s, err := api.Session(ctx)
require.NoError(t, err)
err = s.Send(&proto.Request{Type: &proto.Request_Config{Config: &proto.Config{}}})
require.NoError(t, err)
err = s.Send(&proto.Request{Type: &proto.Request_Parse{Parse: &proto.ParseRequest{}}})
require.NoError(t, err)
msg, err := s.Recv()
require.NoError(t, err)
require.Equal(t, "unimplemented", msg.GetParse().GetError())
err = s.Send(&proto.Request{Type: &proto.Request_Plan{Plan: &proto.PlanRequest{}}})
require.NoError(t, err)
msg, err = s.Recv()
require.NoError(t, err)
// Plan has no error so that we're allowed to run Apply
require.Equal(t, "", msg.GetPlan().GetError())
err = s.Send(&proto.Request{Type: &proto.Request_Apply{Apply: &proto.ApplyRequest{}}})
require.NoError(t, err)
msg, err = s.Recv()
require.NoError(t, err)
require.Equal(t, "unimplemented", msg.GetApply().GetError())
// Check provisioner closes when the connection does
err = s.Close()
require.NoError(t, err)
err = api.DRPCConn().Close()
require.NoError(t, err)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for provisioner")
case err = <-srvErr:
require.NoError(t, err)
}
})
}
type unimplementedServer struct{}