fix: fix null pointer on external provisioner daemons with daily_cost (#9401)

* fix: fix null pointer on external provisioner daemons with daily_cost

Signed-off-by: Spike Curtis <spike@coder.com>

* Add logging for debounce and job acquire

Signed-off-by: Spike Curtis <spike@coder.com>

* Return error instead of panic

Signed-off-by: Spike Curtis <spike@coder.com>

* remove debounce on external provisioners to fix test flakes

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2023-08-30 14:48:35 +04:00 committed by GitHub
parent a415395e9e
commit 90acf998bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 664 additions and 350 deletions

View File

@ -1098,26 +1098,31 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
}
mux := drpcmux.New()
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
ID: daemon.ID,
OIDCConfig: api.OIDCConfig,
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
GitAuthConfigs: api.GitAuthConfigs,
Telemetry: api.Telemetry,
Tracer: tracer,
Tags: tags,
QuotaCommitter: &api.QuotaCommitter,
Auditor: &api.Auditor,
TemplateScheduleStore: api.TemplateScheduleStore,
UserQuietHoursScheduleStore: api.UserQuietHoursScheduleStore,
AcquireJobDebounce: debounce,
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
DeploymentValues: api.DeploymentValues,
})
srv, err := provisionerdserver.NewServer(
api.AccessURL,
daemon.ID,
api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
daemon.Provisioners,
tags,
api.Database,
api.Pubsub,
api.Telemetry,
tracer,
&api.QuotaCommitter,
&api.Auditor,
api.TemplateScheduleStore,
api.UserQuietHoursScheduleStore,
api.DeploymentValues,
debounce,
provisionerdserver.Options{
OIDCConfig: api.OIDCConfig,
GitAuthConfigs: api.GitAuthConfigs,
},
)
if err != nil {
return nil, err
}
err = proto.DRPCRegisterProvisionerDaemon(mux, srv)
if err != nil {
return nil, err
}

View File

@ -48,7 +48,14 @@ var (
lastAcquireMutex sync.RWMutex
)
type Server struct {
type Options struct {
OIDCConfig httpmw.OAuth2Config
GitAuthConfigs []*gitauth.Config
// TimeNowFn is only used in tests
TimeNowFn func() time.Time
}
type server struct {
AccessURL *url.URL
ID uuid.UUID
Logger slog.Logger
@ -71,17 +78,73 @@ type Server struct {
TimeNowFn func() time.Time
}
func NewServer(
accessURL *url.URL,
id uuid.UUID,
logger slog.Logger,
provisioners []database.ProvisionerType,
tags json.RawMessage,
db database.Store,
ps pubsub.Pubsub,
tel telemetry.Reporter,
tracer trace.Tracer,
quotaCommitter *atomic.Pointer[proto.QuotaCommitter],
auditor *atomic.Pointer[audit.Auditor],
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore],
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore],
deploymentValues *codersdk.DeploymentValues,
acquireJobDebounce time.Duration,
options Options,
) (proto.DRPCProvisionerDaemonServer, error) {
// Panic early if pointers are nil
if quotaCommitter == nil {
return nil, xerrors.New("quotaCommitter is nil")
}
if auditor == nil {
return nil, xerrors.New("auditor is nil")
}
if templateScheduleStore == nil {
return nil, xerrors.New("templateScheduleStore is nil")
}
if userQuietHoursScheduleStore == nil {
return nil, xerrors.New("userQuietHoursScheduleStore is nil")
}
if deploymentValues == nil {
return nil, xerrors.New("deploymentValues is nil")
}
return &server{
AccessURL: accessURL,
ID: id,
Logger: logger,
Provisioners: provisioners,
GitAuthConfigs: options.GitAuthConfigs,
Tags: tags,
Database: db,
Pubsub: ps,
Telemetry: tel,
Tracer: tracer,
QuotaCommitter: quotaCommitter,
Auditor: auditor,
TemplateScheduleStore: templateScheduleStore,
UserQuietHoursScheduleStore: userQuietHoursScheduleStore,
DeploymentValues: deploymentValues,
AcquireJobDebounce: acquireJobDebounce,
OIDCConfig: options.OIDCConfig,
TimeNowFn: options.TimeNowFn,
}, nil
}
// timeNow should be used when trying to get the current time for math
// calculations regarding workspace start and stop time.
func (server *Server) timeNow() time.Time {
if server.TimeNowFn != nil {
return database.Time(server.TimeNowFn())
func (s *server) timeNow() time.Time {
if s.TimeNowFn != nil {
return database.Time(s.TimeNowFn())
}
return database.Now()
}
// AcquireJob queries the database to lock a job.
func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
func (s *server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
//nolint:gocritic // Provisionerd has specific authz rules.
ctx = dbauthz.AsProvisionerd(ctx)
// This prevents loads of provisioner daemons from consistently
@ -90,23 +153,24 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
// The debounce only occurs when no job is returned, so if loads of
// jobs are added at once, they will start after at most this duration.
lastAcquireMutex.RLock()
if !lastAcquire.IsZero() && time.Since(lastAcquire) < server.AcquireJobDebounce {
if !lastAcquire.IsZero() && time.Since(lastAcquire) < s.AcquireJobDebounce {
s.Logger.Debug(ctx, "debounce acquire job", slog.F("debounce", s.AcquireJobDebounce), slog.F("last_acquire", lastAcquire))
lastAcquireMutex.RUnlock()
return &proto.AcquiredJob{}, nil
}
lastAcquireMutex.RUnlock()
// This marks the job as locked in the database.
job, err := server.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
job, err := s.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
StartedAt: sql.NullTime{
Time: database.Now(),
Valid: true,
},
WorkerID: uuid.NullUUID{
UUID: server.ID,
UUID: s.ID,
Valid: true,
},
Types: server.Provisioners,
Tags: server.Tags,
Types: s.Provisioners,
Tags: s.Tags,
})
if errors.Is(err, sql.ErrNoRows) {
// The provisioner daemon assumes no jobs are available if
@ -119,11 +183,11 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
if err != nil {
return nil, xerrors.Errorf("acquire job: %w", err)
}
server.Logger.Debug(ctx, "locked job from database", slog.F("job_id", job.ID))
s.Logger.Debug(ctx, "locked job from database", slog.F("job_id", job.ID))
// Marks the acquired job as failed with the error message provided.
failJob := func(errorMessage string) error {
err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: job.ID,
CompletedAt: sql.NullTime{
Time: database.Now(),
@ -141,7 +205,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
return xerrors.Errorf("request job was invalidated: %s", errorMessage)
}
user, err := server.Database.GetUserByID(ctx, job.InitiatorID)
user, err := s.Database.GetUserByID(ctx, job.InitiatorID)
if err != nil {
return nil, failJob(fmt.Sprintf("get user: %s", err))
}
@ -169,38 +233,38 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
if err != nil {
return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err))
}
workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
if err != nil {
return nil, failJob(fmt.Sprintf("get workspace build: %s", err))
}
workspace, err := server.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
workspace, err := s.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
if err != nil {
return nil, failJob(fmt.Sprintf("get workspace: %s", err))
}
templateVersion, err := server.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID)
templateVersion, err := s.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID)
if err != nil {
return nil, failJob(fmt.Sprintf("get template version: %s", err))
}
templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, templateVersion.ID)
templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, templateVersion.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, failJob(fmt.Sprintf("get template version variables: %s", err))
}
template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
template, err := s.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
if err != nil {
return nil, failJob(fmt.Sprintf("get template: %s", err))
}
owner, err := server.Database.GetUserByID(ctx, workspace.OwnerID)
owner, err := s.Database.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
return nil, failJob(fmt.Sprintf("get owner: %s", err))
}
err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{})
err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{})
if err != nil {
return nil, failJob(fmt.Sprintf("publish workspace update: %s", err))
}
var workspaceOwnerOIDCAccessToken string
if server.OIDCConfig != nil {
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, server.Database, server.OIDCConfig, owner.ID)
if s.OIDCConfig != nil {
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Database, s.OIDCConfig, owner.ID)
if err != nil {
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
}
@ -209,12 +273,12 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
var sessionToken string
switch workspaceBuild.Transition {
case database.WorkspaceTransitionStart:
sessionToken, err = server.regenerateSessionToken(ctx, owner, workspace)
sessionToken, err = s.regenerateSessionToken(ctx, owner, workspace)
if err != nil {
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
}
case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete:
err = deleteSessionToken(ctx, server.Database, workspace)
err = deleteSessionToken(ctx, s.Database, workspace)
if err != nil {
return nil, failJob(fmt.Sprintf("delete session token: %s", err))
}
@ -225,14 +289,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
return nil, failJob(fmt.Sprintf("convert workspace transition: %s", err))
}
workspaceBuildParameters, err := server.Database.GetWorkspaceBuildParameters(ctx, workspaceBuild.ID)
workspaceBuildParameters, err := s.Database.GetWorkspaceBuildParameters(ctx, workspaceBuild.ID)
if err != nil {
return nil, failJob(fmt.Sprintf("get workspace build parameters: %s", err))
}
gitAuthProviders := []*sdkproto.GitAuthProvider{}
for _, p := range templateVersion.GitAuthProviders {
link, err := server.Database.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
link, err := s.Database.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
ProviderID: p,
UserID: owner.ID,
})
@ -243,7 +307,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
return nil, failJob(fmt.Sprintf("acquire git auth link: %s", err))
}
var config *gitauth.Config
for _, c := range server.GitAuthConfigs {
for _, c := range s.GitAuthConfigs {
if c.ID != p {
continue
}
@ -252,14 +316,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
}
// We weren't able to find a matching config for the ID!
if config == nil {
server.Logger.Warn(ctx, "workspace build job is missing git provider",
s.Logger.Warn(ctx, "workspace build job is missing git provider",
slog.F("git_provider_id", p),
slog.F("template_version_id", templateVersion.ID),
slog.F("workspace_id", workspaceBuild.WorkspaceID))
continue
}
link, valid, err := config.RefreshToken(ctx, server.Database, link)
link, valid, err := config.RefreshToken(ctx, s.Database, link)
if err != nil {
return nil, failJob(fmt.Sprintf("refresh git auth link %q: %s", p, err))
}
@ -281,7 +345,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
VariableValues: asVariableValues(templateVariables),
GitAuthProviders: gitAuthProviders,
Metadata: &sdkproto.Metadata{
CoderUrl: server.AccessURL.String(),
CoderUrl: s.AccessURL.String(),
WorkspaceTransition: transition,
WorkspaceName: workspace.Name,
WorkspaceOwner: owner.Username,
@ -303,11 +367,11 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err))
}
templateVersion, err := server.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID)
templateVersion, err := s.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID)
if err != nil {
return nil, failJob(fmt.Sprintf("get template version: %s", err))
}
templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, templateVersion.ID)
templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, templateVersion.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, failJob(fmt.Sprintf("get template version variables: %s", err))
}
@ -317,7 +381,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
RichParameterValues: convertRichParameterValues(input.RichParameterValues),
VariableValues: asVariableValues(templateVariables),
Metadata: &sdkproto.Metadata{
CoderUrl: server.AccessURL.String(),
CoderUrl: s.AccessURL.String(),
WorkspaceName: input.WorkspaceName,
},
},
@ -329,7 +393,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err))
}
userVariableValues, err := server.includeLastVariableValues(ctx, input.TemplateVersionID, input.UserVariableValues)
userVariableValues, err := s.includeLastVariableValues(ctx, input.TemplateVersionID, input.UserVariableValues)
if err != nil {
return nil, failJob(err.Error())
}
@ -338,14 +402,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
TemplateImport: &proto.AcquiredJob_TemplateImport{
UserVariableValues: convertVariableValues(userVariableValues),
Metadata: &sdkproto.Metadata{
CoderUrl: server.AccessURL.String(),
CoderUrl: s.AccessURL.String(),
},
},
}
}
switch job.StorageMethod {
case database.ProvisionerStorageMethodFile:
file, err := server.Database.GetFileByID(ctx, job.FileID)
file, err := s.Database.GetFileByID(ctx, job.FileID)
if err != nil {
return nil, failJob(fmt.Sprintf("get file by hash: %s", err))
}
@ -360,7 +424,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
return protoJob, err
}
func (server *Server) includeLastVariableValues(ctx context.Context, templateVersionID uuid.UUID, userVariableValues []codersdk.VariableValue) ([]codersdk.VariableValue, error) {
func (s *server) includeLastVariableValues(ctx context.Context, templateVersionID uuid.UUID, userVariableValues []codersdk.VariableValue) ([]codersdk.VariableValue, error) {
var values []codersdk.VariableValue
values = append(values, userVariableValues...)
@ -368,7 +432,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
return values, nil
}
templateVersion, err := server.Database.GetTemplateVersionByID(ctx, templateVersionID)
templateVersion, err := s.Database.GetTemplateVersionByID(ctx, templateVersionID)
if err != nil {
return nil, xerrors.Errorf("get template version: %w", err)
}
@ -377,7 +441,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
return values, nil
}
template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
template, err := s.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
if err != nil {
return nil, xerrors.Errorf("get template: %w", err)
}
@ -386,7 +450,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
return values, nil
}
templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, template.ActiveVersionID)
templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, template.ActiveVersionID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get template version variables: %w", err)
}
@ -412,8 +476,8 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
return values, nil
}
func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
ctx, span := server.startTrace(ctx, tracing.FuncName())
func (s *server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
ctx, span := s.startTrace(ctx, tracing.FuncName())
defer span.End()
//nolint:gocritic // Provisionerd has specific authz rules.
@ -423,7 +487,7 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot
return nil, xerrors.Errorf("parse job id: %w", err)
}
job, err := server.Database.GetProvisionerJobByID(ctx, jobID)
job, err := s.Database.GetProvisionerJobByID(ctx, jobID)
if err != nil {
return nil, xerrors.Errorf("get job: %w", err)
}
@ -431,11 +495,11 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot
return nil, xerrors.New("job isn't running yet")
}
if job.WorkerID.UUID.String() != server.ID.String() {
if job.WorkerID.UUID.String() != s.ID.String() {
return nil, xerrors.New("you don't own this job")
}
q := server.QuotaCommitter.Load()
q := s.QuotaCommitter.Load()
if q == nil {
// We're probably in community edition or a test.
return &proto.CommitQuotaResponse{
@ -446,8 +510,8 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot
return (*q).CommitQuota(ctx, request)
}
func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
ctx, span := server.startTrace(ctx, tracing.FuncName())
func (s *server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
ctx, span := s.startTrace(ctx, tracing.FuncName())
defer span.End()
//nolint:gocritic // Provisionerd has specific authz rules.
@ -456,18 +520,18 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
if err != nil {
return nil, xerrors.Errorf("parse job id: %w", err)
}
server.Logger.Debug(ctx, "stage UpdateJob starting", slog.F("job_id", parsedID))
job, err := server.Database.GetProvisionerJobByID(ctx, parsedID)
s.Logger.Debug(ctx, "stage UpdateJob starting", slog.F("job_id", parsedID))
job, err := s.Database.GetProvisionerJobByID(ctx, parsedID)
if err != nil {
return nil, xerrors.Errorf("get job: %w", err)
}
if !job.WorkerID.Valid {
return nil, xerrors.New("job isn't running yet")
}
if job.WorkerID.UUID.String() != server.ID.String() {
if job.WorkerID.UUID.String() != s.ID.String() {
return nil, xerrors.New("you don't own this job")
}
err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{
err = s.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{
ID: parsedID,
UpdatedAt: database.Now(),
})
@ -493,37 +557,37 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
insertParams.Stage = append(insertParams.Stage, log.Stage)
insertParams.Source = append(insertParams.Source, logSource)
insertParams.Output = append(insertParams.Output, log.Output)
server.Logger.Debug(ctx, "job log",
s.Logger.Debug(ctx, "job log",
slog.F("job_id", parsedID),
slog.F("stage", log.Stage),
slog.F("output", log.Output))
}
logs, err := server.Database.InsertProvisionerJobLogs(ctx, insertParams)
logs, err := s.Database.InsertProvisionerJobLogs(ctx, insertParams)
if err != nil {
server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err))
s.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err))
return nil, xerrors.Errorf("insert job logs: %w", err)
}
// Publish by the lowest log ID inserted so the log stream will fetch
// everything from that point.
lowestID := logs[0].ID
server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
s.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{
CreatedAfter: lowestID - 1,
})
if err != nil {
return nil, xerrors.Errorf("marshal: %w", err)
}
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data)
err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err))
s.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err))
return nil, xerrors.Errorf("publish job logs: %w", err)
}
server.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID))
s.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID))
}
if len(request.Readme) > 0 {
err := server.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{
err := s.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{
JobID: job.ID,
Readme: string(request.Readme),
UpdatedAt: database.Now(),
@ -534,16 +598,16 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
}
if len(request.TemplateVariables) > 0 {
templateVersion, err := server.Database.GetTemplateVersionByJobID(ctx, job.ID)
templateVersion, err := s.Database.GetTemplateVersionByJobID(ctx, job.ID)
if err != nil {
server.Logger.Error(ctx, "failed to get the template version", slog.F("job_id", parsedID), slog.Error(err))
s.Logger.Error(ctx, "failed to get the template version", slog.F("job_id", parsedID), slog.Error(err))
return nil, xerrors.Errorf("get template version by job id: %w", err)
}
var variableValues []*sdkproto.VariableValue
var variablesWithMissingValues []string
for _, templateVariable := range request.TemplateVariables {
server.Logger.Debug(ctx, "insert template variable", slog.F("template_version_id", templateVersion.ID), slog.F("template_variable", redactTemplateVariable(templateVariable)))
s.Logger.Debug(ctx, "insert template variable", slog.F("template_version_id", templateVersion.ID), slog.F("template_variable", redactTemplateVariable(templateVariable)))
value := templateVariable.DefaultValue
for _, v := range request.UserVariableValues {
@ -563,7 +627,7 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
Sensitive: templateVariable.Sensitive,
})
_, err = server.Database.InsertTemplateVersionVariable(ctx, database.InsertTemplateVersionVariableParams{
_, err = s.Database.InsertTemplateVersionVariable(ctx, database.InsertTemplateVersionVariableParams{
TemplateVersionID: templateVersion.ID,
Name: templateVariable.Name,
Description: templateVariable.Description,
@ -593,8 +657,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
}, nil
}
func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) {
ctx, span := server.startTrace(ctx, tracing.FuncName())
func (s *server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) {
ctx, span := s.startTrace(ctx, tracing.FuncName())
defer span.End()
//nolint:gocritic // Provisionerd has specific authz rules.
@ -603,12 +667,12 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
if err != nil {
return nil, xerrors.Errorf("parse job id: %w", err)
}
server.Logger.Debug(ctx, "stage FailJob starting", slog.F("job_id", jobID))
job, err := server.Database.GetProvisionerJobByID(ctx, jobID)
s.Logger.Debug(ctx, "stage FailJob starting", slog.F("job_id", jobID))
job, err := s.Database.GetProvisionerJobByID(ctx, jobID)
if err != nil {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
if job.WorkerID.UUID.String() != server.ID.String() {
if job.WorkerID.UUID.String() != s.ID.String() {
return nil, xerrors.New("you don't own this job")
}
if job.CompletedAt.Valid {
@ -627,7 +691,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
Valid: failJob.ErrorCode != "",
}
err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: jobID,
CompletedAt: job.CompletedAt,
UpdatedAt: database.Now(),
@ -637,7 +701,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
if err != nil {
return nil, xerrors.Errorf("update provisioner job: %w", err)
}
server.Telemetry.Report(&telemetry.Snapshot{
s.Telemetry.Report(&telemetry.Snapshot{
ProvisionerJobs: []telemetry.ProvisionerJob{telemetry.ConvertProvisionerJob(job)},
})
@ -650,7 +714,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
}
var build database.WorkspaceBuild
err = server.Database.InTx(func(db database.Store) error {
err = s.Database.InTx(func(db database.Store) error {
build, err = db.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
if err != nil {
return xerrors.Errorf("get workspace build: %w", err)
@ -675,7 +739,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
return nil, err
}
err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), []byte{})
err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), []byte{})
if err != nil {
return nil, xerrors.Errorf("update workspace: %w", err)
}
@ -684,18 +748,18 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
// if failed job is a workspace build, audit the outcome
if job.Type == database.ProvisionerJobTypeWorkspaceBuild {
auditor := server.Auditor.Load()
build, err := server.Database.GetWorkspaceBuildByJobID(ctx, job.ID)
auditor := s.Auditor.Load()
build, err := s.Database.GetWorkspaceBuildByJobID(ctx, job.ID)
if err != nil {
server.Logger.Error(ctx, "audit log - get build", slog.Error(err))
s.Logger.Error(ctx, "audit log - get build", slog.Error(err))
} else {
auditAction := auditActionFromTransition(build.Transition)
workspace, err := server.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
workspace, err := s.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
if err != nil {
server.Logger.Error(ctx, "audit log - get workspace", slog.Error(err))
s.Logger.Error(ctx, "audit log - get workspace", slog.Error(err))
} else {
previousBuildNumber := build.BuildNumber - 1
previousBuild, prevBuildErr := server.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
WorkspaceID: workspace.ID,
BuildNumber: previousBuildNumber,
})
@ -713,12 +777,12 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
wriBytes, err := json.Marshal(buildResourceInfo)
if err != nil {
server.Logger.Error(ctx, "marshal workspace resource info for failed job", slog.Error(err))
s.Logger.Error(ctx, "marshal workspace resource info for failed job", slog.Error(err))
}
audit.BuildAudit(ctx, &audit.BuildAuditParams[database.WorkspaceBuild]{
Audit: *auditor,
Log: server.Logger,
Log: s.Logger,
UserID: job.InitiatorID,
JobID: job.ID,
Action: auditAction,
@ -735,17 +799,17 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
return nil, xerrors.Errorf("publish end of job logs: %w", err)
}
return &proto.Empty{}, nil
}
// CompleteJob is triggered by a provision daemon to mark a provisioner job as completed.
func (server *Server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) {
ctx, span := server.startTrace(ctx, tracing.FuncName())
func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) {
ctx, span := s.startTrace(ctx, tracing.FuncName())
defer span.End()
//nolint:gocritic // Provisionerd has specific authz rules.
@ -754,18 +818,18 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
if err != nil {
return nil, xerrors.Errorf("parse job id: %w", err)
}
server.Logger.Debug(ctx, "stage CompleteJob starting", slog.F("job_id", jobID))
job, err := server.Database.GetProvisionerJobByID(ctx, jobID)
s.Logger.Debug(ctx, "stage CompleteJob starting", slog.F("job_id", jobID))
job, err := s.Database.GetProvisionerJobByID(ctx, jobID)
if err != nil {
return nil, xerrors.Errorf("get job by id: %w", err)
}
if job.WorkerID.UUID.String() != server.ID.String() {
if job.WorkerID.UUID.String() != s.ID.String() {
return nil, xerrors.Errorf("you don't own this job")
}
telemetrySnapshot := &telemetry.Snapshot{}
// Items are added to this snapshot as they complete!
defer server.Telemetry.Report(telemetrySnapshot)
defer s.Telemetry.Report(telemetrySnapshot)
switch jobType := completed.Type.(type) {
case *proto.CompletedJob_TemplateImport_:
@ -780,13 +844,13 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
database.WorkspaceTransitionStop: jobType.TemplateImport.StopResources,
} {
for _, resource := range resources {
server.Logger.Info(ctx, "inserting template import job resource",
s.Logger.Info(ctx, "inserting template import job resource",
slog.F("job_id", job.ID.String()),
slog.F("resource_name", resource.Name),
slog.F("resource_type", resource.Type),
slog.F("transition", transition))
err = InsertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot)
err = InsertWorkspaceResource(ctx, s.Database, jobID, transition, resource, telemetrySnapshot)
if err != nil {
return nil, xerrors.Errorf("insert resource: %w", err)
}
@ -794,7 +858,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
}
for _, richParameter := range jobType.TemplateImport.RichParameters {
server.Logger.Info(ctx, "inserting template import job parameter",
s.Logger.Info(ctx, "inserting template import job parameter",
slog.F("job_id", job.ID.String()),
slog.F("parameter_name", richParameter.Name),
slog.F("type", richParameter.Type),
@ -819,7 +883,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
}
}
_, err = server.Database.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{
_, err = s.Database.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{
TemplateVersionID: input.TemplateVersionID,
Name: richParameter.Name,
DisplayName: richParameter.DisplayName,
@ -847,7 +911,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
for _, gitAuthProvider := range jobType.TemplateImport.GitAuthProviders {
contains := false
for _, configuredProvider := range server.GitAuthConfigs {
for _, configuredProvider := range s.GitAuthConfigs {
if configuredProvider.ID == gitAuthProvider {
contains = true
break
@ -862,7 +926,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
}
}
err = server.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{
err = s.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{
JobID: jobID,
GitAuthProviders: jobType.TemplateImport.GitAuthProviders,
UpdatedAt: database.Now(),
@ -871,7 +935,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
return nil, xerrors.Errorf("update template version git auth providers: %w", err)
}
err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: jobID,
UpdatedAt: database.Now(),
CompletedAt: sql.NullTime{
@ -883,7 +947,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
if err != nil {
return nil, xerrors.Errorf("update provisioner job: %w", err)
}
server.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID))
s.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID))
if err != nil {
return nil, xerrors.Errorf("complete job: %w", err)
}
@ -894,7 +958,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
return nil, xerrors.Errorf("unmarshal job data: %w", err)
}
workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
if err != nil {
return nil, xerrors.Errorf("get workspace build: %w", err)
}
@ -902,14 +966,14 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
var workspace database.Workspace
var getWorkspaceError error
err = server.Database.InTx(func(db database.Store) error {
// It's important we use server.timeNow() here because we want to be
err = s.Database.InTx(func(db database.Store) error {
// It's important we use s.timeNow() here because we want to be
// able to customize the current time from within tests.
now := server.timeNow()
now := s.timeNow()
workspace, getWorkspaceError = db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
if getWorkspaceError != nil {
server.Logger.Error(ctx,
s.Logger.Error(ctx,
"fetch workspace for build",
slog.F("workspace_build_id", workspaceBuild.ID),
slog.F("workspace_id", workspaceBuild.WorkspaceID),
@ -919,8 +983,8 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
autoStop, err := schedule.CalculateAutostop(ctx, schedule.CalculateAutostopParams{
Database: db,
TemplateScheduleStore: *server.TemplateScheduleStore.Load(),
UserQuietHoursScheduleStore: *server.UserQuietHoursScheduleStore.Load(),
TemplateScheduleStore: *s.TemplateScheduleStore.Load(),
UserQuietHoursScheduleStore: *s.UserQuietHoursScheduleStore.Load(),
Now: now,
Workspace: workspace,
})
@ -976,7 +1040,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
var updates []<-chan time.Time
for _, d := range timeouts {
server.Logger.Debug(ctx, "triggering workspace notification after agent timeout",
s.Logger.Debug(ctx, "triggering workspace notification after agent timeout",
slog.F("workspace_build_id", workspaceBuild.ID),
slog.F("timeout", d),
)
@ -988,11 +1052,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
for _, wait := range updates {
// Wait for the next potential timeout to occur. Note that we
// can't listen on the context here because we will hang around
// after this function has returned. The server also doesn't
// after this function has returned. The s also doesn't
// have a shutdown signal we can listen to.
<-wait
if err := server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil {
server.Logger.Error(ctx, "workspace notification after agent timeout failed",
if err := s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil {
s.Logger.Error(ctx, "workspace notification after agent timeout failed",
slog.F("workspace_build_id", workspaceBuild.ID),
slog.Error(err),
)
@ -1022,11 +1086,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
// audit the outcome of the workspace build
if getWorkspaceError == nil {
auditor := server.Auditor.Load()
auditor := s.Auditor.Load()
auditAction := auditActionFromTransition(workspaceBuild.Transition)
previousBuildNumber := workspaceBuild.BuildNumber - 1
previousBuild, prevBuildErr := server.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
WorkspaceID: workspace.ID,
BuildNumber: previousBuildNumber,
})
@ -1044,12 +1108,12 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
wriBytes, err := json.Marshal(buildResourceInfo)
if err != nil {
server.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err))
s.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err))
}
audit.BuildAudit(ctx, &audit.BuildAuditParams[database.WorkspaceBuild]{
Audit: *auditor,
Log: server.Logger,
Log: s.Logger,
UserID: job.InitiatorID,
JobID: job.ID,
Action: auditAction,
@ -1060,24 +1124,24 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
})
}
err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{})
err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{})
if err != nil {
return nil, xerrors.Errorf("update workspace: %w", err)
}
case *proto.CompletedJob_TemplateDryRun_:
for _, resource := range jobType.TemplateDryRun.Resources {
server.Logger.Info(ctx, "inserting template dry-run job resource",
s.Logger.Info(ctx, "inserting template dry-run job resource",
slog.F("job_id", job.ID.String()),
slog.F("resource_name", resource.Name),
slog.F("resource_type", resource.Type))
err = InsertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot)
err = InsertWorkspaceResource(ctx, s.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot)
if err != nil {
return nil, xerrors.Errorf("insert resource: %w", err)
}
}
err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: jobID,
UpdatedAt: database.Now(),
CompletedAt: sql.NullTime{
@ -1088,7 +1152,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
if err != nil {
return nil, xerrors.Errorf("update provisioner job: %w", err)
}
server.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID))
s.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID))
if err != nil {
return nil, xerrors.Errorf("complete job: %w", err)
}
@ -1105,18 +1169,18 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
return nil, xerrors.Errorf("publish end of job logs: %w", err)
}
server.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID))
s.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID))
return &proto.Empty{}, nil
}
func (server *Server) startTrace(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
return server.Tracer.Start(ctx, name, append(opts, trace.WithAttributes(
func (s *server) startTrace(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
return s.Tracer.Start(ctx, name, append(opts, trace.WithAttributes(
semconv.ServiceNameKey.String("coderd.provisionerd"),
))...)
}
@ -1316,19 +1380,19 @@ func workspaceSessionTokenName(workspace database.Workspace) string {
return fmt.Sprintf("%s_%s_session_token", workspace.OwnerID, workspace.ID)
}
func (server *Server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) {
func (s *server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) {
newkey, sessionToken, err := apikey.Generate(apikey.CreateParams{
UserID: user.ID,
LoginType: user.LoginType,
DeploymentValues: server.DeploymentValues,
DeploymentValues: s.DeploymentValues,
TokenName: workspaceSessionTokenName(workspace),
LifetimeSeconds: int64(server.DeploymentValues.MaxTokenLifetime.Value().Seconds()),
LifetimeSeconds: int64(s.DeploymentValues.MaxTokenLifetime.Value().Seconds()),
})
if err != nil {
return "", xerrors.Errorf("generate API key: %w", err)
}
err = server.Database.InTx(func(tx database.Store) error {
err = s.Database.InTx(func(tx database.Store) error {
err := deleteSessionToken(ctx, tx, workspace)
if err != nil {
return xerrors.Errorf("delete session token: %w", err)

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@ import (
"net"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
@ -243,23 +244,33 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
return
}
mux := drpcmux.New()
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
GitAuthConfigs: api.GitAuthConfigs,
OIDCConfig: api.OIDCConfig,
ID: daemon.ID,
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
Telemetry: api.Telemetry,
Auditor: &api.AGPL.Auditor,
TemplateScheduleStore: api.AGPL.TemplateScheduleStore,
UserQuietHoursScheduleStore: api.AGPL.UserQuietHoursScheduleStore,
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
Tags: rawTags,
Tracer: trace.NewNoopTracerProvider().Tracer("noop"),
DeploymentValues: api.DeploymentValues,
})
srv, err := provisionerdserver.NewServer(
api.AccessURL,
daemon.ID,
api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
daemon.Provisioners,
rawTags,
api.Database,
api.Pubsub,
api.Telemetry,
trace.NewNoopTracerProvider().Tracer("noop"),
&api.AGPL.QuotaCommitter,
&api.AGPL.Auditor,
api.AGPL.TemplateScheduleStore,
api.AGPL.UserQuietHoursScheduleStore,
api.DeploymentValues,
// TODO(spikecurtis) - fix debounce to not cause flaky tests.
time.Duration(0),
provisionerdserver.Options{
GitAuthConfigs: api.GitAuthConfigs,
OIDCConfig: api.OIDCConfig,
},
)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("create provisioner daemon server: %s", err))
return
}
err = proto.DRPCRegisterProvisionerDaemon(mux, srv)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err))
return

View File

@ -9,13 +9,20 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionerd"
provisionerdproto "github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
@ -212,6 +219,107 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Len(t, daemons, 1)
})
t.Run("PSK_daily_cost", func(t *testing.T) {
t.Parallel()
client, user := coderdenttest.New(t, &coderdenttest.Options{
UserWorkspaceQuota: 10,
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
codersdk.FeatureTemplateRBAC: 1,
},
},
ProvisionerDaemonPSK: "provisionersftw",
})
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
terraformClient, terraformServer := provisionersdk.MemTransportPipe()
go func() {
<-ctx.Done()
_ = terraformClient.Close()
_ = terraformServer.Close()
}()
tempDir := t.TempDir()
errCh := make(chan error)
go func() {
err := echo.Serve(ctx, &provisionersdk.ServeOptions{
Listener: terraformServer,
Logger: logger.Named("echo"),
WorkDirectory: tempDir,
})
errCh <- err
}()
provisioners := provisionerd.Provisioners{
string(database.ProvisionerTypeEcho): proto.NewDRPCProvisionerClient(terraformClient),
}
another := codersdk.New(client.URL)
pd := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{
Organization: user.OrganizationID,
Provisioners: []codersdk.ProvisionerType{
codersdk.ProvisionerTypeEcho,
},
Tags: map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
},
PreSharedKey: "provisionersftw",
})
}, &provisionerd.Options{
Logger: logger.Named("provisionerd"),
Provisioners: provisioners,
})
defer pd.Close()
// Patch the 'Everyone' group to give the user quota to build their workspace.
_, err := client.PatchGroup(ctx, user.OrganizationID, codersdk.PatchGroupRequest{
QuotaAllowance: ptr.Ref(1),
})
require.NoError(t, err)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionApply: []*proto.Response{{
Type: &proto.Response_Apply{
Apply: &proto.ApplyComplete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
DailyCost: 1,
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Name: "example",
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
err = pd.Shutdown(ctx)
require.NoError(t, err)
err = terraformServer.Close()
require.NoError(t, err)
select {
case <-ctx.Done():
t.Error("timeout waiting for server to shut down")
case err := <-errCh:
require.NoError(t, err)
}
})
t.Run("BadPSK", func(t *testing.T) {
t.Parallel()
client, user := coderdenttest.New(t, &coderdenttest.Options{

View File

@ -308,6 +308,7 @@ func (p *Server) acquireJob(ctx context.Context) {
lastAcquireMutex.RLock()
if !lastAcquire.IsZero() && time.Since(lastAcquire) < p.opts.JobPollDebounce {
lastAcquireMutex.RUnlock()
p.opts.Logger.Debug(ctx, "debounce acquire job")
return
}
lastAcquireMutex.RUnlock()
@ -319,6 +320,7 @@ func (p *Server) acquireJob(ctx context.Context) {
}
job, err := client.AcquireJob(ctx, &proto.Empty{})
p.opts.Logger.Debug(ctx, "called AcquireJob on client", slog.F("job_id", job.GetJobId()), slog.Error(err))
if err != nil {
if errors.Is(err, context.Canceled) ||
errors.Is(err, yamux.ErrSessionShutdown) ||