coder/provisionerd/provisionerd.go

422 lines
10 KiB
Go

package provisionerd
import (
"context"
"errors"
"fmt"
"io"
"reflect"
"strings"
"sync"
"time"
"github.com/hashicorp/yamux"
"github.com/spf13/afero"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.opentelemetry.io/otel/trace"
"go.uber.org/atomic"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionerd/runner"
sdkproto "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/retry"
)
// IsMissingParameterError returns whether the error message provided
// is a missing parameter error. This can indicate to consumers that
// they should check parameters.
func IsMissingParameterError(err string) bool {
return strings.Contains(err, runner.MissingParameterErrorText)
}
// Dialer represents the function to create a daemon client connection.
type Dialer func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error)
// Provisioners maps provisioner ID to implementation.
type Provisioners map[string]sdkproto.DRPCProvisionerClient
// Options provides customizations to the behavior of a provisioner daemon.
type Options struct {
Filesystem afero.Fs
Logger slog.Logger
Tracer trace.TracerProvider
ForceCancelInterval time.Duration
UpdateInterval time.Duration
PollInterval time.Duration
Provisioners Provisioners
WorkDirectory string
}
// New creates and starts a provisioner daemon.
func New(clientDialer Dialer, opts *Options) *Server {
if opts.PollInterval == 0 {
opts.PollInterval = 5 * time.Second
}
if opts.UpdateInterval == 0 {
opts.UpdateInterval = 5 * time.Second
}
if opts.ForceCancelInterval == 0 {
opts.ForceCancelInterval = time.Minute
}
if opts.Filesystem == nil {
opts.Filesystem = afero.NewOsFs()
}
if opts.Tracer == nil {
opts.Tracer = trace.NewNoopTracerProvider()
}
ctx, ctxCancel := context.WithCancel(context.Background())
daemon := &Server{
opts: opts,
tracer: opts.Tracer.Tracer(tracing.TracerName),
clientDialer: clientDialer,
closeContext: ctx,
closeCancel: ctxCancel,
shutdown: make(chan struct{}),
}
go daemon.connect(ctx)
return daemon
}
type Server struct {
opts *Options
tracer trace.Tracer
clientDialer Dialer
clientValue atomic.Value
// Locked when closing the daemon, shutting down, or starting a new job.
mutex sync.Mutex
closeContext context.Context
closeCancel context.CancelFunc
closeError error
shutdown chan struct{}
activeJob *runner.Runner
}
// Connect establishes a connection to coderd.
func (p *Server) connect(ctx context.Context) {
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
client, err := p.clientDialer(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if p.isClosed() {
return
}
p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
p.clientValue.Store(client)
p.opts.Logger.Debug(context.Background(), "connected")
break
}
select {
case <-ctx.Done():
return
default:
}
go func() {
if p.isClosed() {
return
}
client, ok := p.client()
if !ok {
return
}
select {
case <-p.closeContext.Done():
return
case <-client.DRPCConn().Closed():
// We use the update stream to detect when the connection
// has been interrupted. This works well, because logs need
// to buffer if a job is running in the background.
p.opts.Logger.Debug(context.Background(), "client stream ended")
p.connect(ctx)
}
}()
go func() {
if p.isClosed() {
return
}
ticker := time.NewTicker(p.opts.PollInterval)
defer ticker.Stop()
for {
client, ok := p.client()
if !ok {
return
}
select {
case <-p.closeContext.Done():
return
case <-client.DRPCConn().Closed():
return
case <-ticker.C:
p.acquireJob(ctx)
}
}
}()
}
func (p *Server) client() (proto.DRPCProvisionerDaemonClient, bool) {
rawClient := p.clientValue.Load()
if rawClient == nil {
return nil, false
}
client, ok := rawClient.(proto.DRPCProvisionerDaemonClient)
return client, ok
}
// isRunningJob returns true if a job is running. Caller must hold the mutex.
func (p *Server) isRunningJob() bool {
if p.activeJob == nil {
return false
}
select {
case <-p.activeJob.Done():
return false
default:
return true
}
}
// Locks a job in the database, and runs it!
func (p *Server) acquireJob(ctx context.Context) {
p.mutex.Lock()
defer p.mutex.Unlock()
if p.isClosed() {
return
}
if p.isRunningJob() {
return
}
if p.isShutdown() {
p.opts.Logger.Debug(context.Background(), "skipping acquire; provisionerd is shutting down...")
return
}
var err error
client, ok := p.client()
if !ok {
return
}
job, err := client.AcquireJob(ctx, &proto.Empty{})
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if errors.Is(err, yamux.ErrSessionShutdown) {
return
}
p.opts.Logger.Warn(ctx, "acquire job", slog.Error(err))
return
}
if job.JobId == "" {
return
}
ctx, span := p.tracer.Start(ctx, tracing.FuncName(), trace.WithAttributes(
semconv.ServiceNameKey.String("coderd.provisionerd"),
attribute.String("job_id", job.JobId),
attribute.String("job_type", reflect.TypeOf(job.GetType()).Elem().Name()),
attribute.Int64("job_created_at", job.CreatedAt),
attribute.String("initiator_username", job.UserName),
attribute.String("provisioner", job.Provisioner),
attribute.Int("template_size_bytes", len(job.TemplateSourceArchive)),
))
defer span.End()
if build := job.GetWorkspaceBuild(); build != nil {
span.SetAttributes(
attribute.String("workspace_build_id", build.WorkspaceBuildId),
attribute.String("workspace_id", build.Metadata.WorkspaceId),
attribute.String("workspace_name", build.WorkspaceName),
attribute.String("workspace_owner_id", build.Metadata.WorkspaceOwnerId),
attribute.String("workspace_owner", build.Metadata.WorkspaceOwner),
attribute.String("workspace_transition", build.Metadata.WorkspaceTransition.String()),
)
}
p.opts.Logger.Info(ctx, "acquired job",
slog.F("initiator_username", job.UserName),
slog.F("provisioner", job.Provisioner),
slog.F("job_id", job.JobId),
)
provisioner, ok := p.opts.Provisioners[job.Provisioner]
if !ok {
err := p.FailJob(ctx, &proto.FailedJob{
JobId: job.JobId,
Error: fmt.Sprintf("no provisioner %s", job.Provisioner),
})
if err != nil {
p.opts.Logger.Error(ctx, "fail job", slog.F("job_id", job.JobId), slog.Error(err))
}
return
}
p.activeJob = runner.NewRunner(
ctx,
job,
p,
p.opts.Logger,
p.opts.Filesystem,
p.opts.WorkDirectory,
provisioner,
p.opts.UpdateInterval,
p.opts.ForceCancelInterval,
p.tracer,
)
go p.activeJob.Run()
}
func retryable(err error) bool {
return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) ||
// annoyingly, dRPC sometimes returns context.Canceled if the transport was closed, even if the context for
// the RPC *is not canceled*. Retrying is fine if the RPC context is not canceled.
xerrors.Is(err, context.Canceled)
}
// clientDoWithRetries runs the function f with a client, and retries with backoff until either the error returned
// is not retryable() or the context expires.
func (p *Server) clientDoWithRetries(
ctx context.Context, f func(context.Context, proto.DRPCProvisionerDaemonClient) (any, error)) (
any, error) {
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); {
client, ok := p.client()
if !ok {
continue
}
resp, err := f(ctx, client)
if retryable(err) {
continue
}
return resp, err
}
return nil, ctx.Err()
}
func (p *Server) UpdateJob(ctx context.Context, in *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
out, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) {
return client.UpdateJob(ctx, in)
})
if err != nil {
return nil, err
}
// nolint: forcetypeassert
return out.(*proto.UpdateJobResponse), nil
}
func (p *Server) FailJob(ctx context.Context, in *proto.FailedJob) error {
_, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) {
return client.FailJob(ctx, in)
})
return err
}
func (p *Server) CompleteJob(ctx context.Context, in *proto.CompletedJob) error {
_, err := p.clientDoWithRetries(ctx, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) (any, error) {
return client.CompleteJob(ctx, in)
})
return err
}
// isClosed returns whether the API is closed or not.
func (p *Server) isClosed() bool {
select {
case <-p.closeContext.Done():
return true
default:
return false
}
}
// isShutdown returns whether the API is shutdown or not.
func (p *Server) isShutdown() bool {
select {
case <-p.shutdown:
return true
default:
return false
}
}
// Shutdown triggers a graceful exit of each registered provisioner.
// It exits when an active job stops.
func (p *Server) Shutdown(ctx context.Context) error {
p.mutex.Lock()
defer p.mutex.Unlock()
if !p.isRunningJob() {
return nil
}
p.opts.Logger.Info(ctx, "attempting graceful shutdown")
close(p.shutdown)
if p.activeJob == nil {
return nil
}
// wait for active job
p.activeJob.Cancel()
select {
case <-ctx.Done():
p.opts.Logger.Warn(ctx, "graceful shutdown failed", slog.Error(ctx.Err()))
return ctx.Err()
case <-p.activeJob.Done():
p.opts.Logger.Info(ctx, "gracefully shutdown")
return nil
}
}
// Close ends the provisioner. It will mark any running jobs as failed.
func (p *Server) Close() error {
return p.closeWithError(nil)
}
// closeWithError closes the provisioner; subsequent reads/writes will return the error err.
func (p *Server) closeWithError(err error) error {
p.mutex.Lock()
defer p.mutex.Unlock()
if p.isClosed() {
return p.closeError
}
p.closeError = err
errMsg := "provisioner daemon was shutdown gracefully"
if err != nil {
errMsg = err.Error()
}
if p.activeJob != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
failErr := p.activeJob.Fail(ctx, &proto.FailedJob{Error: errMsg})
if failErr != nil {
p.activeJob.ForceStop()
}
if err == nil {
err = failErr
}
}
p.closeCancel()
p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err))
return err
}