mirror of https://github.com/coder/coder.git
fix: fix null pointer on external provisioner daemons with daily_cost (#9401)
* fix: fix null pointer on external provisioner daemons with daily_cost Signed-off-by: Spike Curtis <spike@coder.com> * Add logging for debounce and job acquire Signed-off-by: Spike Curtis <spike@coder.com> * Return error instead of panic Signed-off-by: Spike Curtis <spike@coder.com> * remove debounce on external provisioners to fix test flakes Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
a415395e9e
commit
90acf998bf
|
@ -1098,26 +1098,31 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
|
|||
}
|
||||
|
||||
mux := drpcmux.New()
|
||||
|
||||
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
|
||||
AccessURL: api.AccessURL,
|
||||
ID: daemon.ID,
|
||||
OIDCConfig: api.OIDCConfig,
|
||||
Database: api.Database,
|
||||
Pubsub: api.Pubsub,
|
||||
Provisioners: daemon.Provisioners,
|
||||
GitAuthConfigs: api.GitAuthConfigs,
|
||||
Telemetry: api.Telemetry,
|
||||
Tracer: tracer,
|
||||
Tags: tags,
|
||||
QuotaCommitter: &api.QuotaCommitter,
|
||||
Auditor: &api.Auditor,
|
||||
TemplateScheduleStore: api.TemplateScheduleStore,
|
||||
UserQuietHoursScheduleStore: api.UserQuietHoursScheduleStore,
|
||||
AcquireJobDebounce: debounce,
|
||||
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
|
||||
DeploymentValues: api.DeploymentValues,
|
||||
})
|
||||
srv, err := provisionerdserver.NewServer(
|
||||
api.AccessURL,
|
||||
daemon.ID,
|
||||
api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
|
||||
daemon.Provisioners,
|
||||
tags,
|
||||
api.Database,
|
||||
api.Pubsub,
|
||||
api.Telemetry,
|
||||
tracer,
|
||||
&api.QuotaCommitter,
|
||||
&api.Auditor,
|
||||
api.TemplateScheduleStore,
|
||||
api.UserQuietHoursScheduleStore,
|
||||
api.DeploymentValues,
|
||||
debounce,
|
||||
provisionerdserver.Options{
|
||||
OIDCConfig: api.OIDCConfig,
|
||||
GitAuthConfigs: api.GitAuthConfigs,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = proto.DRPCRegisterProvisionerDaemon(mux, srv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -48,7 +48,14 @@ var (
|
|||
lastAcquireMutex sync.RWMutex
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
type Options struct {
|
||||
OIDCConfig httpmw.OAuth2Config
|
||||
GitAuthConfigs []*gitauth.Config
|
||||
// TimeNowFn is only used in tests
|
||||
TimeNowFn func() time.Time
|
||||
}
|
||||
|
||||
type server struct {
|
||||
AccessURL *url.URL
|
||||
ID uuid.UUID
|
||||
Logger slog.Logger
|
||||
|
@ -71,17 +78,73 @@ type Server struct {
|
|||
TimeNowFn func() time.Time
|
||||
}
|
||||
|
||||
func NewServer(
|
||||
accessURL *url.URL,
|
||||
id uuid.UUID,
|
||||
logger slog.Logger,
|
||||
provisioners []database.ProvisionerType,
|
||||
tags json.RawMessage,
|
||||
db database.Store,
|
||||
ps pubsub.Pubsub,
|
||||
tel telemetry.Reporter,
|
||||
tracer trace.Tracer,
|
||||
quotaCommitter *atomic.Pointer[proto.QuotaCommitter],
|
||||
auditor *atomic.Pointer[audit.Auditor],
|
||||
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore],
|
||||
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore],
|
||||
deploymentValues *codersdk.DeploymentValues,
|
||||
acquireJobDebounce time.Duration,
|
||||
options Options,
|
||||
) (proto.DRPCProvisionerDaemonServer, error) {
|
||||
// Panic early if pointers are nil
|
||||
if quotaCommitter == nil {
|
||||
return nil, xerrors.New("quotaCommitter is nil")
|
||||
}
|
||||
if auditor == nil {
|
||||
return nil, xerrors.New("auditor is nil")
|
||||
}
|
||||
if templateScheduleStore == nil {
|
||||
return nil, xerrors.New("templateScheduleStore is nil")
|
||||
}
|
||||
if userQuietHoursScheduleStore == nil {
|
||||
return nil, xerrors.New("userQuietHoursScheduleStore is nil")
|
||||
}
|
||||
if deploymentValues == nil {
|
||||
return nil, xerrors.New("deploymentValues is nil")
|
||||
}
|
||||
return &server{
|
||||
AccessURL: accessURL,
|
||||
ID: id,
|
||||
Logger: logger,
|
||||
Provisioners: provisioners,
|
||||
GitAuthConfigs: options.GitAuthConfigs,
|
||||
Tags: tags,
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
Telemetry: tel,
|
||||
Tracer: tracer,
|
||||
QuotaCommitter: quotaCommitter,
|
||||
Auditor: auditor,
|
||||
TemplateScheduleStore: templateScheduleStore,
|
||||
UserQuietHoursScheduleStore: userQuietHoursScheduleStore,
|
||||
DeploymentValues: deploymentValues,
|
||||
AcquireJobDebounce: acquireJobDebounce,
|
||||
OIDCConfig: options.OIDCConfig,
|
||||
TimeNowFn: options.TimeNowFn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// timeNow should be used when trying to get the current time for math
|
||||
// calculations regarding workspace start and stop time.
|
||||
func (server *Server) timeNow() time.Time {
|
||||
if server.TimeNowFn != nil {
|
||||
return database.Time(server.TimeNowFn())
|
||||
func (s *server) timeNow() time.Time {
|
||||
if s.TimeNowFn != nil {
|
||||
return database.Time(s.TimeNowFn())
|
||||
}
|
||||
return database.Now()
|
||||
}
|
||||
|
||||
// AcquireJob queries the database to lock a job.
|
||||
func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
func (s *server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
//nolint:gocritic // Provisionerd has specific authz rules.
|
||||
ctx = dbauthz.AsProvisionerd(ctx)
|
||||
// This prevents loads of provisioner daemons from consistently
|
||||
|
@ -90,23 +153,24 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
// The debounce only occurs when no job is returned, so if loads of
|
||||
// jobs are added at once, they will start after at most this duration.
|
||||
lastAcquireMutex.RLock()
|
||||
if !lastAcquire.IsZero() && time.Since(lastAcquire) < server.AcquireJobDebounce {
|
||||
if !lastAcquire.IsZero() && time.Since(lastAcquire) < s.AcquireJobDebounce {
|
||||
s.Logger.Debug(ctx, "debounce acquire job", slog.F("debounce", s.AcquireJobDebounce), slog.F("last_acquire", lastAcquire))
|
||||
lastAcquireMutex.RUnlock()
|
||||
return &proto.AcquiredJob{}, nil
|
||||
}
|
||||
lastAcquireMutex.RUnlock()
|
||||
// This marks the job as locked in the database.
|
||||
job, err := server.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
job, err := s.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
StartedAt: sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: server.ID,
|
||||
UUID: s.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: server.Provisioners,
|
||||
Tags: server.Tags,
|
||||
Types: s.Provisioners,
|
||||
Tags: s.Tags,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The provisioner daemon assumes no jobs are available if
|
||||
|
@ -119,11 +183,11 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("acquire job: %w", err)
|
||||
}
|
||||
server.Logger.Debug(ctx, "locked job from database", slog.F("job_id", job.ID))
|
||||
s.Logger.Debug(ctx, "locked job from database", slog.F("job_id", job.ID))
|
||||
|
||||
// Marks the acquired job as failed with the error message provided.
|
||||
failJob := func(errorMessage string) error {
|
||||
err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: job.ID,
|
||||
CompletedAt: sql.NullTime{
|
||||
Time: database.Now(),
|
||||
|
@ -141,7 +205,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
return xerrors.Errorf("request job was invalidated: %s", errorMessage)
|
||||
}
|
||||
|
||||
user, err := server.Database.GetUserByID(ctx, job.InitiatorID)
|
||||
user, err := s.Database.GetUserByID(ctx, job.InitiatorID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get user: %s", err))
|
||||
}
|
||||
|
@ -169,38 +233,38 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err))
|
||||
}
|
||||
workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
|
||||
workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get workspace build: %s", err))
|
||||
}
|
||||
workspace, err := server.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
|
||||
workspace, err := s.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get workspace: %s", err))
|
||||
}
|
||||
templateVersion, err := server.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID)
|
||||
templateVersion, err := s.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get template version: %s", err))
|
||||
}
|
||||
templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, templateVersion.ID)
|
||||
templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, templateVersion.ID)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, failJob(fmt.Sprintf("get template version variables: %s", err))
|
||||
}
|
||||
template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
|
||||
template, err := s.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get template: %s", err))
|
||||
}
|
||||
owner, err := server.Database.GetUserByID(ctx, workspace.OwnerID)
|
||||
owner, err := s.Database.GetUserByID(ctx, workspace.OwnerID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get owner: %s", err))
|
||||
}
|
||||
err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{})
|
||||
err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspace.ID), []byte{})
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("publish workspace update: %s", err))
|
||||
}
|
||||
|
||||
var workspaceOwnerOIDCAccessToken string
|
||||
if server.OIDCConfig != nil {
|
||||
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, server.Database, server.OIDCConfig, owner.ID)
|
||||
if s.OIDCConfig != nil {
|
||||
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Database, s.OIDCConfig, owner.ID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
|
||||
}
|
||||
|
@ -209,12 +273,12 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
var sessionToken string
|
||||
switch workspaceBuild.Transition {
|
||||
case database.WorkspaceTransitionStart:
|
||||
sessionToken, err = server.regenerateSessionToken(ctx, owner, workspace)
|
||||
sessionToken, err = s.regenerateSessionToken(ctx, owner, workspace)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
|
||||
}
|
||||
case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete:
|
||||
err = deleteSessionToken(ctx, server.Database, workspace)
|
||||
err = deleteSessionToken(ctx, s.Database, workspace)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("delete session token: %s", err))
|
||||
}
|
||||
|
@ -225,14 +289,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
return nil, failJob(fmt.Sprintf("convert workspace transition: %s", err))
|
||||
}
|
||||
|
||||
workspaceBuildParameters, err := server.Database.GetWorkspaceBuildParameters(ctx, workspaceBuild.ID)
|
||||
workspaceBuildParameters, err := s.Database.GetWorkspaceBuildParameters(ctx, workspaceBuild.ID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get workspace build parameters: %s", err))
|
||||
}
|
||||
|
||||
gitAuthProviders := []*sdkproto.GitAuthProvider{}
|
||||
for _, p := range templateVersion.GitAuthProviders {
|
||||
link, err := server.Database.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
|
||||
link, err := s.Database.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
|
||||
ProviderID: p,
|
||||
UserID: owner.ID,
|
||||
})
|
||||
|
@ -243,7 +307,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
return nil, failJob(fmt.Sprintf("acquire git auth link: %s", err))
|
||||
}
|
||||
var config *gitauth.Config
|
||||
for _, c := range server.GitAuthConfigs {
|
||||
for _, c := range s.GitAuthConfigs {
|
||||
if c.ID != p {
|
||||
continue
|
||||
}
|
||||
|
@ -252,14 +316,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
}
|
||||
// We weren't able to find a matching config for the ID!
|
||||
if config == nil {
|
||||
server.Logger.Warn(ctx, "workspace build job is missing git provider",
|
||||
s.Logger.Warn(ctx, "workspace build job is missing git provider",
|
||||
slog.F("git_provider_id", p),
|
||||
slog.F("template_version_id", templateVersion.ID),
|
||||
slog.F("workspace_id", workspaceBuild.WorkspaceID))
|
||||
continue
|
||||
}
|
||||
|
||||
link, valid, err := config.RefreshToken(ctx, server.Database, link)
|
||||
link, valid, err := config.RefreshToken(ctx, s.Database, link)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("refresh git auth link %q: %s", p, err))
|
||||
}
|
||||
|
@ -281,7 +345,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
VariableValues: asVariableValues(templateVariables),
|
||||
GitAuthProviders: gitAuthProviders,
|
||||
Metadata: &sdkproto.Metadata{
|
||||
CoderUrl: server.AccessURL.String(),
|
||||
CoderUrl: s.AccessURL.String(),
|
||||
WorkspaceTransition: transition,
|
||||
WorkspaceName: workspace.Name,
|
||||
WorkspaceOwner: owner.Username,
|
||||
|
@ -303,11 +367,11 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err))
|
||||
}
|
||||
|
||||
templateVersion, err := server.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID)
|
||||
templateVersion, err := s.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get template version: %s", err))
|
||||
}
|
||||
templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, templateVersion.ID)
|
||||
templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, templateVersion.ID)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, failJob(fmt.Sprintf("get template version variables: %s", err))
|
||||
}
|
||||
|
@ -317,7 +381,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
RichParameterValues: convertRichParameterValues(input.RichParameterValues),
|
||||
VariableValues: asVariableValues(templateVariables),
|
||||
Metadata: &sdkproto.Metadata{
|
||||
CoderUrl: server.AccessURL.String(),
|
||||
CoderUrl: s.AccessURL.String(),
|
||||
WorkspaceName: input.WorkspaceName,
|
||||
},
|
||||
},
|
||||
|
@ -329,7 +393,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err))
|
||||
}
|
||||
|
||||
userVariableValues, err := server.includeLastVariableValues(ctx, input.TemplateVersionID, input.UserVariableValues)
|
||||
userVariableValues, err := s.includeLastVariableValues(ctx, input.TemplateVersionID, input.UserVariableValues)
|
||||
if err != nil {
|
||||
return nil, failJob(err.Error())
|
||||
}
|
||||
|
@ -338,14 +402,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
TemplateImport: &proto.AcquiredJob_TemplateImport{
|
||||
UserVariableValues: convertVariableValues(userVariableValues),
|
||||
Metadata: &sdkproto.Metadata{
|
||||
CoderUrl: server.AccessURL.String(),
|
||||
CoderUrl: s.AccessURL.String(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
switch job.StorageMethod {
|
||||
case database.ProvisionerStorageMethodFile:
|
||||
file, err := server.Database.GetFileByID(ctx, job.FileID)
|
||||
file, err := s.Database.GetFileByID(ctx, job.FileID)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("get file by hash: %s", err))
|
||||
}
|
||||
|
@ -360,7 +424,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
|||
return protoJob, err
|
||||
}
|
||||
|
||||
func (server *Server) includeLastVariableValues(ctx context.Context, templateVersionID uuid.UUID, userVariableValues []codersdk.VariableValue) ([]codersdk.VariableValue, error) {
|
||||
func (s *server) includeLastVariableValues(ctx context.Context, templateVersionID uuid.UUID, userVariableValues []codersdk.VariableValue) ([]codersdk.VariableValue, error) {
|
||||
var values []codersdk.VariableValue
|
||||
values = append(values, userVariableValues...)
|
||||
|
||||
|
@ -368,7 +432,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
|
|||
return values, nil
|
||||
}
|
||||
|
||||
templateVersion, err := server.Database.GetTemplateVersionByID(ctx, templateVersionID)
|
||||
templateVersion, err := s.Database.GetTemplateVersionByID(ctx, templateVersionID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get template version: %w", err)
|
||||
}
|
||||
|
@ -377,7 +441,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
|
|||
return values, nil
|
||||
}
|
||||
|
||||
template, err := server.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
|
||||
template, err := s.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get template: %w", err)
|
||||
}
|
||||
|
@ -386,7 +450,7 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
|
|||
return values, nil
|
||||
}
|
||||
|
||||
templateVariables, err := server.Database.GetTemplateVersionVariables(ctx, template.ActiveVersionID)
|
||||
templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, template.ActiveVersionID)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get template version variables: %w", err)
|
||||
}
|
||||
|
@ -412,8 +476,8 @@ func (server *Server) includeLastVariableValues(ctx context.Context, templateVer
|
|||
return values, nil
|
||||
}
|
||||
|
||||
func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
|
||||
ctx, span := server.startTrace(ctx, tracing.FuncName())
|
||||
func (s *server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
|
||||
ctx, span := s.startTrace(ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
//nolint:gocritic // Provisionerd has specific authz rules.
|
||||
|
@ -423,7 +487,7 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot
|
|||
return nil, xerrors.Errorf("parse job id: %w", err)
|
||||
}
|
||||
|
||||
job, err := server.Database.GetProvisionerJobByID(ctx, jobID)
|
||||
job, err := s.Database.GetProvisionerJobByID(ctx, jobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get job: %w", err)
|
||||
}
|
||||
|
@ -431,11 +495,11 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot
|
|||
return nil, xerrors.New("job isn't running yet")
|
||||
}
|
||||
|
||||
if job.WorkerID.UUID.String() != server.ID.String() {
|
||||
if job.WorkerID.UUID.String() != s.ID.String() {
|
||||
return nil, xerrors.New("you don't own this job")
|
||||
}
|
||||
|
||||
q := server.QuotaCommitter.Load()
|
||||
q := s.QuotaCommitter.Load()
|
||||
if q == nil {
|
||||
// We're probably in community edition or a test.
|
||||
return &proto.CommitQuotaResponse{
|
||||
|
@ -446,8 +510,8 @@ func (server *Server) CommitQuota(ctx context.Context, request *proto.CommitQuot
|
|||
return (*q).CommitQuota(ctx, request)
|
||||
}
|
||||
|
||||
func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
|
||||
ctx, span := server.startTrace(ctx, tracing.FuncName())
|
||||
func (s *server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
|
||||
ctx, span := s.startTrace(ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
//nolint:gocritic // Provisionerd has specific authz rules.
|
||||
|
@ -456,18 +520,18 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse job id: %w", err)
|
||||
}
|
||||
server.Logger.Debug(ctx, "stage UpdateJob starting", slog.F("job_id", parsedID))
|
||||
job, err := server.Database.GetProvisionerJobByID(ctx, parsedID)
|
||||
s.Logger.Debug(ctx, "stage UpdateJob starting", slog.F("job_id", parsedID))
|
||||
job, err := s.Database.GetProvisionerJobByID(ctx, parsedID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get job: %w", err)
|
||||
}
|
||||
if !job.WorkerID.Valid {
|
||||
return nil, xerrors.New("job isn't running yet")
|
||||
}
|
||||
if job.WorkerID.UUID.String() != server.ID.String() {
|
||||
if job.WorkerID.UUID.String() != s.ID.String() {
|
||||
return nil, xerrors.New("you don't own this job")
|
||||
}
|
||||
err = server.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{
|
||||
err = s.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{
|
||||
ID: parsedID,
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
|
@ -493,37 +557,37 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
|
|||
insertParams.Stage = append(insertParams.Stage, log.Stage)
|
||||
insertParams.Source = append(insertParams.Source, logSource)
|
||||
insertParams.Output = append(insertParams.Output, log.Output)
|
||||
server.Logger.Debug(ctx, "job log",
|
||||
s.Logger.Debug(ctx, "job log",
|
||||
slog.F("job_id", parsedID),
|
||||
slog.F("stage", log.Stage),
|
||||
slog.F("output", log.Output))
|
||||
}
|
||||
|
||||
logs, err := server.Database.InsertProvisionerJobLogs(ctx, insertParams)
|
||||
logs, err := s.Database.InsertProvisionerJobLogs(ctx, insertParams)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err))
|
||||
s.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("insert job logs: %w", err)
|
||||
}
|
||||
// Publish by the lowest log ID inserted so the log stream will fetch
|
||||
// everything from that point.
|
||||
lowestID := logs[0].ID
|
||||
server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
|
||||
s.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
|
||||
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{
|
||||
CreatedAfter: lowestID - 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal: %w", err)
|
||||
}
|
||||
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data)
|
||||
err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err))
|
||||
s.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish job logs: %w", err)
|
||||
}
|
||||
server.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID))
|
||||
s.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID))
|
||||
}
|
||||
|
||||
if len(request.Readme) > 0 {
|
||||
err := server.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{
|
||||
err := s.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{
|
||||
JobID: job.ID,
|
||||
Readme: string(request.Readme),
|
||||
UpdatedAt: database.Now(),
|
||||
|
@ -534,16 +598,16 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
|
|||
}
|
||||
|
||||
if len(request.TemplateVariables) > 0 {
|
||||
templateVersion, err := server.Database.GetTemplateVersionByJobID(ctx, job.ID)
|
||||
templateVersion, err := s.Database.GetTemplateVersionByJobID(ctx, job.ID)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to get the template version", slog.F("job_id", parsedID), slog.Error(err))
|
||||
s.Logger.Error(ctx, "failed to get the template version", slog.F("job_id", parsedID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("get template version by job id: %w", err)
|
||||
}
|
||||
|
||||
var variableValues []*sdkproto.VariableValue
|
||||
var variablesWithMissingValues []string
|
||||
for _, templateVariable := range request.TemplateVariables {
|
||||
server.Logger.Debug(ctx, "insert template variable", slog.F("template_version_id", templateVersion.ID), slog.F("template_variable", redactTemplateVariable(templateVariable)))
|
||||
s.Logger.Debug(ctx, "insert template variable", slog.F("template_version_id", templateVersion.ID), slog.F("template_variable", redactTemplateVariable(templateVariable)))
|
||||
|
||||
value := templateVariable.DefaultValue
|
||||
for _, v := range request.UserVariableValues {
|
||||
|
@ -563,7 +627,7 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
|
|||
Sensitive: templateVariable.Sensitive,
|
||||
})
|
||||
|
||||
_, err = server.Database.InsertTemplateVersionVariable(ctx, database.InsertTemplateVersionVariableParams{
|
||||
_, err = s.Database.InsertTemplateVersionVariable(ctx, database.InsertTemplateVersionVariableParams{
|
||||
TemplateVersionID: templateVersion.ID,
|
||||
Name: templateVariable.Name,
|
||||
Description: templateVariable.Description,
|
||||
|
@ -593,8 +657,8 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) {
|
||||
ctx, span := server.startTrace(ctx, tracing.FuncName())
|
||||
func (s *server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) {
|
||||
ctx, span := s.startTrace(ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
//nolint:gocritic // Provisionerd has specific authz rules.
|
||||
|
@ -603,12 +667,12 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse job id: %w", err)
|
||||
}
|
||||
server.Logger.Debug(ctx, "stage FailJob starting", slog.F("job_id", jobID))
|
||||
job, err := server.Database.GetProvisionerJobByID(ctx, jobID)
|
||||
s.Logger.Debug(ctx, "stage FailJob starting", slog.F("job_id", jobID))
|
||||
job, err := s.Database.GetProvisionerJobByID(ctx, jobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
if job.WorkerID.UUID.String() != server.ID.String() {
|
||||
if job.WorkerID.UUID.String() != s.ID.String() {
|
||||
return nil, xerrors.New("you don't own this job")
|
||||
}
|
||||
if job.CompletedAt.Valid {
|
||||
|
@ -627,7 +691,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
Valid: failJob.ErrorCode != "",
|
||||
}
|
||||
|
||||
err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: jobID,
|
||||
CompletedAt: job.CompletedAt,
|
||||
UpdatedAt: database.Now(),
|
||||
|
@ -637,7 +701,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("update provisioner job: %w", err)
|
||||
}
|
||||
server.Telemetry.Report(&telemetry.Snapshot{
|
||||
s.Telemetry.Report(&telemetry.Snapshot{
|
||||
ProvisionerJobs: []telemetry.ProvisionerJob{telemetry.ConvertProvisionerJob(job)},
|
||||
})
|
||||
|
||||
|
@ -650,7 +714,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
}
|
||||
|
||||
var build database.WorkspaceBuild
|
||||
err = server.Database.InTx(func(db database.Store) error {
|
||||
err = s.Database.InTx(func(db database.Store) error {
|
||||
build, err = db.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace build: %w", err)
|
||||
|
@ -675,7 +739,7 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), []byte{})
|
||||
err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(build.WorkspaceID), []byte{})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update workspace: %w", err)
|
||||
}
|
||||
|
@ -684,18 +748,18 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
|
||||
// if failed job is a workspace build, audit the outcome
|
||||
if job.Type == database.ProvisionerJobTypeWorkspaceBuild {
|
||||
auditor := server.Auditor.Load()
|
||||
build, err := server.Database.GetWorkspaceBuildByJobID(ctx, job.ID)
|
||||
auditor := s.Auditor.Load()
|
||||
build, err := s.Database.GetWorkspaceBuildByJobID(ctx, job.ID)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "audit log - get build", slog.Error(err))
|
||||
s.Logger.Error(ctx, "audit log - get build", slog.Error(err))
|
||||
} else {
|
||||
auditAction := auditActionFromTransition(build.Transition)
|
||||
workspace, err := server.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
|
||||
workspace, err := s.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "audit log - get workspace", slog.Error(err))
|
||||
s.Logger.Error(ctx, "audit log - get workspace", slog.Error(err))
|
||||
} else {
|
||||
previousBuildNumber := build.BuildNumber - 1
|
||||
previousBuild, prevBuildErr := server.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: previousBuildNumber,
|
||||
})
|
||||
|
@ -713,12 +777,12 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
|
||||
wriBytes, err := json.Marshal(buildResourceInfo)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "marshal workspace resource info for failed job", slog.Error(err))
|
||||
s.Logger.Error(ctx, "marshal workspace resource info for failed job", slog.Error(err))
|
||||
}
|
||||
|
||||
audit.BuildAudit(ctx, &audit.BuildAuditParams[database.WorkspaceBuild]{
|
||||
Audit: *auditor,
|
||||
Log: server.Logger,
|
||||
Log: s.Logger,
|
||||
UserID: job.InitiatorID,
|
||||
JobID: job.ID,
|
||||
Action: auditAction,
|
||||
|
@ -735,17 +799,17 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish end of job logs: %w", err)
|
||||
}
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// CompleteJob is triggered by a provision daemon to mark a provisioner job as completed.
|
||||
func (server *Server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) {
|
||||
ctx, span := server.startTrace(ctx, tracing.FuncName())
|
||||
func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) {
|
||||
ctx, span := s.startTrace(ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
//nolint:gocritic // Provisionerd has specific authz rules.
|
||||
|
@ -754,18 +818,18 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse job id: %w", err)
|
||||
}
|
||||
server.Logger.Debug(ctx, "stage CompleteJob starting", slog.F("job_id", jobID))
|
||||
job, err := server.Database.GetProvisionerJobByID(ctx, jobID)
|
||||
s.Logger.Debug(ctx, "stage CompleteJob starting", slog.F("job_id", jobID))
|
||||
job, err := s.Database.GetProvisionerJobByID(ctx, jobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get job by id: %w", err)
|
||||
}
|
||||
if job.WorkerID.UUID.String() != server.ID.String() {
|
||||
if job.WorkerID.UUID.String() != s.ID.String() {
|
||||
return nil, xerrors.Errorf("you don't own this job")
|
||||
}
|
||||
|
||||
telemetrySnapshot := &telemetry.Snapshot{}
|
||||
// Items are added to this snapshot as they complete!
|
||||
defer server.Telemetry.Report(telemetrySnapshot)
|
||||
defer s.Telemetry.Report(telemetrySnapshot)
|
||||
|
||||
switch jobType := completed.Type.(type) {
|
||||
case *proto.CompletedJob_TemplateImport_:
|
||||
|
@ -780,13 +844,13 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
database.WorkspaceTransitionStop: jobType.TemplateImport.StopResources,
|
||||
} {
|
||||
for _, resource := range resources {
|
||||
server.Logger.Info(ctx, "inserting template import job resource",
|
||||
s.Logger.Info(ctx, "inserting template import job resource",
|
||||
slog.F("job_id", job.ID.String()),
|
||||
slog.F("resource_name", resource.Name),
|
||||
slog.F("resource_type", resource.Type),
|
||||
slog.F("transition", transition))
|
||||
|
||||
err = InsertWorkspaceResource(ctx, server.Database, jobID, transition, resource, telemetrySnapshot)
|
||||
err = InsertWorkspaceResource(ctx, s.Database, jobID, transition, resource, telemetrySnapshot)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert resource: %w", err)
|
||||
}
|
||||
|
@ -794,7 +858,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
}
|
||||
|
||||
for _, richParameter := range jobType.TemplateImport.RichParameters {
|
||||
server.Logger.Info(ctx, "inserting template import job parameter",
|
||||
s.Logger.Info(ctx, "inserting template import job parameter",
|
||||
slog.F("job_id", job.ID.String()),
|
||||
slog.F("parameter_name", richParameter.Name),
|
||||
slog.F("type", richParameter.Type),
|
||||
|
@ -819,7 +883,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
}
|
||||
}
|
||||
|
||||
_, err = server.Database.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{
|
||||
_, err = s.Database.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{
|
||||
TemplateVersionID: input.TemplateVersionID,
|
||||
Name: richParameter.Name,
|
||||
DisplayName: richParameter.DisplayName,
|
||||
|
@ -847,7 +911,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
|
||||
for _, gitAuthProvider := range jobType.TemplateImport.GitAuthProviders {
|
||||
contains := false
|
||||
for _, configuredProvider := range server.GitAuthConfigs {
|
||||
for _, configuredProvider := range s.GitAuthConfigs {
|
||||
if configuredProvider.ID == gitAuthProvider {
|
||||
contains = true
|
||||
break
|
||||
|
@ -862,7 +926,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
}
|
||||
}
|
||||
|
||||
err = server.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{
|
||||
err = s.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{
|
||||
JobID: jobID,
|
||||
GitAuthProviders: jobType.TemplateImport.GitAuthProviders,
|
||||
UpdatedAt: database.Now(),
|
||||
|
@ -871,7 +935,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
return nil, xerrors.Errorf("update template version git auth providers: %w", err)
|
||||
}
|
||||
|
||||
err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: jobID,
|
||||
UpdatedAt: database.Now(),
|
||||
CompletedAt: sql.NullTime{
|
||||
|
@ -883,7 +947,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("update provisioner job: %w", err)
|
||||
}
|
||||
server.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID))
|
||||
s.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("complete job: %w", err)
|
||||
}
|
||||
|
@ -894,7 +958,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
return nil, xerrors.Errorf("unmarshal job data: %w", err)
|
||||
}
|
||||
|
||||
workspaceBuild, err := server.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
|
||||
workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace build: %w", err)
|
||||
}
|
||||
|
@ -902,14 +966,14 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
var workspace database.Workspace
|
||||
var getWorkspaceError error
|
||||
|
||||
err = server.Database.InTx(func(db database.Store) error {
|
||||
// It's important we use server.timeNow() here because we want to be
|
||||
err = s.Database.InTx(func(db database.Store) error {
|
||||
// It's important we use s.timeNow() here because we want to be
|
||||
// able to customize the current time from within tests.
|
||||
now := server.timeNow()
|
||||
now := s.timeNow()
|
||||
|
||||
workspace, getWorkspaceError = db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
|
||||
if getWorkspaceError != nil {
|
||||
server.Logger.Error(ctx,
|
||||
s.Logger.Error(ctx,
|
||||
"fetch workspace for build",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.F("workspace_id", workspaceBuild.WorkspaceID),
|
||||
|
@ -919,8 +983,8 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
|
||||
autoStop, err := schedule.CalculateAutostop(ctx, schedule.CalculateAutostopParams{
|
||||
Database: db,
|
||||
TemplateScheduleStore: *server.TemplateScheduleStore.Load(),
|
||||
UserQuietHoursScheduleStore: *server.UserQuietHoursScheduleStore.Load(),
|
||||
TemplateScheduleStore: *s.TemplateScheduleStore.Load(),
|
||||
UserQuietHoursScheduleStore: *s.UserQuietHoursScheduleStore.Load(),
|
||||
Now: now,
|
||||
Workspace: workspace,
|
||||
})
|
||||
|
@ -976,7 +1040,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
|
||||
var updates []<-chan time.Time
|
||||
for _, d := range timeouts {
|
||||
server.Logger.Debug(ctx, "triggering workspace notification after agent timeout",
|
||||
s.Logger.Debug(ctx, "triggering workspace notification after agent timeout",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.F("timeout", d),
|
||||
)
|
||||
|
@ -988,11 +1052,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
for _, wait := range updates {
|
||||
// Wait for the next potential timeout to occur. Note that we
|
||||
// can't listen on the context here because we will hang around
|
||||
// after this function has returned. The server also doesn't
|
||||
// after this function has returned. The s also doesn't
|
||||
// have a shutdown signal we can listen to.
|
||||
<-wait
|
||||
if err := server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil {
|
||||
server.Logger.Error(ctx, "workspace notification after agent timeout failed",
|
||||
if err := s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{}); err != nil {
|
||||
s.Logger.Error(ctx, "workspace notification after agent timeout failed",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
|
@ -1022,11 +1086,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
|
||||
// audit the outcome of the workspace build
|
||||
if getWorkspaceError == nil {
|
||||
auditor := server.Auditor.Load()
|
||||
auditor := s.Auditor.Load()
|
||||
auditAction := auditActionFromTransition(workspaceBuild.Transition)
|
||||
|
||||
previousBuildNumber := workspaceBuild.BuildNumber - 1
|
||||
previousBuild, prevBuildErr := server.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: previousBuildNumber,
|
||||
})
|
||||
|
@ -1044,12 +1108,12 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
|
||||
wriBytes, err := json.Marshal(buildResourceInfo)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err))
|
||||
s.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err))
|
||||
}
|
||||
|
||||
audit.BuildAudit(ctx, &audit.BuildAuditParams[database.WorkspaceBuild]{
|
||||
Audit: *auditor,
|
||||
Log: server.Logger,
|
||||
Log: s.Logger,
|
||||
UserID: job.InitiatorID,
|
||||
JobID: job.ID,
|
||||
Action: auditAction,
|
||||
|
@ -1060,24 +1124,24 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
})
|
||||
}
|
||||
|
||||
err = server.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{})
|
||||
err = s.Pubsub.Publish(codersdk.WorkspaceNotifyChannel(workspaceBuild.WorkspaceID), []byte{})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update workspace: %w", err)
|
||||
}
|
||||
case *proto.CompletedJob_TemplateDryRun_:
|
||||
for _, resource := range jobType.TemplateDryRun.Resources {
|
||||
server.Logger.Info(ctx, "inserting template dry-run job resource",
|
||||
s.Logger.Info(ctx, "inserting template dry-run job resource",
|
||||
slog.F("job_id", job.ID.String()),
|
||||
slog.F("resource_name", resource.Name),
|
||||
slog.F("resource_type", resource.Type))
|
||||
|
||||
err = InsertWorkspaceResource(ctx, server.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot)
|
||||
err = InsertWorkspaceResource(ctx, s.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert resource: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = server.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: jobID,
|
||||
UpdatedAt: database.Now(),
|
||||
CompletedAt: sql.NullTime{
|
||||
|
@ -1088,7 +1152,7 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("update provisioner job: %w", err)
|
||||
}
|
||||
server.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID))
|
||||
s.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("complete job: %w", err)
|
||||
}
|
||||
|
@ -1105,18 +1169,18 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish end of job logs: %w", err)
|
||||
}
|
||||
|
||||
server.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID))
|
||||
s.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID))
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
func (server *Server) startTrace(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
return server.Tracer.Start(ctx, name, append(opts, trace.WithAttributes(
|
||||
func (s *server) startTrace(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
return s.Tracer.Start(ctx, name, append(opts, trace.WithAttributes(
|
||||
semconv.ServiceNameKey.String("coderd.provisionerd"),
|
||||
))...)
|
||||
}
|
||||
|
@ -1316,19 +1380,19 @@ func workspaceSessionTokenName(workspace database.Workspace) string {
|
|||
return fmt.Sprintf("%s_%s_session_token", workspace.OwnerID, workspace.ID)
|
||||
}
|
||||
|
||||
func (server *Server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) {
|
||||
func (s *server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) {
|
||||
newkey, sessionToken, err := apikey.Generate(apikey.CreateParams{
|
||||
UserID: user.ID,
|
||||
LoginType: user.LoginType,
|
||||
DeploymentValues: server.DeploymentValues,
|
||||
DeploymentValues: s.DeploymentValues,
|
||||
TokenName: workspaceSessionTokenName(workspace),
|
||||
LifetimeSeconds: int64(server.DeploymentValues.MaxTokenLifetime.Value().Seconds()),
|
||||
LifetimeSeconds: int64(s.DeploymentValues.MaxTokenLifetime.Value().Seconds()),
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate API key: %w", err)
|
||||
}
|
||||
|
||||
err = server.Database.InTx(func(tx database.Store) error {
|
||||
err = s.Database.InTx(func(tx database.Store) error {
|
||||
err := deleteSessionToken(ctx, tx, workspace)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete session token: %w", err)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,6 +11,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/yamux"
|
||||
|
@ -243,23 +244,33 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||
return
|
||||
}
|
||||
mux := drpcmux.New()
|
||||
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
|
||||
AccessURL: api.AccessURL,
|
||||
GitAuthConfigs: api.GitAuthConfigs,
|
||||
OIDCConfig: api.OIDCConfig,
|
||||
ID: daemon.ID,
|
||||
Database: api.Database,
|
||||
Pubsub: api.Pubsub,
|
||||
Provisioners: daemon.Provisioners,
|
||||
Telemetry: api.Telemetry,
|
||||
Auditor: &api.AGPL.Auditor,
|
||||
TemplateScheduleStore: api.AGPL.TemplateScheduleStore,
|
||||
UserQuietHoursScheduleStore: api.AGPL.UserQuietHoursScheduleStore,
|
||||
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
|
||||
Tags: rawTags,
|
||||
Tracer: trace.NewNoopTracerProvider().Tracer("noop"),
|
||||
DeploymentValues: api.DeploymentValues,
|
||||
})
|
||||
srv, err := provisionerdserver.NewServer(
|
||||
api.AccessURL,
|
||||
daemon.ID,
|
||||
api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
|
||||
daemon.Provisioners,
|
||||
rawTags,
|
||||
api.Database,
|
||||
api.Pubsub,
|
||||
api.Telemetry,
|
||||
trace.NewNoopTracerProvider().Tracer("noop"),
|
||||
&api.AGPL.QuotaCommitter,
|
||||
&api.AGPL.Auditor,
|
||||
api.AGPL.TemplateScheduleStore,
|
||||
api.AGPL.UserQuietHoursScheduleStore,
|
||||
api.DeploymentValues,
|
||||
// TODO(spikecurtis) - fix debounce to not cause flaky tests.
|
||||
time.Duration(0),
|
||||
provisionerdserver.Options{
|
||||
GitAuthConfigs: api.GitAuthConfigs,
|
||||
OIDCConfig: api.OIDCConfig,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("create provisioner daemon server: %s", err))
|
||||
return
|
||||
}
|
||||
err = proto.DRPCRegisterProvisionerDaemon(mux, srv)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err))
|
||||
return
|
||||
|
|
|
@ -9,13 +9,20 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
"github.com/coder/coder/v2/provisionerd"
|
||||
provisionerdproto "github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
@ -212,6 +219,107 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
require.Len(t, daemons, 1)
|
||||
})
|
||||
|
||||
t.Run("PSK_daily_cost", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||
UserWorkspaceQuota: 10,
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||
codersdk.FeatureTemplateRBAC: 1,
|
||||
},
|
||||
},
|
||||
ProvisionerDaemonPSK: "provisionersftw",
|
||||
})
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
terraformClient, terraformServer := provisionersdk.MemTransportPipe()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = terraformClient.Close()
|
||||
_ = terraformServer.Close()
|
||||
}()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
err := echo.Serve(ctx, &provisionersdk.ServeOptions{
|
||||
Listener: terraformServer,
|
||||
Logger: logger.Named("echo"),
|
||||
WorkDirectory: tempDir,
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
provisioners := provisionerd.Provisioners{
|
||||
string(database.ProvisionerTypeEcho): proto.NewDRPCProvisionerClient(terraformClient),
|
||||
}
|
||||
another := codersdk.New(client.URL)
|
||||
pd := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
|
||||
return another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{
|
||||
Organization: user.OrganizationID,
|
||||
Provisioners: []codersdk.ProvisionerType{
|
||||
codersdk.ProvisionerTypeEcho,
|
||||
},
|
||||
Tags: map[string]string{
|
||||
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
|
||||
},
|
||||
PreSharedKey: "provisionersftw",
|
||||
})
|
||||
}, &provisionerd.Options{
|
||||
Logger: logger.Named("provisionerd"),
|
||||
Provisioners: provisioners,
|
||||
})
|
||||
defer pd.Close()
|
||||
|
||||
// Patch the 'Everyone' group to give the user quota to build their workspace.
|
||||
_, err := client.PatchGroup(ctx, user.OrganizationID, codersdk.PatchGroupRequest{
|
||||
QuotaAllowance: ptr.Ref(1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionApply: []*proto.Response{{
|
||||
Type: &proto.Response_Apply{
|
||||
Apply: &proto.ApplyComplete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
DailyCost: 1,
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "example",
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
|
||||
|
||||
err = pd.Shutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
err = terraformServer.Close()
|
||||
require.NoError(t, err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout waiting for server to shut down")
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BadPSK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||
|
|
|
@ -308,6 +308,7 @@ func (p *Server) acquireJob(ctx context.Context) {
|
|||
lastAcquireMutex.RLock()
|
||||
if !lastAcquire.IsZero() && time.Since(lastAcquire) < p.opts.JobPollDebounce {
|
||||
lastAcquireMutex.RUnlock()
|
||||
p.opts.Logger.Debug(ctx, "debounce acquire job")
|
||||
return
|
||||
}
|
||||
lastAcquireMutex.RUnlock()
|
||||
|
@ -319,6 +320,7 @@ func (p *Server) acquireJob(ctx context.Context) {
|
|||
}
|
||||
|
||||
job, err := client.AcquireJob(ctx, &proto.Empty{})
|
||||
p.opts.Logger.Debug(ctx, "called AcquireJob on client", slog.F("job_id", job.GetJobId()), slog.Error(err))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, yamux.ErrSessionShutdown) ||
|
||||
|
|
Loading…
Reference in New Issue