diff --git a/cli/server.go b/cli/server.go index 9526a944f2..ee6c5878ac 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1309,7 +1309,7 @@ func newProvisionerDaemon( return nil, xerrors.Errorf("mkdir work dir: %w", err) } - provisioners := provisionerd.Provisioners{} + connector := provisionerd.LocalProvisioners{} if cfg.Provisioner.DaemonsEcho { echoClient, echoServer := provisionersdk.MemTransportPipe() wg.Add(1) @@ -1336,7 +1336,7 @@ func newProvisionerDaemon( } } }() - provisioners[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient) + connector[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient) } else { tfDir := filepath.Join(cacheDir, "tf") err = os.MkdirAll(tfDir, 0o700) @@ -1375,7 +1375,7 @@ func newProvisionerDaemon( } }() - provisioners[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient) + connector[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient) } debounce := time.Second @@ -1390,7 +1390,7 @@ func newProvisionerDaemon( JobPollDebounce: debounce, UpdateInterval: time.Second, ForceCancelInterval: cfg.Provisioner.ForceCancelInterval.Value(), - Provisioners: provisioners, + Connector: connector, TracerProvider: coderAPI.TracerProvider, Metrics: &metrics, }), nil diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 5ef17af359..af4314b545 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -484,7 +484,7 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer { JobPollInterval: 50 * time.Millisecond, UpdateInterval: 250 * time.Millisecond, ForceCancelInterval: time.Second, - Provisioners: provisionerd.Provisioners{ + Connector: provisionerd.LocalProvisioners{ string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient), }, }) @@ -524,7 +524,7 @@ func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uui JobPollInterval: 50 * time.Millisecond, UpdateInterval: 250 * time.Millisecond, ForceCancelInterval: time.Second, - Provisioners: provisionerd.Provisioners{ + Connector: provisionerd.LocalProvisioners{ string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient), }, }) diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index e63e7cd46c..2cb5f98d49 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -124,7 +124,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { logger.Info(ctx, "starting provisioner daemon", slog.F("tags", tags)) - provisioners := provisionerd.Provisioners{ + connector := provisionerd.LocalProvisioners{ string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(terraformClient), } srv := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { @@ -140,7 +140,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { JobPollInterval: pollInterval, JobPollJitter: pollJitter, UpdateInterval: 500 * time.Millisecond, - Provisioners: provisioners, + Connector: connector, }) var exitErr error diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index d41bf42385..aa4b1295dd 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -253,7 +253,7 @@ func TestProvisionerDaemonServe(t *testing.T) { errCh <- err }() - provisioners := provisionerd.Provisioners{ + connector := provisionerd.LocalProvisioners{ string(database.ProvisionerTypeEcho): proto.NewDRPCProvisionerClient(terraformClient), } another := codersdk.New(client.URL) @@ -269,8 +269,8 @@ func TestProvisionerDaemonServe(t *testing.T) { PreSharedKey: "provisionersftw", }) }, &provisionerd.Options{ - Logger: logger.Named("provisionerd"), - Provisioners: provisioners, + Logger: logger.Named("provisionerd"), + Connector: connector, }) defer pd.Close() diff --git a/enterprise/provisionerd/remoteprovisioners.go b/enterprise/provisionerd/remoteprovisioners.go new file mode 100644 index 0000000000..c56459ef31 --- /dev/null +++ b/enterprise/provisionerd/remoteprovisioners.go @@ -0,0 +1,466 @@ +package provisionerd + +import ( + "bufio" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/subtle" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "fmt" + "io" + "math/big" + "net" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "golang.org/x/xerrors" + "storj.io/drpc/drpcconn" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/provisioner/echo" + agpl "github.com/coder/coder/v2/provisionerd" + "github.com/coder/coder/v2/provisionerd/proto" + "github.com/coder/coder/v2/provisionersdk" + sdkproto "github.com/coder/coder/v2/provisionersdk/proto" +) + +// Executor is responsible for executing the remote provisioners. +// +// TODO: this interface is where we will run Kubernetes Jobs in a future +// version; right now, only the unit tests implement this interface. +type Executor interface { + // Execute a provisioner that connects back to the remoteConnector. errCh + // allows signalling of errors asynchronously and is closed on completion + // with no error. + Execute( + ctx context.Context, + provisionerType database.ProvisionerType, + jobID, token, daemonCert, daemonAddress string) (errCh <-chan error) +} + +type waiter struct { + ctx context.Context + job *proto.AcquiredJob + respCh chan<- agpl.ConnectResponse + token string +} + +type remoteConnector struct { + ctx context.Context + executor Executor + cert string + addr string + listener net.Listener + logger slog.Logger + tlsCfg *tls.Config + + mu sync.Mutex + waiters map[string]waiter +} + +func NewRemoteConnector(ctx context.Context, logger slog.Logger, exec Executor) (agpl.Connector, error) { + // nolint: gosec + listener, err := net.Listen("tcp", ":0") + if err != nil { + return nil, xerrors.Errorf("failed to listen: %w", err) + } + go func() { + <-ctx.Done() + ce := listener.Close() + logger.Debug(ctx, "listener closed", slog.Error(ce)) + }() + r := &remoteConnector{ + ctx: ctx, + executor: exec, + listener: listener, + addr: listener.Addr().String(), + logger: logger, + waiters: make(map[string]waiter), + } + err = r.genCert() + if err != nil { + return nil, xerrors.Errorf("failed to generate certificate: %w", err) + } + go r.listenLoop() + return r, nil +} + +func (r *remoteConnector) genCert() error { + privateKey, cert, err := GenCert() + if err != nil { + return err + } + r.cert = string(cert) + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return xerrors.Errorf("failed to marshal private key: %w", err) + } + pkPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}) + certKey, err := tls.X509KeyPair(cert, pkPEM) + if err != nil { + return xerrors.Errorf("failed to create TLS certificate: %w", err) + } + r.tlsCfg = &tls.Config{Certificates: []tls.Certificate{certKey}, MinVersion: tls.VersionTLS13} + return nil +} + +// GenCert is a helper function that generates a private key and certificate. It +// is exported so that we can test a certificate generated in exactly the same +// way, but with a different private key. +func GenCert() (*ecdsa.PrivateKey, []byte, error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, xerrors.Errorf("generate private key: %w", err) + } + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "Coder Provisioner Daemon", + }, + DNSNames: []string{serverName}, + NotBefore: time.Now(), + // cert is valid for 5 years, which is much longer than we expect this + // process to stay up. The idea is that the certificate is self-signed + // and is valid for as long as the daemon is up and starting new remote + // provisioners + NotAfter: time.Now().Add(time.Hour * 24 * 365 * 5), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, nil, xerrors.Errorf("failed to create certificate: %w", err) + } + cert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + return privateKey, cert, nil +} + +func (r *remoteConnector) listenLoop() { + for { + conn, err := r.listener.Accept() + if err != nil { + r.logger.Info(r.ctx, "stopping listenLoop", slog.Error(err)) + return + } + go r.handleConn(conn) + } +} + +func (r *remoteConnector) handleConn(conn net.Conn) { + logger := r.logger.With(slog.F("remote_addr", conn.RemoteAddr())) + + // If we hit an error while setting up, we want to close the connection. + // This construction makes the default to close until we explicitly set + // closeConn = false just before handing the connection over the respCh. + closeConn := true + defer func() { + if closeConn { + ce := conn.Close() + logger.Debug(r.ctx, "closed connection", slog.Error(ce)) + } + }() + + tlsConn := tls.Server(conn, r.tlsCfg) + err := tlsConn.HandshakeContext(r.ctx) + if err != nil { + logger.Info(r.ctx, "failed TLS handshake", slog.Error(err)) + return + } + w, err := r.authenticate(tlsConn) + if err != nil { + logger.Info(r.ctx, "failed provisioner authentication", slog.Error(err)) + return + } + logger = logger.With(slog.F("job_id", w.job.JobId)) + logger.Info(r.ctx, "provisioner connected") + closeConn = false // we're passing the conn over the channel + w.respCh <- agpl.ConnectResponse{ + Job: w.job, + Client: sdkproto.NewDRPCProvisionerClient(drpcconn.New(tlsConn)), + } +} + +var ( + errInvalidJobID = xerrors.New("invalid jobID") + errInvalidToken = xerrors.New("invalid token") +) + +func (r *remoteConnector) pullWaiter(jobID, token string) (waiter, error) { + r.mu.Lock() + defer r.mu.Unlock() + // provisioners authenticate with a jobID and token. The jobID is required + // because we need to use public information for the lookup, to avoid timing + // attacks against the token. + w, ok := r.waiters[jobID] + if !ok { + return waiter{}, errInvalidJobID + } + if subtle.ConstantTimeCompare([]byte(token), []byte(w.token)) == 1 { + delete(r.waiters, jobID) + return w, nil + } + return waiter{}, errInvalidToken +} + +func (r *remoteConnector) Connect( + ctx context.Context, job *proto.AcquiredJob, respCh chan<- agpl.ConnectResponse, +) { + pt := database.ProvisionerType(job.Provisioner) + if !pt.Valid() { + go errResponse(job, respCh, xerrors.Errorf("invalid provisioner type: %s", job.Provisioner)) + } + tb := make([]byte, 16) // 128-bit token + n, err := rand.Read(tb) + if err != nil { + go errResponse(job, respCh, err) + return + } + if n != 16 { + go errResponse(job, respCh, xerrors.New("short read generating token")) + } + token := base64.StdEncoding.EncodeToString(tb) + r.mu.Lock() + defer r.mu.Unlock() + r.waiters[job.JobId] = waiter{ + ctx: ctx, + job: job, + respCh: respCh, + token: token, + } + go r.handleContextExpired(ctx, job.JobId) + errCh := r.executor.Execute(ctx, pt, job.JobId, token, r.cert, r.addr) + go r.handleExecError(job.JobId, errCh) +} + +func (r *remoteConnector) handleContextExpired(ctx context.Context, jobID string) { + <-ctx.Done() + r.mu.Lock() + defer r.mu.Unlock() + w, ok := r.waiters[jobID] + if !ok { + // something else already responded. + return + } + delete(r.waiters, jobID) + // separate goroutine, so we don't hold the lock while trying to write + // to the channel. + go func() { + w.respCh <- agpl.ConnectResponse{ + Job: w.job, + Error: ctx.Err(), + } + }() +} + +func (r *remoteConnector) handleExecError(jobID string, errCh <-chan error) { + err := <-errCh + if err == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + w, ok := r.waiters[jobID] + if !ok { + // something else already responded. + return + } + delete(r.waiters, jobID) + // separate goroutine, so we don't hold the lock while trying to write + // to the channel. + go func() { + w.respCh <- agpl.ConnectResponse{ + Job: w.job, + Error: err, + } + }() +} + +func errResponse(job *proto.AcquiredJob, respCh chan<- agpl.ConnectResponse, err error) { + respCh <- agpl.ConnectResponse{ + Job: job, + Error: err, + } +} + +// EphemeralEcho starts an Echo provisioner that connects to provisioner daemon, +// handles one job, then exits. +func EphemeralEcho( + ctx context.Context, + logger slog.Logger, + cacheDir, jobID, token, daemonCert, daemonAddress string, +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + workdir := filepath.Join(cacheDir, "echo") + err := os.MkdirAll(workdir, 0o777) + if err != nil { + return xerrors.Errorf("create workdir %s: %w", workdir, err) + } + conn, err := DialTLS(ctx, daemonCert, daemonAddress) + if err != nil { + return err + } + defer conn.Close() + err = AuthenticateProvisioner(conn, token, jobID) + if err != nil { + return err + } + // so it's a little confusing, but the provisioner is the client with + // respect to TLS, but is the server with respect to dRPC + exitErr := echo.Serve(ctx, &provisionersdk.ServeOptions{ + Conn: conn, + Logger: logger.Named("echo"), + WorkDirectory: workdir, + }) + logger.Debug(ctx, "echo.Serve done", slog.Error(exitErr)) + + if xerrors.Is(exitErr, context.Canceled) { + return nil + } + return exitErr +} + +// DialTLS establishes a TLS connection to the given addr using the given cert +// as the root CA +func DialTLS(ctx context.Context, cert, addr string) (*tls.Conn, error) { + roots := x509.NewCertPool() + ok := roots.AppendCertsFromPEM([]byte(cert)) + if !ok { + return nil, xerrors.New("failed to parse daemon certificate") + } + cfg := &tls.Config{RootCAs: roots, MinVersion: tls.VersionTLS13, ServerName: serverName} + d := net.Dialer{} + nc, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, xerrors.Errorf("dial: %w", err) + } + tc := tls.Client(nc, cfg) + // Explicitly handshake so we don't have to mess with setting read + // and write deadlines. + err = tc.HandshakeContext(ctx) + if err != nil { + _ = nc.Close() + return nil, xerrors.Errorf("TLS handshake: %w", err) + } + return tc, nil +} + +// Authentication Protocol: +// +// Ephemeral provisioners connect to the connector using TLS. This allows the +// provisioner to authenticate the daemon/connector based on the TLS certificate +// delivered to the provisioner out-of-band. +// +// The daemon/connector authenticates the provisioner by jobID and token, which +// are sent over the TLS connection separated by newlines. The daemon/connector +// responds with a 3-byte response to complete the handshake. +// +// Although the token is unique to the job and unambiguous, we also send the +// jobID. This allows the daemon/connector to look up the job based on public +// information (jobID), shielding the token from timing attacks. I'm not sure +// how practical a timing attack against an in-memory golang map is, but it's +// better to avoid it entirely. After the job is looked up by jobID, we do a +// constant time compare on the token to authenticate. +// +// Also note that we don't really have to worry about cross-version +// compatibility in this protocol, since the provisioners are always started by +// the same daemon/connector as they connect to. + +// Responses are all exactly 3 bytes so that don't have to use a scanner +// which might accidentally buffer some of the first dRPC request. +const ( + responseOK = "OK\n" + responseInvalidJobID = "IJ\n" + responseInvalidToken = "IT\n" +) + +// serverName is the name on the x509 certificate the daemon/connector generates +// this name doesn't matter as long as both sides agree, since the provisioners +// get the IP address directly. It is also fine to reuse, since each generates +// a unique private key and self-signs, we will not correctly authenticate to +// a different provisionerd. +const serverName = "provisionerd" + +// AuthenticateProvisioner performs the provisioner's side of the authentication +// protocol. +func AuthenticateProvisioner(conn io.ReadWriter, token, jobID string) error { + sb := strings.Builder{} + _, _ = sb.WriteString(jobID) + _, _ = sb.WriteString("\n") + _, _ = sb.WriteString(token) + _, _ = sb.WriteString("\n") + _, err := conn.Write([]byte(sb.String())) + if err != nil { + return xerrors.Errorf("failed to write token: %w", err) + } + b := make([]byte, 3) + _, err = conn.Read(b) + if err != nil { + return xerrors.Errorf("failed to read token resp: %w", err) + } + if string(b) != responseOK { + // convert to a human-readable format + var reason string + switch string(b) { + case responseInvalidJobID: + reason = "invalid job ID" + case responseInvalidToken: + reason = "invalid token" + default: + reason = fmt.Sprintf("unknown response code: %s", b) + } + return xerrors.Errorf("authenticate protocol error: %s", reason) + } + return nil +} + +// authenticate performs the daemon/connector's side of the authentication +// protocol. +func (r *remoteConnector) authenticate(conn io.ReadWriter) (waiter, error) { + // it's fine to use a scanner here because the provisioner side doesn't hand + // off the connection to the dRPC handler until after we send our response. + scn := bufio.NewScanner(conn) + if ok := scn.Scan(); !ok { + return waiter{}, xerrors.Errorf("failed to receive jobID: %w", scn.Err()) + } + jobID := scn.Text() + if ok := scn.Scan(); !ok { + return waiter{}, xerrors.Errorf("failed to receive job token: %w", scn.Err()) + } + token := scn.Text() + w, err := r.pullWaiter(jobID, token) + if err == nil { + _, err = conn.Write([]byte(responseOK)) + if err != nil { + err = xerrors.Errorf("failed to write authentication response: %w", err) + // if we fail here, it's our responsibility to send the error response on the respCh + // because we're not going to return the waiter to the caller. + go errResponse(w.job, w.respCh, err) + return waiter{}, err + } + return w, nil + } + if xerrors.Is(err, errInvalidJobID) { + _, wErr := conn.Write([]byte(responseInvalidJobID)) + r.logger.Debug(r.ctx, "responded invalid jobID", slog.Error(wErr)) + } + if xerrors.Is(err, errInvalidToken) { + _, wErr := conn.Write([]byte(responseInvalidToken)) + r.logger.Debug(r.ctx, "responded invalid token", slog.Error(wErr)) + } + return waiter{}, err +} diff --git a/enterprise/provisionerd/remoteprovisioners_test.go b/enterprise/provisionerd/remoteprovisioners_test.go new file mode 100644 index 0000000000..1e1ca3d788 --- /dev/null +++ b/enterprise/provisionerd/remoteprovisioners_test.go @@ -0,0 +1,371 @@ +package provisionerd_test + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/enterprise/provisionerd" + "github.com/coder/coder/v2/provisioner/echo" + agpl "github.com/coder/coder/v2/provisionerd" + "github.com/coder/coder/v2/provisionerd/proto" + sdkproto "github.com/coder/coder/v2/provisionersdk/proto" + "github.com/coder/coder/v2/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestRemoteConnector_Mainline(t *testing.T) { + t.Parallel() + cases := []struct { + name string + smokescreen bool + }{ + {name: "NoSmokescreen", smokescreen: false}, + {name: "Smokescreen", smokescreen: true}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + exec := &testExecutor{ + t: t, + logger: logger, + smokescreen: tc.smokescreen, + } + uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec) + require.NoError(t, err) + + respCh := make(chan agpl.ConnectResponse) + job := &proto.AcquiredJob{ + JobId: "test-job", + Provisioner: string(database.ProvisionerTypeEcho), + } + uut.Connect(ctx, job, respCh) + var resp agpl.ConnectResponse + select { + case <-ctx.Done(): + t.Error("timeout waiting for connect response") + case resp = <-respCh: + // OK + } + require.NoError(t, resp.Error) + require.Equal(t, job, resp.Job) + require.NotNil(t, resp.Client) + + // check that we can communicate with the provisioner + er := &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionApply: echo.ApplyComplete, + ProvisionPlan: echo.PlanComplete, + } + arc, err := echo.Tar(er) + require.NoError(t, err) + c := resp.Client + s, err := c.Session(ctx) + require.NoError(t, err) + err = s.Send(&sdkproto.Request{Type: &sdkproto.Request_Config{Config: &sdkproto.Config{ + TemplateSourceArchive: arc, + }}}) + require.NoError(t, err) + err = s.Send(&sdkproto.Request{Type: &sdkproto.Request_Parse{Parse: &sdkproto.ParseRequest{}}}) + require.NoError(t, err) + r, err := s.Recv() + require.NoError(t, err) + require.IsType(t, &sdkproto.Response_Parse{}, r.Type) + }) + } +} + +func TestRemoteConnector_BadToken(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + exec := &testExecutor{ + t: t, + logger: logger, + overrideToken: "bad-token", + } + uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec) + require.NoError(t, err) + + respCh := make(chan agpl.ConnectResponse) + job := &proto.AcquiredJob{ + JobId: "test-job", + Provisioner: string(database.ProvisionerTypeEcho), + } + uut.Connect(ctx, job, respCh) + var resp agpl.ConnectResponse + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for connect response") + case resp = <-respCh: + // OK + } + require.Equal(t, job, resp.Job) + require.ErrorContains(t, resp.Error, "invalid token") +} + +func TestRemoteConnector_BadJobID(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + exec := &testExecutor{ + t: t, + logger: logger, + overrideJobID: "bad-job", + } + uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec) + require.NoError(t, err) + + respCh := make(chan agpl.ConnectResponse) + job := &proto.AcquiredJob{ + JobId: "test-job", + Provisioner: string(database.ProvisionerTypeEcho), + } + uut.Connect(ctx, job, respCh) + var resp agpl.ConnectResponse + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for connect response") + case resp = <-respCh: + // OK + } + require.Equal(t, job, resp.Job) + require.ErrorContains(t, resp.Error, "invalid job ID") +} + +func TestRemoteConnector_BadCert(t *testing.T) { + t.Parallel() + _, cert, err := provisionerd.GenCert() + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + exec := &testExecutor{ + t: t, + logger: logger, + overrideCert: string(cert), + } + uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec) + require.NoError(t, err) + + respCh := make(chan agpl.ConnectResponse) + job := &proto.AcquiredJob{ + JobId: "test-job", + Provisioner: string(database.ProvisionerTypeEcho), + } + uut.Connect(ctx, job, respCh) + var resp agpl.ConnectResponse + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for connect response") + case resp = <-respCh: + // OK + } + require.Equal(t, job, resp.Job) + require.ErrorContains(t, resp.Error, "certificate signed by unknown authority") +} + +func TestRemoteConnector_Fuzz(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + exec := newFuzzExecutor(t, logger) + uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec) + require.NoError(t, err) + + respCh := make(chan agpl.ConnectResponse) + job := &proto.AcquiredJob{ + JobId: "test-job", + Provisioner: string(database.ProvisionerTypeEcho), + } + + connectCtx, connectCtxCancel := context.WithCancel(ctx) + defer connectCtxCancel() + + uut.Connect(connectCtx, job, respCh) + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for fuzzer") + case <-exec.done: + // Connector hung up on the fuzzer + } + require.Less(t, exec.bytesFuzzed, 2<<20, "should not allow more than 1 MiB") + connectCtxCancel() + var resp agpl.ConnectResponse + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for connect response") + case resp = <-respCh: + // OK + } + require.Equal(t, job, resp.Job) + require.ErrorIs(t, resp.Error, context.Canceled) +} + +func TestRemoteConnector_CancelConnect(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + exec := &testExecutor{ + t: t, + logger: logger, + dontStart: true, + } + uut, err := provisionerd.NewRemoteConnector(ctx, logger.Named("connector"), exec) + require.NoError(t, err) + + respCh := make(chan agpl.ConnectResponse) + job := &proto.AcquiredJob{ + JobId: "test-job", + Provisioner: string(database.ProvisionerTypeEcho), + } + + connectCtx, connectCtxCancel := context.WithCancel(ctx) + defer connectCtxCancel() + + uut.Connect(connectCtx, job, respCh) + connectCtxCancel() + var resp agpl.ConnectResponse + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for connect response") + case resp = <-respCh: + // OK + } + require.Equal(t, job, resp.Job) + require.ErrorIs(t, resp.Error, context.Canceled) +} + +type testExecutor struct { + t *testing.T + logger slog.Logger + overrideToken string + overrideJobID string + overrideCert string + // dontStart simulates when everything looks good to the connector but + // the provisioner never starts + dontStart bool + // smokescreen starts a connection that fails authentication before starting + // the real connection. Tests that failed connections don't interfere with + // real ones. + smokescreen bool +} + +func (e *testExecutor) Execute( + ctx context.Context, + provisionerType database.ProvisionerType, + jobID, token, daemonCert, daemonAddress string, +) <-chan error { + assert.Equal(e.t, database.ProvisionerTypeEcho, provisionerType) + if e.overrideToken != "" { + token = e.overrideToken + } + if e.overrideJobID != "" { + jobID = e.overrideJobID + } + if e.overrideCert != "" { + daemonCert = e.overrideCert + } + cacheDir := e.t.TempDir() + errCh := make(chan error) + go func() { + defer close(errCh) + if e.smokescreen { + e.doSmokeScreen(ctx, jobID, daemonCert, daemonAddress) + } + if !e.dontStart { + err := provisionerd.EphemeralEcho(ctx, e.logger, cacheDir, jobID, token, daemonCert, daemonAddress) + e.logger.Debug(ctx, "provisioner done", slog.Error(err)) + if err != nil { + errCh <- err + } + } + }() + return errCh +} + +func (e *testExecutor) doSmokeScreen(ctx context.Context, jobID, daemonCert, daemonAddress string) { + conn, err := provisionerd.DialTLS(ctx, daemonCert, daemonAddress) + if !assert.NoError(e.t, err) { + return + } + defer conn.Close() + err = provisionerd.AuthenticateProvisioner(conn, "smokescreen", jobID) + assert.ErrorContains(e.t, err, "invalid token") +} + +type fuzzExecutor struct { + t *testing.T + logger slog.Logger + done chan struct{} + bytesFuzzed int +} + +func newFuzzExecutor(t *testing.T, logger slog.Logger) *fuzzExecutor { + return &fuzzExecutor{ + t: t, + logger: logger, + done: make(chan struct{}), + bytesFuzzed: 0, + } +} + +func (e *fuzzExecutor) Execute( + ctx context.Context, + _ database.ProvisionerType, + _, _, daemonCert, daemonAddress string, +) <-chan error { + errCh := make(chan error) + go func() { + defer close(errCh) + defer close(e.done) + conn, err := provisionerd.DialTLS(ctx, daemonCert, daemonAddress) + assert.NoError(e.t, err) + rb := make([]byte, 128) + for { + if ctx.Err() != nil { + e.t.Error("context canceled while fuzzing") + return + } + n, err := rand.Read(rb) + if err != nil { + e.t.Errorf("random read: %s", err) + } + if n < 128 { + e.t.Error("short random read") + return + } + // replace newlines so the Connector doesn't think we are done + // with the JobID + for i := 0; i < len(rb); i++ { + if rb[i] == '\n' || rb[i] == '\r' { + rb[i] = 'A' + } + } + n, err = conn.Write(rb) + e.bytesFuzzed += n + if err != nil { + return + } + } + }() + return errCh +} diff --git a/provisionerd/localprovisioners.go b/provisionerd/localprovisioners.go new file mode 100644 index 0000000000..0e495f536d --- /dev/null +++ b/provisionerd/localprovisioners.go @@ -0,0 +1,28 @@ +package provisionerd + +import ( + "context" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/provisionerd/proto" + + sdkproto "github.com/coder/coder/v2/provisionersdk/proto" +) + +// LocalProvisioners is a Connector that stores a static set of in-process +// provisioners. +type LocalProvisioners map[string]sdkproto.DRPCProvisionerClient + +func (l LocalProvisioners) Connect(_ context.Context, job *proto.AcquiredJob, respCh chan<- ConnectResponse) { + r := ConnectResponse{Job: job} + p, ok := l[job.Provisioner] + if ok { + r.Client = p + } else { + r.Error = xerrors.Errorf("missing provisioner type %s", job.Provisioner) + } + go func() { + respCh <- r + }() +} diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index 2d8788ddec..e873d1901d 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -32,8 +32,24 @@ import ( // Dialer represents the function to create a daemon client connection. type Dialer func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) -// Provisioners maps provisioner ID to implementation. -type Provisioners map[string]sdkproto.DRPCProvisionerClient +// ConnectResponse is the response returned asynchronously from Connector.Connect +// containing either the Provisioner Client or an Error. The Job is also returned +// unaltered to disambiguate responses if the respCh is shared among multiple jobs +type ConnectResponse struct { + Job *proto.AcquiredJob + Client sdkproto.DRPCProvisionerClient + Error error +} + +// Connector allows the provisioner daemon to Connect to a provisioner +// for the given job. +type Connector interface { + // Connect to the correct provisioner for the given job. The response is + // delivered asynchronously over the respCh. If the provided context expires, + // the Connector may stop waiting for the provisioner and return an error + // response. + Connect(ctx context.Context, job *proto.AcquiredJob, respCh chan<- ConnectResponse) +} // Options provides customizations to the behavior of a provisioner daemon. type Options struct { @@ -47,7 +63,7 @@ type Options struct { JobPollInterval time.Duration JobPollJitter time.Duration JobPollDebounce time.Duration - Provisioners Provisioners + Connector Connector } // New creates and starts a provisioner daemon. @@ -375,11 +391,13 @@ func (p *Server) acquireJob(ctx context.Context) { p.opts.Logger.Debug(ctx, "acquired job", fields...) - provisioner, ok := p.opts.Provisioners[job.Provisioner] - if !ok { + respCh := make(chan ConnectResponse) + p.opts.Connector.Connect(ctx, job, respCh) + resp := <-respCh + if resp.Error != nil { err := p.FailJob(ctx, &proto.FailedJob{ JobId: job.JobId, - Error: fmt.Sprintf("no provisioner %s", job.Provisioner), + Error: fmt.Sprintf("failed to connect to provisioner: %s", resp.Error), }) if err != nil { p.opts.Logger.Error(ctx, "provisioner job failed", slog.F("job_id", job.JobId), slog.Error(err)) @@ -394,7 +412,7 @@ func (p *Server) acquireJob(ctx context.Context) { Updater: p, QuotaCommitter: p, Logger: p.opts.Logger.Named("runner"), - Provisioner: provisioner, + Provisioner: resp.Client, UpdateInterval: p.opts.UpdateInterval, ForceCancelInterval: p.opts.ForceCancelInterval, LogDebounceInterval: p.opts.LogBufferInterval, diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index ee379e0ab9..a4fffdf468 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -60,7 +60,7 @@ func TestProvisionerd(t *testing.T) { }) closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil - }, provisionerd.Provisioners{}) + }, provisionerd.LocalProvisioners{}) require.NoError(t, closer.Close()) }) @@ -74,7 +74,7 @@ func TestProvisionerd(t *testing.T) { closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { defer close(completeChan) return nil, xerrors.New("an error") - }, provisionerd.Provisioners{}) + }, provisionerd.LocalProvisioners{}) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, closer.Close()) }) @@ -101,7 +101,7 @@ func TestProvisionerd(t *testing.T) { }, updateJob: noopUpdateJob, }), nil - }, provisionerd.Provisioners{}) + }, provisionerd.LocalProvisioners{}) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) require.NoError(t, closer.Close()) }) @@ -141,7 +141,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func(_ *provisionersdk.Session, _ *sdkproto.ParseRequest, _ <-chan struct{}) *sdkproto.ParseComplete { closerMutex.Lock() @@ -195,7 +195,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}), }) require.Condition(t, closedWithin(completeChan, testutil.WaitShort)) @@ -237,7 +237,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func( _ *provisionersdk.Session, @@ -304,7 +304,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ parse: func( s *provisionersdk.Session, @@ -398,7 +398,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( _ *provisionersdk.Session, @@ -472,7 +472,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, @@ -553,7 +553,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, @@ -638,7 +638,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, @@ -714,7 +714,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, @@ -800,7 +800,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, @@ -886,7 +886,7 @@ func TestProvisionerd(t *testing.T) { }() } return client, nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( _ *provisionersdk.Session, @@ -971,7 +971,7 @@ func TestProvisionerd(t *testing.T) { }() } return client, nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( _ *provisionersdk.Session, @@ -1055,7 +1055,7 @@ func TestProvisionerd(t *testing.T) { return &proto.Empty{}, nil }, }), nil - }, provisionerd.Provisioners{ + }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ plan: func( s *provisionersdk.Session, @@ -1103,12 +1103,12 @@ func createTar(t *testing.T, files map[string]string) []byte { } // Creates a provisionerd implementation with the provided dialer and provisioners. -func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners provisionerd.Provisioners) *provisionerd.Server { +func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, connector provisionerd.LocalProvisioners) *provisionerd.Server { server := provisionerd.New(dialer, &provisionerd.Options{ Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("provisionerd").Leveled(slog.LevelDebug), JobPollInterval: 50 * time.Millisecond, UpdateInterval: 50 * time.Millisecond, - Provisioners: provisioners, + Connector: connector, }) t.Cleanup(func() { _ = server.Close() diff --git a/provisionersdk/serve.go b/provisionersdk/serve.go index 924c7ad013..baa3cc1412 100644 --- a/provisionersdk/serve.go +++ b/provisionersdk/serve.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/yamux" "github.com/valyala/fasthttp/fasthttputil" "golang.org/x/xerrors" + "storj.io/drpc" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -21,8 +22,10 @@ import ( // ServeOptions are configurations to serve a provisioner. type ServeOptions struct { - // Conn specifies a custom transport to serve the dRPC connection. - Listener net.Listener + // Listener serves multiple connections. Cannot be combined with Conn. + Listener net.Listener + // Conn is a single connection to serve. Cannot be combined with Listener. + Conn drpc.Transport Logger slog.Logger WorkDirectory string } @@ -38,8 +41,11 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error { if options == nil { options = &ServeOptions{} } - // Default to using stdio. - if options.Listener == nil { + if options.Listener != nil && options.Conn != nil { + return xerrors.New("specify Listener or Conn, not both") + } + // Default to using stdio with yamux as a Listener + if options.Listener == nil && options.Conn == nil { config := yamux.DefaultConfig() config.LogOutput = io.Discard stdio, err := yamux.Server(&readWriteCloser{ @@ -75,10 +81,12 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error { return xerrors.Errorf("register provisioner: %w", err) } srv := drpcserver.New(&tracing.DRPCHandler{Handler: mux}) - // Only serve a single connection on the transport. - // Transports are not multiplexed, and provisioners are - // short-lived processes that can be executed concurrently. - err = srv.Serve(ctx, options.Listener) + + if options.Listener != nil { + err = srv.Serve(ctx, options.Listener) + } else if options.Conn != nil { + err = srv.ServeOne(ctx, options.Conn) + } if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || diff --git a/provisionersdk/serve_test.go b/provisionersdk/serve_test.go index baa5d2ba62..7ebfeb6f9b 100644 --- a/provisionersdk/serve_test.go +++ b/provisionersdk/serve_test.go @@ -2,14 +2,17 @@ package provisionersdk_test import ( "context" + "net" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "storj.io/drpc/drpcconn" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/provisionersdk/proto" + "github.com/coder/coder/v2/testutil" ) func TestMain(m *testing.M) { @@ -18,7 +21,7 @@ func TestMain(m *testing.M) { func TestProvisionerSDK(t *testing.T) { t.Parallel() - t.Run("Serve", func(t *testing.T) { + t.Run("ServeListener", func(t *testing.T) { t.Parallel() client, server := provisionersdk.MemTransportPipe() defer client.Close() @@ -72,6 +75,61 @@ func TestProvisionerSDK(t *testing.T) { }) require.NoError(t, err) }) + + t.Run("ServeConn", func(t *testing.T) { + t.Parallel() + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancelFunc() + srvErr := make(chan error, 1) + go func() { + err := provisionersdk.Serve(ctx, unimplementedServer{}, &provisionersdk.ServeOptions{ + Conn: server, + WorkDirectory: t.TempDir(), + }) + srvErr <- err + }() + + api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) + s, err := api.Session(ctx) + require.NoError(t, err) + err = s.Send(&proto.Request{Type: &proto.Request_Config{Config: &proto.Config{}}}) + require.NoError(t, err) + + err = s.Send(&proto.Request{Type: &proto.Request_Parse{Parse: &proto.ParseRequest{}}}) + require.NoError(t, err) + msg, err := s.Recv() + require.NoError(t, err) + require.Equal(t, "unimplemented", msg.GetParse().GetError()) + + err = s.Send(&proto.Request{Type: &proto.Request_Plan{Plan: &proto.PlanRequest{}}}) + require.NoError(t, err) + msg, err = s.Recv() + require.NoError(t, err) + // Plan has no error so that we're allowed to run Apply + require.Equal(t, "", msg.GetPlan().GetError()) + + err = s.Send(&proto.Request{Type: &proto.Request_Apply{Apply: &proto.ApplyRequest{}}}) + require.NoError(t, err) + msg, err = s.Recv() + require.NoError(t, err) + require.Equal(t, "unimplemented", msg.GetApply().GetError()) + + // Check provisioner closes when the connection does + err = s.Close() + require.NoError(t, err) + err = api.DRPCConn().Close() + require.NoError(t, err) + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for provisioner") + case err = <-srvErr: + require.NoError(t, err) + } + }) } type unimplementedServer struct{}