feat: add provisioner job hang detector (#7927)

This commit is contained in:
Dean Sheather 2023-06-25 23:17:00 +10:00 committed by GitHub
parent 3671846b1b
commit 98a5ae7f48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1414 additions and 54 deletions

View File

@ -78,6 +78,7 @@ import (
"github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/unhanger"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/coderd/util/slice"
"github.com/coder/coder/coderd/workspaceapps"
@ -898,11 +899,17 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("notify systemd: %w", err)
}
autobuildPoller := time.NewTicker(cfg.AutobuildPollInterval.Value())
defer autobuildPoller.Stop()
autobuildExecutor := autobuild.NewExecutor(ctx, options.Database, coderAPI.TemplateScheduleStore, logger, autobuildPoller.C)
autobuildTicker := time.NewTicker(cfg.AutobuildPollInterval.Value())
defer autobuildTicker.Stop()
autobuildExecutor := autobuild.NewExecutor(ctx, options.Database, coderAPI.TemplateScheduleStore, logger, autobuildTicker.C)
autobuildExecutor.Run()
hangDetectorTicker := time.NewTicker(cfg.JobHangDetectorInterval.Value())
defer hangDetectorTicker.Stop()
hangDetector := unhanger.New(ctx, options.Database, options.Pubsub, logger, hangDetectorTicker.C)
hangDetector.Start()
defer hangDetector.Close()
// Currently there is no way to ask the server to shut
// itself down, so any exit signal will result in a non-zero
// exit of the server.

View File

@ -148,6 +148,9 @@ networking:
# Interval to poll for scheduled workspace builds.
# (default: 1m0s, type: duration)
autobuildPollInterval: 1m0s
# Interval to poll for hung jobs and automatically terminate them.
# (default: 1m0s, type: duration)
jobHangDetectorInterval: 1m0s
introspection:
prometheus:
# Serve prometheus metrics on the address defined by prometheus address.

3
coderd/apidoc/docs.go generated
View File

@ -7239,6 +7239,9 @@ const docTemplate = `{
"in_memory_database": {
"type": "boolean"
},
"job_hang_detector_interval": {
"type": "integer"
},
"logging": {
"$ref": "#/definitions/codersdk.LoggingConfig"
},

View File

@ -6468,6 +6468,9 @@
"in_memory_database": {
"type": "boolean"
},
"job_hang_detector_interval": {
"type": "integer"
},
"logging": {
"$ref": "#/definitions/codersdk.LoggingConfig"
},

View File

@ -68,6 +68,7 @@ import (
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/unhanger"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/coderd/workspaceapps"
@ -256,6 +257,12 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
).WithStatsChannel(options.AutobuildStats)
lifecycleExecutor.Run()
hangDetectorTicker := time.NewTicker(options.DeploymentValues.JobHangDetectorInterval.Value())
defer hangDetectorTicker.Stop()
hangDetector := unhanger.New(ctx, options.Database, options.Pubsub, slogtest.Make(t, nil).Named("unhanger.detector"), hangDetectorTicker.C)
hangDetector.Start()
t.Cleanup(hangDetector.Close)
var mutex sync.RWMutex
var handler http.Handler
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -18,13 +18,6 @@ import (
"golang.org/x/xerrors"
)
// Well-known lock IDs for lock functions in the database. These should not
// change. If locks are deprecated, they should be kept to avoid reusing the
// same ID.
const (
LockIDDeploymentSetup = iota + 1
)
// Store contains all queryable database functions.
// It extends the generated interface to add transaction support.
type Store interface {

View File

@ -3,7 +3,6 @@ package db2sdk
import (
"encoding/json"
"time"
"github.com/google/uuid"
@ -81,6 +80,9 @@ func TemplateVersionParameter(param database.TemplateVersionParameter) (codersdk
}
func ProvisionerJobStatus(provisionerJob database.ProvisionerJob) codersdk.ProvisionerJobStatus {
// The case where jobs are hung is handled by the unhang package. We can't
// just return Failed here when it's hung because that doesn't reflect in
// the database.
switch {
case provisionerJob.CanceledAt.Valid:
if !provisionerJob.CompletedAt.Valid {
@ -97,8 +99,6 @@ func ProvisionerJobStatus(provisionerJob database.ProvisionerJob) codersdk.Provi
return codersdk.ProvisionerJobSucceeded
}
return codersdk.ProvisionerJobFailed
case database.Now().Sub(provisionerJob.UpdatedAt) > 30*time.Second:
return codersdk.ProvisionerJobFailed
default:
return codersdk.ProvisionerJobRunning
}

View File

@ -96,17 +96,6 @@ func TestProvisionerJobStatus(t *testing.T) {
},
status: codersdk.ProvisionerJobFailed,
},
{
name: "not_updated",
job: database.ProvisionerJob{
StartedAt: sql.NullTime{
Time: database.Now().Add(-time.Minute),
Valid: true,
},
UpdatedAt: database.Now().Add(-31 * time.Second),
},
status: codersdk.ProvisionerJobFailed,
},
{
name: "updated",
job: database.ProvisionerJob{

View File

@ -176,6 +176,25 @@ var (
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
// See unhanger package.
subjectHangDetector = rbac.Subject{
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Name: "hangdetector",
DisplayName: "Hang Detector Daemon",
Site: rbac.Permissions(map[string][]rbac.Action{
rbac.ResourceSystem.Type: {rbac.WildcardSymbol},
rbac.ResourceTemplate.Type: {rbac.ActionRead},
rbac.ResourceWorkspace.Type: {rbac.ActionRead, rbac.ActionUpdate},
}),
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectSystemRestricted = rbac.Subject{
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@ -217,6 +236,12 @@ func AsAutostart(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectAutostart)
}
// AsHangDetector returns a context with an actor that has permissions required
// for unhanger.Detector to function.
func AsHangDetector(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectHangDetector)
}
// AsSystemRestricted returns a context with an actor that has permissions
// required for various system operations (login, logout, metrics cache).
func AsSystemRestricted(ctx context.Context) context.Context {
@ -950,6 +975,14 @@ func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID
return fetchWithPostFilter(q.auth, q.db.GetGroupsByOrganizationID)(ctx, organizationID)
}
// TODO: We need to create a ProvisionerJob resource type
func (q *querier) GetHungProvisionerJobs(ctx context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) {
// if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
// return nil, err
// }
return q.db.GetHungProvisionerJobs(ctx, hungSince)
}
func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return "", err

View File

@ -1753,6 +1753,19 @@ func (q *fakeQuerier) GetGroupsByOrganizationID(_ context.Context, organizationI
return groups, nil
}
func (q *fakeQuerier) GetHungProvisionerJobs(_ context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
hungJobs := []database.ProvisionerJob{}
for _, provisionerJob := range q.provisionerJobs {
if provisionerJob.StartedAt.Valid && !provisionerJob.CompletedAt.Valid && provisionerJob.UpdatedAt.Before(hungSince) {
hungJobs = append(hungJobs, provisionerJob)
}
}
return hungJobs, nil
}
func (q *fakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -2135,7 +2148,7 @@ func (q *fakeQuerier) GetProvisionerLogsAfterID(_ context.Context, arg database.
if jobLog.JobID != arg.JobID {
continue
}
if arg.CreatedAfter != 0 && jobLog.ID < arg.CreatedAfter {
if jobLog.ID <= arg.CreatedAfter {
continue
}
logs = append(logs, jobLog)

View File

@ -399,6 +399,13 @@ func (m metricsStore) GetGroupsByOrganizationID(ctx context.Context, organizatio
return groups, err
}
func (m metricsStore) GetHungProvisionerJobs(ctx context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) {
start := time.Now()
jobs, err := m.s.GetHungProvisionerJobs(ctx, hungSince)
m.queryLatencies.WithLabelValues("GetHungProvisionerJobs").Observe(time.Since(start).Seconds())
return jobs, err
}
func (m metricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) {
start := time.Now()
version, err := m.s.GetLastUpdateCheck(ctx)

View File

@ -701,6 +701,21 @@ func (mr *MockStoreMockRecorder) GetGroupsByOrganizationID(arg0, arg1 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsByOrganizationID", reflect.TypeOf((*MockStore)(nil).GetGroupsByOrganizationID), arg0, arg1)
}
// GetHungProvisionerJobs mocks base method.
func (m *MockStore) GetHungProvisionerJobs(arg0 context.Context, arg1 time.Time) ([]database.ProvisionerJob, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetHungProvisionerJobs", arg0, arg1)
ret0, _ := ret[0].([]database.ProvisionerJob)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetHungProvisionerJobs indicates an expected call of GetHungProvisionerJobs.
func (mr *MockStoreMockRecorder) GetHungProvisionerJobs(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHungProvisionerJobs", reflect.TypeOf((*MockStore)(nil).GetHungProvisionerJobs), arg0, arg1)
}
// GetLastUpdateCheck mocks base method.
func (m *MockStore) GetLastUpdateCheck(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()

19
coderd/database/lock.go Normal file
View File

@ -0,0 +1,19 @@
package database
import "hash/fnv"
// Well-known lock IDs for lock functions in the database. These should not
// change. If locks are deprecated, they should be kept in this list to avoid
// reusing the same ID.
const (
// Keep the unused iota here so we don't need + 1 every time
lockIDUnused = iota
LockIDDeploymentSetup
)
// GenLockID generates a unique and consistent lock ID from a given string.
func GenLockID(name string) int64 {
hash := fnv.New64()
_, _ = hash.Write([]byte(name))
return int64(hash.Sum64())
}

View File

@ -16,8 +16,6 @@ type sqlcQuerier interface {
//
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
//
// Use database.LockID() to generate a unique lock ID from a string.
AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error
// Acquires the lock for a single job that isn't started, completed,
// canceled, and that matches an array of provisioner types.
@ -75,6 +73,7 @@ type sqlcQuerier interface {
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error)
GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error)
GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error)
GetLastUpdateCheck(ctx context.Context) (string, error)
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
@ -217,8 +216,6 @@ type sqlcQuerier interface {
//
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
//
// Use database.LockID() to generate a unique lock ID from a string.
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLinkParams) (GitAuthLink, error)

View File

@ -1527,8 +1527,6 @@ SELECT pg_advisory_xact_lock($1)
//
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
//
// Use database.LockID() to generate a unique lock ID from a string.
func (q *sqlQuerier) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
_, err := q.db.ExecContext(ctx, acquireLock, pgAdvisoryXactLock)
return err
@ -1542,8 +1540,6 @@ SELECT pg_try_advisory_xact_lock($1)
//
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
//
// Use database.LockID() to generate a unique lock ID from a string.
func (q *sqlQuerier) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
row := q.db.QueryRowContext(ctx, tryAcquireLock, pgTryAdvisoryXactLock)
var pg_try_advisory_xact_lock bool
@ -2201,6 +2197,59 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi
return i, err
}
const getHungProvisionerJobs = `-- name: GetHungProvisionerJobs :many
SELECT
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags, error_code, trace_metadata
FROM
provisioner_jobs
WHERE
updated_at < $1
AND started_at IS NOT NULL
AND completed_at IS NULL
`
func (q *sqlQuerier) GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error) {
rows, err := q.db.QueryContext(ctx, getHungProvisionerJobs, updatedAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ProvisionerJob
for rows.Next() {
var i ProvisionerJob
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.StartedAt,
&i.CanceledAt,
&i.CompletedAt,
&i.Error,
&i.OrganizationID,
&i.InitiatorID,
&i.Provisioner,
&i.StorageMethod,
&i.Type,
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
&i.ErrorCode,
&i.TraceMetadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getProvisionerJobByID = `-- name: GetProvisionerJobByID :one
SELECT
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags, error_code, trace_metadata

View File

@ -3,8 +3,6 @@
--
-- This must be called from within a transaction. The lock will be automatically
-- released when the transaction ends.
--
-- Use database.LockID() to generate a unique lock ID from a string.
SELECT pg_advisory_xact_lock($1);
-- name: TryAcquireLock :one
@ -12,6 +10,4 @@ SELECT pg_advisory_xact_lock($1);
--
-- This must be called from within a transaction. The lock will be automatically
-- released when the transaction ends.
--
-- Use database.LockID() to generate a unique lock ID from a string.
SELECT pg_try_advisory_xact_lock($1);

View File

@ -128,3 +128,13 @@ SET
error_code = $5
WHERE
id = $1;
-- name: GetHungProvisionerJobs :many
SELECT
*
FROM
provisioner_jobs
WHERE
updated_at < $1
AND started_at IS NOT NULL
AND completed_at IS NULL;

363
coderd/unhanger/detector.go Normal file
View File

@ -0,0 +1,363 @@
package unhanger
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math/rand" //#nosec // this is only used for shuffling an array to pick random jobs to unhang
"time"
"golang.org/x/xerrors"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/db2sdk"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
)
const (
// HungJobDuration is the duration of time since the last update to a job
// before it is considered hung.
HungJobDuration = 5 * time.Minute
// HungJobExitTimeout is the duration of time that provisioners should allow
// for a graceful exit upon cancellation due to failing to send an update to
// a job.
//
// Provisioners should avoid keeping a job "running" for longer than this
// time after failing to send an update to the job.
HungJobExitTimeout = 3 * time.Minute
// MaxJobsPerRun is the maximum number of hung jobs that the detector will
// terminate in a single run.
MaxJobsPerRun = 10
)
// HungJobLogMessages are written to provisioner job logs when a job is hung and
// terminated.
var HungJobLogMessages = []string{
"",
"====================",
"Coder: Build has been detected as hung for 5 minutes and will be terminated.",
"====================",
"",
}
// acquireLockError is returned when the detector fails to acquire a lock and
// cancels the current run.
type acquireLockError struct{}
// Error implements error.
func (acquireLockError) Error() string {
return "lock is held by another client"
}
// jobInelligibleError is returned when a job is not eligible to be terminated
// anymore.
type jobInelligibleError struct {
Err error
}
// Error implements error.
func (e jobInelligibleError) Error() string {
return fmt.Sprintf("job is no longer eligible to be terminated: %s", e.Err)
}
// Detector automatically detects hung provisioner jobs, sends messages into the
// build log and terminates them as failed.
type Detector struct {
ctx context.Context
cancel context.CancelFunc
done chan struct{}
db database.Store
pubsub pubsub.Pubsub
log slog.Logger
tick <-chan time.Time
stats chan<- Stats
}
// Stats contains statistics about the last run of the detector.
type Stats struct {
// TerminatedJobIDs contains the IDs of all jobs that were detected as hung and
// terminated.
TerminatedJobIDs []uuid.UUID
// Error is the fatal error that occurred during the last run of the
// detector, if any. Error may be set to AcquireLockError if the detector
// failed to acquire a lock.
Error error
}
// New returns a new hang detector.
func New(ctx context.Context, db database.Store, pub pubsub.Pubsub, log slog.Logger, tick <-chan time.Time) *Detector {
//nolint:gocritic // Hang detector has a limited set of permissions.
ctx, cancel := context.WithCancel(dbauthz.AsHangDetector(ctx))
d := &Detector{
ctx: ctx,
cancel: cancel,
done: make(chan struct{}),
db: db,
pubsub: pub,
log: log,
tick: tick,
stats: nil,
}
return d
}
// WithStatsChannel will cause Executor to push a RunStats to ch after
// every tick. This push is blocking, so if ch is not read, the detector will
// hang. This should only be used in tests.
func (d *Detector) WithStatsChannel(ch chan<- Stats) *Detector {
d.stats = ch
return d
}
// Start will cause the detector to detect and unhang provisioner jobs on every
// tick from its channel. It will stop when its context is Done, or when its
// channel is closed.
//
// Start should only be called once.
func (d *Detector) Start() {
go func() {
defer close(d.done)
defer d.cancel()
for {
select {
case <-d.ctx.Done():
return
case t, ok := <-d.tick:
if !ok {
return
}
stats := d.run(t)
if stats.Error != nil && !xerrors.As(stats.Error, &acquireLockError{}) {
d.log.Warn(d.ctx, "error running workspace build hang detector once", slog.Error(stats.Error))
}
if len(stats.TerminatedJobIDs) != 0 {
d.log.Warn(d.ctx, "detected (and terminated) hung provisioner jobs", slog.F("job_ids", stats.TerminatedJobIDs))
}
if d.stats != nil {
select {
case <-d.ctx.Done():
return
case d.stats <- stats:
}
}
}
}
}()
}
// Wait will block until the detector is stopped.
func (d *Detector) Wait() {
<-d.done
}
// Close will stop the detector.
func (d *Detector) Close() {
d.cancel()
<-d.done
}
func (d *Detector) run(t time.Time) Stats {
ctx, cancel := context.WithTimeout(d.ctx, 5*time.Minute)
defer cancel()
stats := Stats{
TerminatedJobIDs: []uuid.UUID{},
Error: nil,
}
// Find all provisioner jobs that are currently running but have not
// received an update in the last 5 minutes.
jobs, err := d.db.GetHungProvisionerJobs(ctx, t.Add(-HungJobDuration))
if err != nil {
stats.Error = xerrors.Errorf("get hung provisioner jobs: %w", err)
return stats
}
// Limit the number of jobs we'll unhang in a single run to avoid
// timing out.
if len(jobs) > MaxJobsPerRun {
// Pick a random subset of the jobs to unhang.
rand.Shuffle(len(jobs), func(i, j int) {
jobs[i], jobs[j] = jobs[j], jobs[i]
})
jobs = jobs[:MaxJobsPerRun]
}
// Send a message into the build log for each hung job saying that it
// has been detected and will be terminated, then mark the job as
// failed.
for _, job := range jobs {
log := d.log.With(slog.F("job_id", job.ID))
err := unhangJob(ctx, log, d.db, d.pubsub, job.ID)
if err != nil && !(xerrors.As(err, &acquireLockError{}) || xerrors.As(err, &jobInelligibleError{})) {
log.Error(ctx, "error forcefully terminating hung provisioner job", slog.Error(err))
continue
}
stats.TerminatedJobIDs = append(stats.TerminatedJobIDs, job.ID)
}
return stats
}
func unhangJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub.Pubsub, jobID uuid.UUID) error {
var lowestLogID int64
err := db.InTx(func(db database.Store) error {
locked, err := db.TryAcquireLock(ctx, database.GenLockID(fmt.Sprintf("hang-detector:%s", jobID)))
if err != nil {
return xerrors.Errorf("acquire lock: %w", err)
}
if !locked {
// This error is ignored.
return acquireLockError{}
}
// Refetch the job while we hold the lock.
job, err := db.GetProvisionerJobByID(ctx, jobID)
if err != nil {
return xerrors.Errorf("get provisioner job: %w", err)
}
// Check if we should still unhang it.
jobStatus := db2sdk.ProvisionerJobStatus(job)
if jobStatus != codersdk.ProvisionerJobRunning {
return jobInelligibleError{
Err: xerrors.Errorf("job is not running (status %s)", jobStatus),
}
}
if job.UpdatedAt.After(time.Now().Add(-HungJobDuration)) {
return jobInelligibleError{
Err: xerrors.New("job has been updated recently"),
}
}
log.Info(ctx, "detected hung (>5m) provisioner job, forcefully terminating")
// First, get the latest logs from the build so we can make sure
// our messages are in the latest stage.
logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: job.ID,
CreatedAfter: 0,
})
if err != nil {
return xerrors.Errorf("get logs for hung job: %w", err)
}
logStage := ""
if len(logs) != 0 {
logStage = logs[len(logs)-1].Stage
}
if logStage == "" {
logStage = "Unknown"
}
// Insert the messages into the build log.
insertParams := database.InsertProvisionerJobLogsParams{
JobID: job.ID,
}
now := database.Now()
for i, msg := range HungJobLogMessages {
// Set the created at in a way that ensures each message has
// a unique timestamp so they will be sorted correctly.
insertParams.CreatedAt = append(insertParams.CreatedAt, now.Add(time.Millisecond*time.Duration(i)))
insertParams.Level = append(insertParams.Level, database.LogLevelError)
insertParams.Stage = append(insertParams.Stage, logStage)
insertParams.Source = append(insertParams.Source, database.LogSourceProvisionerDaemon)
insertParams.Output = append(insertParams.Output, msg)
}
newLogs, err := db.InsertProvisionerJobLogs(ctx, insertParams)
if err != nil {
return xerrors.Errorf("insert logs for hung job: %w", err)
}
lowestLogID = newLogs[0].ID
// Mark the job as failed.
now = database.Now()
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: job.ID,
UpdatedAt: now,
CompletedAt: sql.NullTime{
Time: now,
Valid: true,
},
Error: sql.NullString{
String: "Coder: Build has been detected as hung for 5 minutes and has been terminated by hang detector.",
Valid: true,
},
ErrorCode: sql.NullString{
Valid: false,
},
})
if err != nil {
return xerrors.Errorf("mark job as failed: %w", err)
}
// If the provisioner job is a workspace build, copy the
// provisioner state from the previous build to this workspace
// build.
if job.Type == database.ProvisionerJobTypeWorkspaceBuild {
build, err := db.GetWorkspaceBuildByJobID(ctx, job.ID)
if err != nil {
return xerrors.Errorf("get workspace build for workspace build job by job id: %w", err)
}
// Only copy the provisioner state if there's no state in
// the current build.
if len(build.ProvisionerState) == 0 {
// Get the previous build if it exists.
prevBuild, err := db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
WorkspaceID: build.WorkspaceID,
BuildNumber: build.BuildNumber - 1,
})
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get previous workspace build: %w", err)
}
if err == nil {
_, err = db.UpdateWorkspaceBuildByID(ctx, database.UpdateWorkspaceBuildByIDParams{
ID: build.ID,
UpdatedAt: database.Now(),
ProvisionerState: prevBuild.ProvisionerState,
Deadline: time.Time{},
MaxDeadline: time.Time{},
})
if err != nil {
return xerrors.Errorf("update workspace build by id: %w", err)
}
}
}
}
return nil
}, nil)
if err != nil {
return xerrors.Errorf("in tx: %w", err)
}
// Publish the new log notification to pubsub. Use the lowest log ID
// inserted so the log stream will fetch everything after that point.
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{
CreatedAfter: lowestLogID - 1,
EndOfLogs: true,
})
if err != nil {
return xerrors.Errorf("marshal log notification: %w", err)
}
err = pub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
if err != nil {
return xerrors.Errorf("publish log notification: %w", err)
}
return nil
}

View File

@ -0,0 +1,724 @@
package unhanger_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/unhanger"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDetectorNoJobs(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- time.Now()
stats := <-statsCh
require.NoError(t, stats.Error)
require.Empty(t, stats.TerminatedJobIDs)
detector.Close()
detector.Wait()
}
func TestDetectorNoHungJobs(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
// Insert some jobs that are running and haven't been updated in a while,
// but not enough to be considered hung.
now := time.Now()
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{})
file := dbgen.File(t, db, database.File{})
for i := 0; i < 5; i++ {
dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: now.Add(-time.Minute * 5),
UpdatedAt: now.Add(-time.Minute * time.Duration(i)),
StartedAt: sql.NullTime{
Time: now.Add(-time.Minute * 5),
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
}
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Empty(t, stats.TerminatedJobIDs)
detector.Close()
detector.Wait()
}
func TestDetectorHungWorkspaceBuild(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
twentyMinAgo = now.Add(-time.Minute * 20)
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
template = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
CreatedBy: user.ID,
})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: template.ID,
})
// Previous build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: twentyMinAgo,
UpdatedAt: twentyMinAgo,
StartedAt: sql.NullTime{
Time: twentyMinAgo,
Valid: true,
},
CompletedAt: sql.NullTime{
Time: twentyMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 1,
ProvisionerState: expectedWorkspaceBuildState,
JobID: previousWorkspaceBuildJob.ID,
})
// Current build.
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 2,
JobID: currentWorkspaceBuildJob.ID,
// No provisioner state.
})
)
t.Log("previous job ID: ", previousWorkspaceBuildJob.ID)
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0])
// Check that the current provisioner job was updated.
job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
// Check that the provisioner state was copied.
build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
}
func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
twentyMinAgo = now.Add(-time.Minute * 20)
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
template = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
CreatedBy: user.ID,
})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: template.ID,
})
// Previous build.
previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: twentyMinAgo,
UpdatedAt: twentyMinAgo,
StartedAt: sql.NullTime{
Time: twentyMinAgo,
Valid: true,
},
CompletedAt: sql.NullTime{
Time: twentyMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 1,
ProvisionerState: []byte(`{"dean":"NOT cool","colin":"also NOT cool"}`),
JobID: previousWorkspaceBuildJob.ID,
})
// Current build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 2,
JobID: currentWorkspaceBuildJob.ID,
// Should not be overridden.
ProvisionerState: expectedWorkspaceBuildState,
})
)
t.Log("previous job ID: ", previousWorkspaceBuildJob.ID)
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0])
// Check that the current provisioner job was updated.
job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
// Check that the provisioner state was NOT copied.
build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
}
func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
template = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
CreatedBy: user.ID,
})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: template.ID,
})
// First build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 1,
JobID: currentWorkspaceBuildJob.ID,
// Should not be overridden.
ProvisionerState: expectedWorkspaceBuildState,
})
)
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0])
// Check that the current provisioner job was updated.
job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
// Check that the provisioner state was NOT updated.
build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
}
func TestDetectorHungOtherJobTypes(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
// Template import job.
templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
// Template dry-run job.
templateDryRunJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: []byte("{}"),
})
)
t.Log("template import job ID: ", templateImportJob.ID)
t.Log("template dry-run job ID: ", templateDryRunJob.ID)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 2)
require.Contains(t, stats.TerminatedJobIDs, templateImportJob.ID)
require.Contains(t, stats.TerminatedJobIDs, templateDryRunJob.ID)
// Check that the template import job was updated.
job, err := db.GetProvisionerJobByID(ctx, templateImportJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
// Check that the template dry-run job was updated.
job, err = db.GetProvisionerJobByID(ctx, templateDryRunJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
detector.Close()
detector.Wait()
}
func TestDetectorPushesLogs(t *testing.T) {
t.Parallel()
cases := []struct {
name string
preLogCount int
preLogStage string
expectStage string
}{
{
name: "WithExistingLogs",
preLogCount: 10,
preLogStage: "Stage Name",
expectStage: "Stage Name",
},
{
name: "WithExistingLogsNoStage",
preLogCount: 10,
preLogStage: "",
expectStage: "Unknown",
},
{
name: "WithoutExistingLogs",
preLogCount: 0,
expectStage: "Unknown",
},
}
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
// Template import job.
templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
)
t.Log("template import job ID: ", templateImportJob.ID)
// Insert some logs at the start of the job.
if c.preLogCount > 0 {
insertParams := database.InsertProvisionerJobLogsParams{
JobID: templateImportJob.ID,
}
for i := 0; i < c.preLogCount; i++ {
insertParams.CreatedAt = append(insertParams.CreatedAt, tenMinAgo.Add(time.Millisecond*time.Duration(i)))
insertParams.Level = append(insertParams.Level, database.LogLevelInfo)
insertParams.Stage = append(insertParams.Stage, c.preLogStage)
insertParams.Source = append(insertParams.Source, database.LogSourceProvisioner)
insertParams.Output = append(insertParams.Output, fmt.Sprintf("Output %d", i))
}
logs, err := db.InsertProvisionerJobLogs(ctx, insertParams)
require.NoError(t, err)
require.Len(t, logs, 10)
}
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
// Create pubsub subscription to listen for new log events.
pubsubCalled := make(chan int64, 1)
pubsubCancel, err := pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(templateImportJob.ID), func(ctx context.Context, message []byte) {
defer close(pubsubCalled)
var event provisionersdk.ProvisionerJobLogsNotifyMessage
err := json.Unmarshal(message, &event)
if !assert.NoError(t, err) {
return
}
assert.True(t, event.EndOfLogs)
pubsubCalled <- event.CreatedAfter
})
require.NoError(t, err)
defer pubsubCancel()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
require.Contains(t, stats.TerminatedJobIDs, templateImportJob.ID)
after := <-pubsubCalled
// Get the jobs after the given time and check that they are what we
// expect.
logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: templateImportJob.ID,
CreatedAfter: after,
})
require.NoError(t, err)
require.Len(t, logs, len(unhanger.HungJobLogMessages))
for i, log := range logs {
assert.Equal(t, database.LogLevelError, log.Level)
assert.Equal(t, c.expectStage, log.Stage)
assert.Equal(t, database.LogSourceProvisionerDaemon, log.Source)
assert.Equal(t, unhanger.HungJobLogMessages[i], log.Output)
}
// Double check the full log count.
logs, err = db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: templateImportJob.ID,
CreatedAfter: 0,
})
require.NoError(t, err)
require.Len(t, logs, c.preLogCount+len(unhanger.HungJobLogMessages))
detector.Close()
detector.Wait()
})
}
}
func TestDetectorMaxJobsPerRun(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
)
// Create unhanger.MaxJobsPerRun + 1 hung jobs.
now := time.Now()
for i := 0; i < unhanger.MaxJobsPerRun+1; i++ {
dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: now.Add(-time.Hour),
UpdatedAt: now.Add(-time.Hour),
StartedAt: sql.NullTime{
Time: now.Add(-time.Hour),
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
}
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
// Make sure that only unhanger.MaxJobsPerRun jobs are terminated.
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, unhanger.MaxJobsPerRun)
// Run the detector again and make sure that only the remaining job is
// terminated.
tickCh <- now
stats = <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
detector.Close()
detector.Wait()
}

View File

@ -124,6 +124,7 @@ type DeploymentValues struct {
// HTTPAddress is a string because it may be set to zero to disable.
HTTPAddress clibase.String `json:"http_address,omitempty" typescript:",notnull"`
AutobuildPollInterval clibase.Duration `json:"autobuild_poll_interval,omitempty"`
JobHangDetectorInterval clibase.Duration `json:"job_hang_detector_interval,omitempty"`
DERP DERP `json:"derp,omitempty" typescript:",notnull"`
Prometheus PrometheusConfig `json:"prometheus,omitempty" typescript:",notnull"`
Pprof PprofConfig `json:"pprof,omitempty" typescript:",notnull"`
@ -539,6 +540,16 @@ when required by your organization's security policy.`,
Value: &c.AutobuildPollInterval,
YAML: "autobuildPollInterval",
},
{
Name: "Job Hang Detector Interval",
Description: "Interval to poll for hung jobs and automatically terminate them.",
Flag: "job-hang-detector-interval",
Env: "CODER_JOB_HANG_DETECTOR_INTERVAL",
Hidden: true,
Default: time.Minute.String(),
Value: &c.JobHangDetectorInterval,
YAML: "jobHangDetectorInterval",
},
httpAddress,
tlsBindAddress,
{

View File

@ -214,6 +214,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
},
"http_address": "string",
"in_memory_database": true,
"job_hang_detector_interval": 0,
"logging": {
"human": "string",
"json": "string",

View File

@ -1891,6 +1891,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
},
"http_address": "string",
"in_memory_database": true,
"job_hang_detector_interval": 0,
"logging": {
"human": "string",
"json": "string",
@ -2221,6 +2222,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
},
"http_address": "string",
"in_memory_database": true,
"job_hang_detector_interval": 0,
"logging": {
"human": "string",
"json": "string",
@ -2400,6 +2402,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
| `git_auth` | [clibase.Struct-array_codersdk_GitAuthConfig](#clibasestruct-array_codersdk_gitauthconfig) | false | | |
| `http_address` | string | false | | Http address is a string because it may be set to zero to disable. |
| `in_memory_database` | boolean | false | | |
| `job_hang_detector_interval` | integer | false | | |
| `logging` | [codersdk.LoggingConfig](#codersdkloggingconfig) | false | | |
| `max_session_expiry` | integer | false | | |
| `max_token_lifetime` | integer | false | | |

View File

@ -49,7 +49,7 @@ func (s *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Create a separate context for forcefull cancellation not tied to
// Create a separate context for forceful cancellation not tied to
// the stream so that we can control when to terminate the process.
killCtx, kill := context.WithCancel(context.Background())
defer kill()
@ -57,13 +57,15 @@ func (s *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
// Ensure processes are eventually cleaned up on graceful
// cancellation or disconnect.
go func() {
<-stream.Context().Done()
<-ctx.Done()
// TODO(mafredri): We should track this provision request as
// part of graceful server shutdown procedure. Waiting on a
// process here should delay provisioner/coder shutdown.
t := time.NewTimer(s.exitTimeout)
defer t.Stop()
select {
case <-time.After(s.exitTimeout):
case <-t.C:
kill()
case <-killCtx.Done():
}

View File

@ -129,8 +129,7 @@ func TestProvision_Cancel(t *testing.T) {
require.NoError(t, err)
ctx, api := setupProvisioner(t, &provisionerServeOptions{
binaryPath: binPath,
exitTimeout: time.Nanosecond,
binaryPath: binPath,
})
response, err := api.Provision(ctx)
@ -186,6 +185,75 @@ func TestProvision_Cancel(t *testing.T) {
}
}
func TestProvision_CancelTimeout(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("This test uses interrupts and is not supported on Windows")
}
cwd, err := os.Getwd()
require.NoError(t, err)
fakeBin := filepath.Join(cwd, "testdata", "fake_cancel_hang.sh")
dir := t.TempDir()
binPath := filepath.Join(dir, "terraform")
// Example: exec /path/to/terrafork_fake_cancel.sh 1.2.1 apply "$@"
content := fmt.Sprintf("#!/bin/sh\nexec %q %s \"$@\"\n", fakeBin, terraform.TerraformVersion.String())
err = os.WriteFile(binPath, []byte(content), 0o755) //#nosec
require.NoError(t, err)
ctx, api := setupProvisioner(t, &provisionerServeOptions{
binaryPath: binPath,
exitTimeout: time.Second,
})
response, err := api.Provision(ctx)
require.NoError(t, err)
err = response.Send(&proto.Provision_Request{
Type: &proto.Provision_Request_Apply{
Apply: &proto.Provision_Apply{
Config: &proto.Provision_Config{
Directory: dir,
Metadata: &proto.Provision_Metadata{},
},
},
},
})
require.NoError(t, err)
for _, line := range []string{"init", "apply_start"} {
LoopStart:
msg, err := response.Recv()
require.NoError(t, err)
t.Log(msg.Type)
log := msg.GetLog()
if log == nil {
goto LoopStart
}
require.Equal(t, line, log.Output)
}
err = response.Send(&proto.Provision_Request{
Type: &proto.Provision_Request_Cancel{
Cancel: &proto.Provision_Cancel{},
},
})
require.NoError(t, err)
for {
msg, err := response.Recv()
require.NoError(t, err)
if c := msg.GetComplete(); c != nil {
require.Contains(t, c.Error, "killed")
break
}
}
}
func TestProvision(t *testing.T) {
t.Parallel()

View File

@ -12,13 +12,10 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/unhanger"
"github.com/coder/coder/provisionersdk"
)
const (
defaultExitTimeout = 5 * time.Minute
)
type ServeOptions struct {
*provisionersdk.ServeOptions
@ -31,14 +28,15 @@ type ServeOptions struct {
Tracer trace.Tracer
// ExitTimeout defines how long we will wait for a running Terraform
// command to exit (cleanly) if the provision was stopped. This only
// happens when the command is still running after the provision
// stream is closed. If the provision is canceled via RPC, this
// timeout will not be used.
// command to exit (cleanly) if the provision was stopped. This
// happens when the provision is canceled via RPC and when the command is
// still running after the provision stream is closed.
//
// This is a no-op on Windows where the process can't be interrupted.
//
// Default value: 5 minutes.
// Default value: 3 minutes (unhanger.HungJobExitTimeout). This value should
// be kept less than the value that Coder uses to mark hung jobs as failed,
// which is 5 minutes (see unhanger package).
ExitTimeout time.Duration
}
@ -96,7 +94,7 @@ func Serve(ctx context.Context, options *ServeOptions) error {
options.Tracer = trace.NewNoopTracerProvider().Tracer("noop")
}
if options.ExitTimeout == 0 {
options.ExitTimeout = defaultExitTimeout
options.ExitTimeout = unhanger.HungJobExitTimeout
}
return provisionersdk.Serve(ctx, &server{
execMut: &sync.Mutex{},

View File

@ -0,0 +1,41 @@
#!/bin/sh
VERSION=$1
shift 1
json_print() {
echo "{\"@level\":\"error\",\"@message\":\"$*\"}"
}
case "$1" in
version)
cat <<-EOF
{
"terraform_version": "${VERSION}",
"platform": "linux_amd64",
"provider_selections": {},
"terraform_outdated": false
}
EOF
exit 0
;;
init)
echo "init"
exit 0
;;
apply)
trap 'json_print interrupt' INT
json_print apply_start
sleep 10 2>/dev/null >/dev/null
json_print apply_end
exit 0
;;
plan)
echo "plan not supported"
exit 1
;;
esac
exit 0

View File

@ -337,6 +337,9 @@ func (r *Runner) sendHeartbeat(ctx context.Context) (*proto.UpdateJobResponse, e
}
func (r *Runner) update(ctx context.Context, u *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
ctx, span := r.startTrace(ctx, tracing.FuncName())
defer span.End()
defer func() {
@ -537,6 +540,7 @@ func (r *Runner) heartbeatRoutine(ctx context.Context) {
resp, err := r.sendHeartbeat(ctx)
if err != nil {
// Calling Fail starts cancellation so the process will exit.
err = r.Fail(ctx, r.failedJobf("send periodic update: %s", err))
if err != nil {
r.logger.Error(ctx, "failed to call FailJob", slog.Error(err))
@ -547,9 +551,9 @@ func (r *Runner) heartbeatRoutine(ctx context.Context) {
ticker.Reset(r.updateInterval)
continue
}
r.logger.Info(ctx, "attempting graceful cancelation")
r.logger.Info(ctx, "attempting graceful cancellation")
r.Cancel()
// Hard-cancel the job after a minute of pending cancelation.
// Mark the job as failed after a minute of pending cancellation.
timer := time.NewTimer(r.forceCancelInterval)
select {
case <-timer.C:

View File

@ -327,6 +327,7 @@ export interface DeploymentValues {
readonly redirect_to_access_url?: boolean
readonly http_address?: string
readonly autobuild_poll_interval?: number
readonly job_hang_detector_interval?: number
readonly derp?: DERP
readonly prometheus?: PrometheusConfig
readonly pprof?: PprofConfig