mirror of https://github.com/coder/coder.git
feat: add support for networked provisioners (#9593)
* Refactor provisionerd to use interface to connect to provisioners Signed-off-by: Spike Curtis <spike@coder.com> * feat: add support for networked provisioners Signed-off-by: Spike Curtis <spike@coder.com> * fix token length and linting Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
8b51a2f3c5
commit
11b6068112
|
@ -1309,7 +1309,7 @@ func newProvisionerDaemon(
|
|||
return nil, xerrors.Errorf("mkdir work dir: %w", err)
|
||||
}
|
||||
|
||||
provisioners := provisionerd.Provisioners{}
|
||||
connector := provisionerd.LocalProvisioners{}
|
||||
if cfg.Provisioner.DaemonsEcho {
|
||||
echoClient, echoServer := provisionersdk.MemTransportPipe()
|
||||
wg.Add(1)
|
||||
|
@ -1336,7 +1336,7 @@ func newProvisionerDaemon(
|
|||
}
|
||||
}
|
||||
}()
|
||||
provisioners[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient)
|
||||
connector[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient)
|
||||
} else {
|
||||
tfDir := filepath.Join(cacheDir, "tf")
|
||||
err = os.MkdirAll(tfDir, 0o700)
|
||||
|
@ -1375,7 +1375,7 @@ func newProvisionerDaemon(
|
|||
}
|
||||
}()
|
||||
|
||||
provisioners[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient)
|
||||
connector[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient)
|
||||
}
|
||||
|
||||
debounce := time.Second
|
||||
|
@ -1390,7 +1390,7 @@ func newProvisionerDaemon(
|
|||
JobPollDebounce: debounce,
|
||||
UpdateInterval: time.Second,
|
||||
ForceCancelInterval: cfg.Provisioner.ForceCancelInterval.Value(),
|
||||
Provisioners: provisioners,
|
||||
Connector: connector,
|
||||
TracerProvider: coderAPI.TracerProvider,
|
||||
Metrics: &metrics,
|
||||
}), nil
|
||||
|
|
|
@ -484,7 +484,7 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer {
|
|||
JobPollInterval: 50 * time.Millisecond,
|
||||
UpdateInterval: 250 * time.Millisecond,
|
||||
ForceCancelInterval: time.Second,
|
||||
Provisioners: provisionerd.Provisioners{
|
||||
Connector: provisionerd.LocalProvisioners{
|
||||
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
|
||||
},
|
||||
})
|
||||
|
@ -524,7 +524,7 @@ func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uui
|
|||
JobPollInterval: 50 * time.Millisecond,
|
||||
UpdateInterval: 250 * time.Millisecond,
|
||||
ForceCancelInterval: time.Second,
|
||||
Provisioners: provisionerd.Provisioners{
|
||||
Connector: provisionerd.LocalProvisioners{
|
||||
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
|
||||
},
|
||||
})
|
||||
|
|
|
@ -124,7 +124,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
|
|||
|
||||
logger.Info(ctx, "starting provisioner daemon", slog.F("tags", tags))
|
||||
|
||||
provisioners := provisionerd.Provisioners{
|
||||
connector := provisionerd.LocalProvisioners{
|
||||
string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(terraformClient),
|
||||
}
|
||||
srv := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
|
||||
|
@ -140,7 +140,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
|
|||
JobPollInterval: pollInterval,
|
||||
JobPollJitter: pollJitter,
|
||||
UpdateInterval: 500 * time.Millisecond,
|
||||
Provisioners: provisioners,
|
||||
Connector: connector,
|
||||
})
|
||||
|
||||
var exitErr error
|
||||
|
|
|
@ -253,7 +253,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
errCh <- err
|
||||
}()
|
||||
|
||||
provisioners := provisionerd.Provisioners{
|
||||
connector := provisionerd.LocalProvisioners{
|
||||
string(database.ProvisionerTypeEcho): proto.NewDRPCProvisionerClient(terraformClient),
|
||||
}
|
||||
another := codersdk.New(client.URL)
|
||||
|
@ -269,8 +269,8 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
PreSharedKey: "provisionersftw",
|
||||
})
|
||||
}, &provisionerd.Options{
|
||||
Logger: logger.Named("provisionerd"),
|
||||
Provisioners: provisioners,
|
||||
Logger: logger.Named("provisionerd"),
|
||||
Connector: connector,
|
||||
})
|
||||
defer pd.Close()
|
||||
|
||||
|
|
|
@ -0,0 +1,466 @@
|
|||
package provisionerd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"storj.io/drpc/drpcconn"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
agpl "github.com/coder/coder/v2/provisionerd"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
)
|
||||
|
||||
// Executor is responsible for executing the remote provisioners.
|
||||
//
|
||||
// TODO: this interface is where we will run Kubernetes Jobs in a future
|
||||
// version; right now, only the unit tests implement this interface.
|
||||
type Executor interface {
|
||||
// Execute a provisioner that connects back to the remoteConnector. errCh
|
||||
// allows signalling of errors asynchronously and is closed on completion
|
||||
// with no error.
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
provisionerType database.ProvisionerType,
|
||||
jobID, token, daemonCert, daemonAddress string) (errCh <-chan error)
|
||||
}
|
||||
|
||||
type waiter struct {
|
||||
ctx context.Context
|
||||
job *proto.AcquiredJob
|
||||
respCh chan<- agpl.ConnectResponse
|
||||
token string
|
||||
}
|
||||
|
||||
type remoteConnector struct {
|
||||
ctx context.Context
|
||||
executor Executor
|
||||
cert string
|
||||
addr string
|
||||
listener net.Listener
|
||||
logger slog.Logger
|
||||
tlsCfg *tls.Config
|
||||
|
||||
mu sync.Mutex
|
||||
waiters map[string]waiter
|
||||
}
|
||||
|
||||
func NewRemoteConnector(ctx context.Context, logger slog.Logger, exec Executor) (agpl.Connector, error) {
|
||||
// nolint: gosec
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
ce := listener.Close()
|
||||
logger.Debug(ctx, "listener closed", slog.Error(ce))
|
||||
}()
|
||||
r := &remoteConnector{
|
||||
ctx: ctx,
|
||||
executor: exec,
|
||||
listener: listener,
|
||||
addr: listener.Addr().String(),
|
||||
logger: logger,
|
||||
waiters: make(map[string]waiter),
|
||||
}
|
||||
err = r.genCert()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to generate certificate: %w", err)
|
||||
}
|
||||
go r.listenLoop()
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *remoteConnector) genCert() error {
|
||||
privateKey, cert, err := GenCert()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.cert = string(cert)
|
||||
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to marshal private key: %w", err)
|
||||
}
|
||||
pkPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes})
|
||||
certKey, err := tls.X509KeyPair(cert, pkPEM)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create TLS certificate: %w", err)
|
||||
}
|
||||
r.tlsCfg = &tls.Config{Certificates: []tls.Certificate{certKey}, MinVersion: tls.VersionTLS13}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenCert is a helper function that generates a private key and certificate. It
|
||||
// is exported so that we can test a certificate generated in exactly the same
|
||||
// way, but with a different private key.
|
||||
func GenCert() (*ecdsa.PrivateKey, []byte, error) {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("generate private key: %w", err)
|
||||
}
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "Coder Provisioner Daemon",
|
||||
},
|
||||
DNSNames: []string{serverName},
|
||||
NotBefore: time.Now(),
|
||||
// cert is valid for 5 years, which is much longer than we expect this
|
||||
// process to stay up. The idea is that the certificate is self-signed
|
||||
// and is valid for as long as the daemon is up and starting new remote
|
||||
// provisioners
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 365 * 5),
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
return privateKey, cert, nil
|
||||
}
|
||||
|
||||
func (r *remoteConnector) listenLoop() {
|
||||
for {
|
||||
conn, err := r.listener.Accept()
|
||||
if err != nil {
|
||||
r.logger.Info(r.ctx, "stopping listenLoop", slog.Error(err))
|
||||
return
|
||||
}
|
||||
go r.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *remoteConnector) handleConn(conn net.Conn) {
|
||||
logger := r.logger.With(slog.F("remote_addr", conn.RemoteAddr()))
|
||||
|
||||
// If we hit an error while setting up, we want to close the connection.
|
||||
// This construction makes the default to close until we explicitly set
|
||||
// closeConn = false just before handing the connection over the respCh.
|
||||
closeConn := true
|
||||
defer func() {
|
||||
if closeConn {
|
||||
ce := conn.Close()
|
||||
logger.Debug(r.ctx, "closed connection", slog.Error(ce))
|
||||
}
|
||||
}()
|
||||
|
||||
tlsConn := tls.Server(conn, r.tlsCfg)
|
||||
err := tlsConn.HandshakeContext(r.ctx)
|
||||
if err != nil {
|
||||
logger.Info(r.ctx, "failed TLS handshake", slog.Error(err))
|
||||
return
|
||||
}
|
||||
w, err := r.authenticate(tlsConn)
|
||||
if err != nil {
|
||||
logger.Info(r.ctx, "failed provisioner authentication", slog.Error(err))
|
||||
return
|
||||
}
|
||||
logger = logger.With(slog.F("job_id", w.job.JobId))
|
||||
logger.Info(r.ctx, "provisioner connected")
|
||||
closeConn = false // we're passing the conn over the channel
|
||||
w.respCh <- agpl.ConnectResponse{
|
||||
Job: w.job,
|
||||
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.New(tlsConn)),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
errInvalidJobID = xerrors.New("invalid jobID")
|
||||
errInvalidToken = xerrors.New("invalid token")
|
||||
)
|
||||
|
||||
func (r *remoteConnector) pullWaiter(jobID, token string) (waiter, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
// provisioners authenticate with a jobID and token. The jobID is required
|
||||
// because we need to use public information for the lookup, to avoid timing
|
||||
// attacks against the token.
|
||||
w, ok := r.waiters[jobID]
|
||||
if !ok {
|
||||
return waiter{}, errInvalidJobID
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(token), []byte(w.token)) == 1 {
|
||||
delete(r.waiters, jobID)
|
||||
return w, nil
|
||||
}
|
||||
return waiter{}, errInvalidToken
|
||||
}
|
||||
|
||||
func (r *remoteConnector) Connect(
|
||||
ctx context.Context, job *proto.AcquiredJob, respCh chan<- agpl.ConnectResponse,
|
||||
) {
|
||||
pt := database.ProvisionerType(job.Provisioner)
|
||||
if !pt.Valid() {
|
||||
go errResponse(job, respCh, xerrors.Errorf("invalid provisioner type: %s", job.Provisioner))
|
||||
}
|
||||
tb := make([]byte, 16) // 128-bit token
|
||||
n, err := rand.Read(tb)
|
||||
if err != nil {
|
||||
go errResponse(job, respCh, err)
|
||||
return
|
||||
}
|
||||
if n != 16 {
|
||||
go errResponse(job, respCh, xerrors.New("short read generating token"))
|
||||
}
|
||||
token := base64.StdEncoding.EncodeToString(tb)
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.waiters[job.JobId] = waiter{
|
||||
ctx: ctx,
|
||||
job: job,
|
||||
respCh: respCh,
|
||||
token: token,
|
||||
}
|
||||
go r.handleContextExpired(ctx, job.JobId)
|
||||
errCh := r.executor.Execute(ctx, pt, job.JobId, token, r.cert, r.addr)
|
||||
go r.handleExecError(job.JobId, errCh)
|
||||
}
|
||||
|
||||
func (r *remoteConnector) handleContextExpired(ctx context.Context, jobID string) {
|
||||
<-ctx.Done()
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
w, ok := r.waiters[jobID]
|
||||
if !ok {
|
||||
// something else already responded.
|
||||
return
|
||||
}
|
||||
delete(r.waiters, jobID)
|
||||
// separate goroutine, so we don't hold the lock while trying to write
|
||||
// to the channel.
|
||||
go func() {
|
||||
w.respCh <- agpl.ConnectResponse{
|
||||
Job: w.job,
|
||||
Error: ctx.Err(),
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *remoteConnector) handleExecError(jobID string, errCh <-chan error) {
|
||||
err := <-errCh
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
w, ok := r.waiters[jobID]
|
||||
if !ok {
|
||||
// something else already responded.
|
||||
return
|
||||
}
|
||||
delete(r.waiters, jobID)
|
||||
// separate goroutine, so we don't hold the lock while trying to write
|
||||
// to the channel.
|
||||
go func() {
|
||||
w.respCh <- agpl.ConnectResponse{
|
||||
Job: w.job,
|
||||
Error: err,
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func errResponse(job *proto.AcquiredJob, respCh chan<- agpl.ConnectResponse, err error) {
|
||||
respCh <- agpl.ConnectResponse{
|
||||
Job: job,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// EphemeralEcho starts an Echo provisioner that connects to provisioner daemon,
|
||||
// handles one job, then exits.
|
||||
func EphemeralEcho(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
cacheDir, jobID, token, daemonCert, daemonAddress string,
|
||||
) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
workdir := filepath.Join(cacheDir, "echo")
|
||||
err := os.MkdirAll(workdir, 0o777)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create workdir %s: %w", workdir, err)
|
||||
}
|
||||
conn, err := DialTLS(ctx, daemonCert, daemonAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
err = AuthenticateProvisioner(conn, token, jobID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// so it's a little confusing, but the provisioner is the client with
|
||||
// respect to TLS, but is the server with respect to dRPC
|
||||
exitErr := echo.Serve(ctx, &provisionersdk.ServeOptions{
|
||||
Conn: conn,
|
||||
Logger: logger.Named("echo"),
|
||||
WorkDirectory: workdir,
|
||||
})
|
||||
logger.Debug(ctx, "echo.Serve done", slog.Error(exitErr))
|
||||
|
||||
if xerrors.Is(exitErr, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return exitErr
|
||||
}
|
||||
|
||||
// DialTLS establishes a TLS connection to the given addr using the given cert
|
||||
// as the root CA
|
||||
func DialTLS(ctx context.Context, cert, addr string) (*tls.Conn, error) {
|
||||
roots := x509.NewCertPool()
|
||||
ok := roots.AppendCertsFromPEM([]byte(cert))
|
||||
if !ok {
|
||||
return nil, xerrors.New("failed to parse daemon certificate")
|
||||
}
|
||||
cfg := &tls.Config{RootCAs: roots, MinVersion: tls.VersionTLS13, ServerName: serverName}
|
||||
d := net.Dialer{}
|
||||
nc, err := d.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial: %w", err)
|
||||
}
|
||||
tc := tls.Client(nc, cfg)
|
||||
// Explicitly handshake so we don't have to mess with setting read
|
||||
// and write deadlines.
|
||||
err = tc.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
_ = nc.Close()
|
||||
return nil, xerrors.Errorf("TLS handshake: %w", err)
|
||||
}
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
// Authentication Protocol:
|
||||
//
|
||||
// Ephemeral provisioners connect to the connector using TLS. This allows the
|
||||
// provisioner to authenticate the daemon/connector based on the TLS certificate
|
||||
// delivered to the provisioner out-of-band.
|
||||
//
|
||||
// The daemon/connector authenticates the provisioner by jobID and token, which
|
||||
// are sent over the TLS connection separated by newlines. The daemon/connector
|
||||
// responds with a 3-byte response to complete the handshake.
|
||||
//
|
||||
// Although the token is unique to the job and unambiguous, we also send the
|
||||
// jobID. This allows the daemon/connector to look up the job based on public
|
||||
// information (jobID), shielding the token from timing attacks. I'm not sure
|
||||
// how practical a timing attack against an in-memory golang map is, but it's
|
||||
// better to avoid it entirely. After the job is looked up by jobID, we do a
|
||||
// constant time compare on the token to authenticate.
|
||||
//
|
||||
// Also note that we don't really have to worry about cross-version
|
||||
// compatibility in this protocol, since the provisioners are always started by
|
||||
// the same daemon/connector as they connect to.
|
||||
|
||||
// Responses are all exactly 3 bytes so that don't have to use a scanner
|
||||
// which might accidentally buffer some of the first dRPC request.
|
||||
const (
|
||||
responseOK = "OK\n"
|
||||
responseInvalidJobID = "IJ\n"
|
||||
responseInvalidToken = "IT\n"
|
||||
)
|
||||
|
||||
// serverName is the name on the x509 certificate the daemon/connector generates
|
||||
// this name doesn't matter as long as both sides agree, since the provisioners
|
||||
// get the IP address directly. It is also fine to reuse, since each generates
|
||||
// a unique private key and self-signs, we will not correctly authenticate to
|
||||
// a different provisionerd.
|
||||
const serverName = "provisionerd"
|
||||
|
||||
// AuthenticateProvisioner performs the provisioner's side of the authentication
|
||||
// protocol.
|
||||
func AuthenticateProvisioner(conn io.ReadWriter, token, jobID string) error {
|
||||
sb := strings.Builder{}
|
||||
_, _ = sb.WriteString(jobID)
|
||||
_, _ = sb.WriteString("\n")
|
||||
_, _ = sb.WriteString(token)
|
||||
_, _ = sb.WriteString("\n")
|
||||
_, err := conn.Write([]byte(sb.String()))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write token: %w", err)
|
||||
}
|
||||
b := make([]byte, 3)
|
||||
_, err = conn.Read(b)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to read token resp: %w", err)
|
||||
}
|
||||
if string(b) != responseOK {
|
||||
// convert to a human-readable format
|
||||
var reason string
|
||||
switch string(b) {
|
||||
case responseInvalidJobID:
|
||||
reason = "invalid job ID"
|
||||
case responseInvalidToken:
|
||||
reason = "invalid token"
|
||||
default:
|
||||
reason = fmt.Sprintf("unknown response code: %s", b)
|
||||
}
|
||||
return xerrors.Errorf("authenticate protocol error: %s", reason)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// authenticate performs the daemon/connector's side of the authentication
|
||||
// protocol.
|
||||
func (r *remoteConnector) authenticate(conn io.ReadWriter) (waiter, error) {
|
||||
// it's fine to use a scanner here because the provisioner side doesn't hand
|
||||
// off the connection to the dRPC handler until after we send our response.
|
||||
scn := bufio.NewScanner(conn)
|
||||
if ok := scn.Scan(); !ok {
|
||||
return waiter{}, xerrors.Errorf("failed to receive jobID: %w", scn.Err())
|
||||
}
|
||||
jobID := scn.Text()
|
||||
if ok := scn.Scan(); !ok {
|
||||
return waiter{}, xerrors.Errorf("failed to receive job token: %w", scn.Err())
|
||||
}
|
||||
token := scn.Text()
|
||||
w, err := r.pullWaiter(jobID, token)
|
||||
if err == nil {
|
||||
_, err = conn.Write([]byte(responseOK))
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("failed to write authentication response: %w", err)
|
||||
// if we fail here, it's our responsibility to send the error response on the respCh
|
||||
// because we're not going to return the waiter to the caller.
|
||||
go errResponse(w.job, w.respCh, err)
|
||||
return waiter{}, err
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
if xerrors.Is(err, errInvalidJobID) {
|
||||
_, wErr := conn.Write([]byte(responseInvalidJobID))
|
||||
r.logger.Debug(r.ctx, "responded invalid jobID", slog.Error(wErr))
|
||||
}
|
||||
if xerrors.Is(err, errInvalidToken) {
|
||||
_, wErr := conn.Write([]byte(responseInvalidToken))
|
||||
r.logger.Debug(r.ctx, "responded invalid token", slog.Error(wErr))
|
||||
}
|
||||
return waiter{}, err
|
||||
}
|
|
@ -0,0 +1,371 @@
|
|||
package provisionerd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/enterprise/provisionerd"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
agpl "github.com/coder/coder/v2/provisionerd"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestRemoteConnector_Mainline(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
smokescreen bool
|
||||
}{
|
||||
{name: "NoSmokescreen", smokescreen: false},
|
||||
{name: "Smokescreen", smokescreen: true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
exec := &testExecutor{
|
||||
t: t,
|
||||
logger: logger,
|
||||
smokescreen: tc.smokescreen,
|
||||
}
|
||||
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
|
||||
require.NoError(t, err)
|
||||
|
||||
respCh := make(chan agpl.ConnectResponse)
|
||||
job := &proto.AcquiredJob{
|
||||
JobId: "test-job",
|
||||
Provisioner: string(database.ProvisionerTypeEcho),
|
||||
}
|
||||
uut.Connect(ctx, job, respCh)
|
||||
var resp agpl.ConnectResponse
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout waiting for connect response")
|
||||
case resp = <-respCh:
|
||||
// OK
|
||||
}
|
||||
require.NoError(t, resp.Error)
|
||||
require.Equal(t, job, resp.Job)
|
||||
require.NotNil(t, resp.Client)
|
||||
|
||||
// check that we can communicate with the provisioner
|
||||
er := &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: echo.PlanComplete,
|
||||
}
|
||||
arc, err := echo.Tar(er)
|
||||
require.NoError(t, err)
|
||||
c := resp.Client
|
||||
s, err := c.Session(ctx)
|
||||
require.NoError(t, err)
|
||||
err = s.Send(&sdkproto.Request{Type: &sdkproto.Request_Config{Config: &sdkproto.Config{
|
||||
TemplateSourceArchive: arc,
|
||||
}}})
|
||||
require.NoError(t, err)
|
||||
err = s.Send(&sdkproto.Request{Type: &sdkproto.Request_Parse{Parse: &sdkproto.ParseRequest{}}})
|
||||
require.NoError(t, err)
|
||||
r, err := s.Recv()
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, &sdkproto.Response_Parse{}, r.Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteConnector_BadToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
exec := &testExecutor{
|
||||
t: t,
|
||||
logger: logger,
|
||||
overrideToken: "bad-token",
|
||||
}
|
||||
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
|
||||
require.NoError(t, err)
|
||||
|
||||
respCh := make(chan agpl.ConnectResponse)
|
||||
job := &proto.AcquiredJob{
|
||||
JobId: "test-job",
|
||||
Provisioner: string(database.ProvisionerTypeEcho),
|
||||
}
|
||||
uut.Connect(ctx, job, respCh)
|
||||
var resp agpl.ConnectResponse
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for connect response")
|
||||
case resp = <-respCh:
|
||||
// OK
|
||||
}
|
||||
require.Equal(t, job, resp.Job)
|
||||
require.ErrorContains(t, resp.Error, "invalid token")
|
||||
}
|
||||
|
||||
func TestRemoteConnector_BadJobID(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
exec := &testExecutor{
|
||||
t: t,
|
||||
logger: logger,
|
||||
overrideJobID: "bad-job",
|
||||
}
|
||||
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
|
||||
require.NoError(t, err)
|
||||
|
||||
respCh := make(chan agpl.ConnectResponse)
|
||||
job := &proto.AcquiredJob{
|
||||
JobId: "test-job",
|
||||
Provisioner: string(database.ProvisionerTypeEcho),
|
||||
}
|
||||
uut.Connect(ctx, job, respCh)
|
||||
var resp agpl.ConnectResponse
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for connect response")
|
||||
case resp = <-respCh:
|
||||
// OK
|
||||
}
|
||||
require.Equal(t, job, resp.Job)
|
||||
require.ErrorContains(t, resp.Error, "invalid job ID")
|
||||
}
|
||||
|
||||
func TestRemoteConnector_BadCert(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, cert, err := provisionerd.GenCert()
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
exec := &testExecutor{
|
||||
t: t,
|
||||
logger: logger,
|
||||
overrideCert: string(cert),
|
||||
}
|
||||
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
|
||||
require.NoError(t, err)
|
||||
|
||||
respCh := make(chan agpl.ConnectResponse)
|
||||
job := &proto.AcquiredJob{
|
||||
JobId: "test-job",
|
||||
Provisioner: string(database.ProvisionerTypeEcho),
|
||||
}
|
||||
uut.Connect(ctx, job, respCh)
|
||||
var resp agpl.ConnectResponse
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for connect response")
|
||||
case resp = <-respCh:
|
||||
// OK
|
||||
}
|
||||
require.Equal(t, job, resp.Job)
|
||||
require.ErrorContains(t, resp.Error, "certificate signed by unknown authority")
|
||||
}
|
||||
|
||||
func TestRemoteConnector_Fuzz(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
exec := newFuzzExecutor(t, logger)
|
||||
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
|
||||
require.NoError(t, err)
|
||||
|
||||
respCh := make(chan agpl.ConnectResponse)
|
||||
job := &proto.AcquiredJob{
|
||||
JobId: "test-job",
|
||||
Provisioner: string(database.ProvisionerTypeEcho),
|
||||
}
|
||||
|
||||
connectCtx, connectCtxCancel := context.WithCancel(ctx)
|
||||
defer connectCtxCancel()
|
||||
|
||||
uut.Connect(connectCtx, job, respCh)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for fuzzer")
|
||||
case <-exec.done:
|
||||
// Connector hung up on the fuzzer
|
||||
}
|
||||
require.Less(t, exec.bytesFuzzed, 2<<20, "should not allow more than 1 MiB")
|
||||
connectCtxCancel()
|
||||
var resp agpl.ConnectResponse
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for connect response")
|
||||
case resp = <-respCh:
|
||||
// OK
|
||||
}
|
||||
require.Equal(t, job, resp.Job)
|
||||
require.ErrorIs(t, resp.Error, context.Canceled)
|
||||
}
|
||||
|
||||
func TestRemoteConnector_CancelConnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
exec := &testExecutor{
|
||||
t: t,
|
||||
logger: logger,
|
||||
dontStart: true,
|
||||
}
|
||||
uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec)
|
||||
require.NoError(t, err)
|
||||
|
||||
respCh := make(chan agpl.ConnectResponse)
|
||||
job := &proto.AcquiredJob{
|
||||
JobId: "test-job",
|
||||
Provisioner: string(database.ProvisionerTypeEcho),
|
||||
}
|
||||
|
||||
connectCtx, connectCtxCancel := context.WithCancel(ctx)
|
||||
defer connectCtxCancel()
|
||||
|
||||
uut.Connect(connectCtx, job, respCh)
|
||||
connectCtxCancel()
|
||||
var resp agpl.ConnectResponse
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for connect response")
|
||||
case resp = <-respCh:
|
||||
// OK
|
||||
}
|
||||
require.Equal(t, job, resp.Job)
|
||||
require.ErrorIs(t, resp.Error, context.Canceled)
|
||||
}
|
||||
|
||||
type testExecutor struct {
|
||||
t *testing.T
|
||||
logger slog.Logger
|
||||
overrideToken string
|
||||
overrideJobID string
|
||||
overrideCert string
|
||||
// dontStart simulates when everything looks good to the connector but
|
||||
// the provisioner never starts
|
||||
dontStart bool
|
||||
// smokescreen starts a connection that fails authentication before starting
|
||||
// the real connection. Tests that failed connections don't interfere with
|
||||
// real ones.
|
||||
smokescreen bool
|
||||
}
|
||||
|
||||
func (e *testExecutor) Execute(
|
||||
ctx context.Context,
|
||||
provisionerType database.ProvisionerType,
|
||||
jobID, token, daemonCert, daemonAddress string,
|
||||
) <-chan error {
|
||||
assert.Equal(e.t, database.ProvisionerTypeEcho, provisionerType)
|
||||
if e.overrideToken != "" {
|
||||
token = e.overrideToken
|
||||
}
|
||||
if e.overrideJobID != "" {
|
||||
jobID = e.overrideJobID
|
||||
}
|
||||
if e.overrideCert != "" {
|
||||
daemonCert = e.overrideCert
|
||||
}
|
||||
cacheDir := e.t.TempDir()
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
if e.smokescreen {
|
||||
e.doSmokeScreen(ctx, jobID, daemonCert, daemonAddress)
|
||||
}
|
||||
if !e.dontStart {
|
||||
err := provisionerd.EphemeralEcho(ctx, e.logger, cacheDir, jobID, token, daemonCert, daemonAddress)
|
||||
e.logger.Debug(ctx, "provisioner done", slog.Error(err))
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
return errCh
|
||||
}
|
||||
|
||||
func (e *testExecutor) doSmokeScreen(ctx context.Context, jobID, daemonCert, daemonAddress string) {
|
||||
conn, err := provisionerd.DialTLS(ctx, daemonCert, daemonAddress)
|
||||
if !assert.NoError(e.t, err) {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
err = provisionerd.AuthenticateProvisioner(conn, "smokescreen", jobID)
|
||||
assert.ErrorContains(e.t, err, "invalid token")
|
||||
}
|
||||
|
||||
type fuzzExecutor struct {
|
||||
t *testing.T
|
||||
logger slog.Logger
|
||||
done chan struct{}
|
||||
bytesFuzzed int
|
||||
}
|
||||
|
||||
func newFuzzExecutor(t *testing.T, logger slog.Logger) *fuzzExecutor {
|
||||
return &fuzzExecutor{
|
||||
t: t,
|
||||
logger: logger,
|
||||
done: make(chan struct{}),
|
||||
bytesFuzzed: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *fuzzExecutor) Execute(
|
||||
ctx context.Context,
|
||||
_ database.ProvisionerType,
|
||||
_, _, daemonCert, daemonAddress string,
|
||||
) <-chan error {
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
defer close(e.done)
|
||||
conn, err := provisionerd.DialTLS(ctx, daemonCert, daemonAddress)
|
||||
assert.NoError(e.t, err)
|
||||
rb := make([]byte, 128)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
e.t.Error("context canceled while fuzzing")
|
||||
return
|
||||
}
|
||||
n, err := rand.Read(rb)
|
||||
if err != nil {
|
||||
e.t.Errorf("random read: %s", err)
|
||||
}
|
||||
if n < 128 {
|
||||
e.t.Error("short random read")
|
||||
return
|
||||
}
|
||||
// replace newlines so the Connector doesn't think we are done
|
||||
// with the JobID
|
||||
for i := 0; i < len(rb); i++ {
|
||||
if rb[i] == '\n' || rb[i] == '\r' {
|
||||
rb[i] = 'A'
|
||||
}
|
||||
}
|
||||
n, err = conn.Write(rb)
|
||||
e.bytesFuzzed += n
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return errCh
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package provisionerd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
)
|
||||
|
||||
// LocalProvisioners is a Connector that stores a static set of in-process
|
||||
// provisioners.
|
||||
type LocalProvisioners map[string]sdkproto.DRPCProvisionerClient
|
||||
|
||||
func (l LocalProvisioners) Connect(_ context.Context, job *proto.AcquiredJob, respCh chan<- ConnectResponse) {
|
||||
r := ConnectResponse{Job: job}
|
||||
p, ok := l[job.Provisioner]
|
||||
if ok {
|
||||
r.Client = p
|
||||
} else {
|
||||
r.Error = xerrors.Errorf("missing provisioner type %s", job.Provisioner)
|
||||
}
|
||||
go func() {
|
||||
respCh <- r
|
||||
}()
|
||||
}
|
|
@ -32,8 +32,24 @@ import (
|
|||
// Dialer represents the function to create a daemon client connection.
|
||||
type Dialer func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error)
|
||||
|
||||
// Provisioners maps provisioner ID to implementation.
|
||||
type Provisioners map[string]sdkproto.DRPCProvisionerClient
|
||||
// ConnectResponse is the response returned asynchronously from Connector.Connect
|
||||
// containing either the Provisioner Client or an Error. The Job is also returned
|
||||
// unaltered to disambiguate responses if the respCh is shared among multiple jobs
|
||||
type ConnectResponse struct {
|
||||
Job *proto.AcquiredJob
|
||||
Client sdkproto.DRPCProvisionerClient
|
||||
Error error
|
||||
}
|
||||
|
||||
// Connector allows the provisioner daemon to Connect to a provisioner
|
||||
// for the given job.
|
||||
type Connector interface {
|
||||
// Connect to the correct provisioner for the given job. The response is
|
||||
// delivered asynchronously over the respCh. If the provided context expires,
|
||||
// the Connector may stop waiting for the provisioner and return an error
|
||||
// response.
|
||||
Connect(ctx context.Context, job *proto.AcquiredJob, respCh chan<- ConnectResponse)
|
||||
}
|
||||
|
||||
// Options provides customizations to the behavior of a provisioner daemon.
|
||||
type Options struct {
|
||||
|
@ -47,7 +63,7 @@ type Options struct {
|
|||
JobPollInterval time.Duration
|
||||
JobPollJitter time.Duration
|
||||
JobPollDebounce time.Duration
|
||||
Provisioners Provisioners
|
||||
Connector Connector
|
||||
}
|
||||
|
||||
// New creates and starts a provisioner daemon.
|
||||
|
@ -375,11 +391,13 @@ func (p *Server) acquireJob(ctx context.Context) {
|
|||
|
||||
p.opts.Logger.Debug(ctx, "acquired job", fields...)
|
||||
|
||||
provisioner, ok := p.opts.Provisioners[job.Provisioner]
|
||||
if !ok {
|
||||
respCh := make(chan ConnectResponse)
|
||||
p.opts.Connector.Connect(ctx, job, respCh)
|
||||
resp := <-respCh
|
||||
if resp.Error != nil {
|
||||
err := p.FailJob(ctx, &proto.FailedJob{
|
||||
JobId: job.JobId,
|
||||
Error: fmt.Sprintf("no provisioner %s", job.Provisioner),
|
||||
Error: fmt.Sprintf("failed to connect to provisioner: %s", resp.Error),
|
||||
})
|
||||
if err != nil {
|
||||
p.opts.Logger.Error(ctx, "provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err))
|
||||
|
@ -394,7 +412,7 @@ func (p *Server) acquireJob(ctx context.Context) {
|
|||
Updater: p,
|
||||
QuotaCommitter: p,
|
||||
Logger: p.opts.Logger.Named("runner"),
|
||||
Provisioner: provisioner,
|
||||
Provisioner: resp.Client,
|
||||
UpdateInterval: p.opts.UpdateInterval,
|
||||
ForceCancelInterval: p.opts.ForceCancelInterval,
|
||||
LogDebounceInterval: p.opts.LogBufferInterval,
|
||||
|
|
|
@ -60,7 +60,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
})
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil
|
||||
}, provisionerd.Provisioners{})
|
||||
}, provisionerd.LocalProvisioners{})
|
||||
require.NoError(t, closer.Close())
|
||||
})
|
||||
|
||||
|
@ -74,7 +74,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
defer close(completeChan)
|
||||
return nil, xerrors.New("an error")
|
||||
}, provisionerd.Provisioners{})
|
||||
}, provisionerd.LocalProvisioners{})
|
||||
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
|
||||
require.NoError(t, closer.Close())
|
||||
})
|
||||
|
@ -101,7 +101,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
},
|
||||
updateJob: noopUpdateJob,
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{})
|
||||
}, provisionerd.LocalProvisioners{})
|
||||
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
|
||||
require.NoError(t, closer.Close())
|
||||
})
|
||||
|
@ -141,7 +141,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
parse: func(_ *provisionersdk.Session, _ *sdkproto.ParseRequest, _ <-chan struct{}) *sdkproto.ParseComplete {
|
||||
closerMutex.Lock()
|
||||
|
@ -195,7 +195,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}),
|
||||
})
|
||||
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
|
||||
|
@ -237,7 +237,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
parse: func(
|
||||
_ *provisionersdk.Session,
|
||||
|
@ -304,7 +304,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
parse: func(
|
||||
s *provisionersdk.Session,
|
||||
|
@ -398,7 +398,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
_ *provisionersdk.Session,
|
||||
|
@ -472,7 +472,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
s *provisionersdk.Session,
|
||||
|
@ -553,7 +553,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
s *provisionersdk.Session,
|
||||
|
@ -638,7 +638,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
s *provisionersdk.Session,
|
||||
|
@ -714,7 +714,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
s *provisionersdk.Session,
|
||||
|
@ -800,7 +800,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
s *provisionersdk.Session,
|
||||
|
@ -886,7 +886,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
}()
|
||||
}
|
||||
return client, nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
_ *provisionersdk.Session,
|
||||
|
@ -971,7 +971,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
}()
|
||||
}
|
||||
return client, nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
_ *provisionersdk.Session,
|
||||
|
@ -1055,7 +1055,7 @@ func TestProvisionerd(t *testing.T) {
|
|||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
plan: func(
|
||||
s *provisionersdk.Session,
|
||||
|
@ -1103,12 +1103,12 @@ func createTar(t *testing.T, files map[string]string) []byte {
|
|||
}
|
||||
|
||||
// Creates a provisionerd implementation with the provided dialer and provisioners.
|
||||
func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners provisionerd.Provisioners) *provisionerd.Server {
|
||||
func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, connector provisionerd.LocalProvisioners) *provisionerd.Server {
|
||||
server := provisionerd.New(dialer, &provisionerd.Options{
|
||||
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("provisionerd").Leveled(slog.LevelDebug),
|
||||
JobPollInterval: 50 * time.Millisecond,
|
||||
UpdateInterval: 50 * time.Millisecond,
|
||||
Provisioners: provisioners,
|
||||
Connector: connector,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = server.Close()
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/hashicorp/yamux"
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
"golang.org/x/xerrors"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
|
@ -21,8 +22,10 @@ import (
|
|||
|
||||
// ServeOptions are configurations to serve a provisioner.
|
||||
type ServeOptions struct {
|
||||
// Conn specifies a custom transport to serve the dRPC connection.
|
||||
Listener net.Listener
|
||||
// Listener serves multiple connections. Cannot be combined with Conn.
|
||||
Listener net.Listener
|
||||
// Conn is a single connection to serve. Cannot be combined with Listener.
|
||||
Conn drpc.Transport
|
||||
Logger slog.Logger
|
||||
WorkDirectory string
|
||||
}
|
||||
|
@ -38,8 +41,11 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error {
|
|||
if options == nil {
|
||||
options = &ServeOptions{}
|
||||
}
|
||||
// Default to using stdio.
|
||||
if options.Listener == nil {
|
||||
if options.Listener != nil && options.Conn != nil {
|
||||
return xerrors.New("specify Listener or Conn, not both")
|
||||
}
|
||||
// Default to using stdio with yamux as a Listener
|
||||
if options.Listener == nil && options.Conn == nil {
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
stdio, err := yamux.Server(&readWriteCloser{
|
||||
|
@ -75,10 +81,12 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error {
|
|||
return xerrors.Errorf("register provisioner: %w", err)
|
||||
}
|
||||
srv := drpcserver.New(&tracing.DRPCHandler{Handler: mux})
|
||||
// Only serve a single connection on the transport.
|
||||
// Transports are not multiplexed, and provisioners are
|
||||
// short-lived processes that can be executed concurrently.
|
||||
err = srv.Serve(ctx, options.Listener)
|
||||
|
||||
if options.Listener != nil {
|
||||
err = srv.Serve(ctx, options.Listener)
|
||||
} else if options.Conn != nil {
|
||||
err = srv.ServeOne(ctx, options.Conn)
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
|
|
|
@ -2,14 +2,17 @@ package provisionersdk_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"storj.io/drpc/drpcconn"
|
||||
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -18,7 +21,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
func TestProvisionerSDK(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Serve", func(t *testing.T) {
|
||||
t.Run("ServeListener", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server := provisionersdk.MemTransportPipe()
|
||||
defer client.Close()
|
||||
|
@ -72,6 +75,61 @@ func TestProvisionerSDK(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("ServeConn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancelFunc()
|
||||
srvErr := make(chan error, 1)
|
||||
go func() {
|
||||
err := provisionersdk.Serve(ctx, unimplementedServer{}, &provisionersdk.ServeOptions{
|
||||
Conn: server,
|
||||
WorkDirectory: t.TempDir(),
|
||||
})
|
||||
srvErr <- err
|
||||
}()
|
||||
|
||||
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
|
||||
s, err := api.Session(ctx)
|
||||
require.NoError(t, err)
|
||||
err = s.Send(&proto.Request{Type: &proto.Request_Config{Config: &proto.Config{}}})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Send(&proto.Request{Type: &proto.Request_Parse{Parse: &proto.ParseRequest{}}})
|
||||
require.NoError(t, err)
|
||||
msg, err := s.Recv()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "unimplemented", msg.GetParse().GetError())
|
||||
|
||||
err = s.Send(&proto.Request{Type: &proto.Request_Plan{Plan: &proto.PlanRequest{}}})
|
||||
require.NoError(t, err)
|
||||
msg, err = s.Recv()
|
||||
require.NoError(t, err)
|
||||
// Plan has no error so that we're allowed to run Apply
|
||||
require.Equal(t, "", msg.GetPlan().GetError())
|
||||
|
||||
err = s.Send(&proto.Request{Type: &proto.Request_Apply{Apply: &proto.ApplyRequest{}}})
|
||||
require.NoError(t, err)
|
||||
msg, err = s.Recv()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "unimplemented", msg.GetApply().GetError())
|
||||
|
||||
// Check provisioner closes when the connection does
|
||||
err = s.Close()
|
||||
require.NoError(t, err)
|
||||
err = api.DRPCConn().Close()
|
||||
require.NoError(t, err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for provisioner")
|
||||
case err = <-srvErr:
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type unimplementedServer struct{}
|
||||
|
|
Loading…
Reference in New Issue