diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index f3a13e70bb..1896de7ef8 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -537,13 +537,13 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq // everything from that point. lowestID := logs[0].ID server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) - data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{ + data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{ CreatedAfter: lowestID - 1, }) if err != nil { return nil, xerrors.Errorf("marshal: %w", err) } - err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(parsedID), data) + err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data) if err != nil { server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("publish job log: %w", err) @@ -846,11 +846,11 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p } } - data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) + data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } - err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(jobID), data) + err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) if err != nil { server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) return nil, xerrors.Errorf("publish end of job logs: %w", err) @@ -1236,11 +1236,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete reflect.TypeOf(completed.Type).String()) } - data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) + data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } - err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(jobID), data) + err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) if err != nil { server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) return nil, xerrors.Errorf("publish end of job logs: %w", err) @@ -1704,19 +1704,6 @@ type TemplateVersionDryRunJob struct { RichParameterValues []database.WorkspaceBuildParameter `json:"rich_parameter_values"` } -// ProvisionerJobLogsNotifyMessage is the payload published on -// the provisioner job logs notify channel. -type ProvisionerJobLogsNotifyMessage struct { - CreatedAfter int64 `json:"created_after"` - EndOfLogs bool `json:"end_of_logs,omitempty"` -} - -// ProvisionerJobLogsNotifyChannel is the PostgreSQL NOTIFY channel -// to publish updates to job logs on. -func ProvisionerJobLogsNotifyChannel(jobID uuid.UUID) string { - return fmt.Sprintf("provisioner-log-logs:%s", jobID) -} - func asVariableValues(templateVariables []database.TemplateVersionVariable) []*sdkproto.VariableValue { var apiVariableValues []*sdkproto.VariableValue for _, v := range templateVariables { diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index c848913fb6..385ad03c2e 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -27,6 +27,7 @@ import ( "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" sdkproto "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/testutil" ) @@ -528,7 +529,7 @@ func TestUpdateJob(t *testing.T) { published := make(chan struct{}) - closeListener, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) { + closeListener, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) { close(published) }) require.NoError(t, err) @@ -776,7 +777,7 @@ func TestFailJob(t *testing.T) { require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) - closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { + closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { close(publishedLogs) }) require.NoError(t, err) @@ -1082,7 +1083,7 @@ func TestCompleteJob(t *testing.T) { require.NoError(t, err) defer closeWorkspaceSubscribe() publishedLogs := make(chan struct{}) - closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { + closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) { close(publishedLogs) }) require.NoError(t, err) diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index f92686c0db..ba6e8cb5ce 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -5,22 +5,23 @@ import ( "database/sql" "encoding/json" "errors" - "fmt" + "io" "net/http" "sort" "strconv" "github.com/google/uuid" - "go.uber.org/atomic" + "golang.org/x/xerrors" "nhooyr.io/websocket" "cdr.dev/slog" + "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/db2sdk" "github.com/coder/coder/coderd/database/dbauthz" "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisionersdk" ) // Returns provisioner logs based on query parameters. @@ -32,7 +33,6 @@ import ( func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) { var ( ctx = r.Context() - actor, _ = dbauthz.ActorFromContext(ctx) logger = api.Logger.With(slog.F("job_id", job.ID)) follow = r.URL.Query().Has("follow") afterRaw = r.URL.Query().Get("after") @@ -55,129 +55,16 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job } if !follow { - logs, err := api.Database.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{ - JobID: job.ID, - CreatedAfter: after, - }) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching provisioner logs.", - Detail: err.Error(), - }) - return - } - if logs == nil { - logs = []database.ProvisionerJobLog{} - } - - logger.Debug(ctx, "Finished non-follow job logs") - httpapi.Write(ctx, rw, http.StatusOK, convertProvisionerJobLogs(logs)) + fetchAndWriteLogs(ctx, logger, api.Database, job.ID, after, rw) return } - // if we are following logs, start the subscription before we query the database, so that we don't miss any logs - // between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track - // of processed IDs. - var bufferedLogs <-chan *database.ProvisionerJobLog - if follow { - bl, closeFollow, err := api.followProvisionerJobLogs(actor, job.ID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error watching provisioner logs.", - Detail: err.Error(), - }) - return - } - defer closeFollow() - bufferedLogs = bl - } - - logs, err := api.Database.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{ - JobID: job.ID, - CreatedAfter: after, - }) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching provisioner logs.", - Detail: err.Error(), - }) - return - } - if logs == nil { - logs = []database.ProvisionerJobLog{} - } - + follower := newLogFollower(ctx, logger, api.Database, api.Pubsub, rw, r, job, after) api.WebsocketWaitMutex.Lock() api.WebsocketWaitGroup.Add(1) api.WebsocketWaitMutex.Unlock() defer api.WebsocketWaitGroup.Done() - conn, err := websocket.Accept(rw, r, nil) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - go httpapi.Heartbeat(ctx, conn) - - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) - defer wsNetConn.Close() // Also closes conn. - - logIdsDone := make(map[int64]bool) - - // The Go stdlib JSON encoder appends a newline character after message write. - encoder := json.NewEncoder(wsNetConn) - for _, provisionerJobLog := range logs { - logIdsDone[provisionerJobLog.ID] = true - err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog)) - if err != nil { - return - } - } - job, err = api.Database.GetProvisionerJobByID(ctx, job.ID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching provisioner job.", - Detail: err.Error(), - }) - return - } - if job.CompletedAt.Valid { - // job was complete before we queried the database for historical logs - return - } - - for { - select { - case <-ctx.Done(): - logger.Debug(context.Background(), "job logs context canceled") - return - case log, ok := <-bufferedLogs: - // A nil log is sent when complete! - if !ok || log == nil { - logger.Debug(context.Background(), "reached the end of published logs") - return - } - if logIdsDone[log.ID] { - logger.Debug(ctx, "subscribe duplicated log", - slog.F("stage", log.Stage)) - } else { - logger.Debug(ctx, "subscribe encoding log", - slog.F("stage", log.Stage)) - err = encoder.Encode(convertProvisionerJobLog(*log)) - if err != nil { - return - } - } - } - } + follower.follow() } func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) { @@ -334,98 +221,225 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov return job } -func provisionerJobLogsChannel(jobID uuid.UUID) string { - return fmt.Sprintf("provisioner-log-logs:%s", jobID) +func fetchAndWriteLogs(ctx context.Context, logger slog.Logger, db database.Store, jobID uuid.UUID, after int64, rw http.ResponseWriter) { + logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{ + JobID: jobID, + CreatedAfter: after, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching provisioner logs.", + Detail: err.Error(), + }) + return + } + if logs == nil { + logs = []database.ProvisionerJobLog{} + } + + logger.Debug(ctx, "Finished non-follow job logs") + httpapi.Write(ctx, rw, http.StatusOK, convertProvisionerJobLogs(logs)) } -// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel -type provisionerJobLogsMessage struct { - CreatedAfter int64 `json:"created_after"` - EndOfLogs bool `json:"end_of_logs,omitempty"` +func jobIsComplete(logger slog.Logger, job database.ProvisionerJob) bool { + status := db2sdk.ProvisionerJobStatus(job) + switch status { + case codersdk.ProvisionerJobCanceled: + return true + case codersdk.ProvisionerJobFailed: + return true + case codersdk.ProvisionerJobSucceeded: + return true + case codersdk.ProvisionerJobPending: + return false + case codersdk.ProvisionerJobCanceling: + return false + case codersdk.ProvisionerJobRunning: + return false + default: + logger.Error(context.Background(), + "unknown status", + slog.F("job_id", job.ID), slog.F("status", status)) + return false + } } -func (api *API) followProvisionerJobLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan *database.ProvisionerJobLog, func(), error) { - logger := api.Logger.With(slog.F("job_id", jobID)) +type logFollower struct { + ctx context.Context + logger slog.Logger + db database.Store + pubsub database.Pubsub + r *http.Request + rw http.ResponseWriter + conn *websocket.Conn - var ( - // With debug logging enabled length = 128 is insufficient - bufferedLogs = make(chan *database.ProvisionerJobLog, 1024) - endOfLogs atomic.Bool - lastSentLogID atomic.Int64 - ) + jobID uuid.UUID + after int64 + complete bool + notifications chan provisionersdk.ProvisionerJobLogsNotifyMessage + errors chan error +} - sendLog := func(log *database.ProvisionerJobLog) { +func newLogFollower( + ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, + rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob, after int64, +) *logFollower { + return &logFollower{ + ctx: ctx, + logger: logger, + db: db, + pubsub: pubsub, + r: r, + rw: rw, + jobID: job.ID, + after: after, + complete: jobIsComplete(logger, job), + notifications: make(chan provisionersdk.ProvisionerJobLogsNotifyMessage), + errors: make(chan error), + } +} + +func (f *logFollower) follow() { + // note that we only need to subscribe to updates if the job is not yet + // complete. + if !f.complete { + subCancel, err := f.pubsub.SubscribeWithErr( + provisionersdk.ProvisionerJobLogsNotifyChannel(f.jobID), + f.listener, + ) + if err != nil { + httpapi.Write(f.ctx, f.rw, http.StatusInternalServerError, codersdk.Response{ + Message: "failed to subscribe to job updates", + Detail: err.Error(), + }) + return + } + defer subCancel() + + // we were provided `complete` prior to starting this subscription, so + // we also need to check whether the job is now complete, in case the + // job completed between the last time we queried the job and the start + // of the subscription. If the job completes after this, we will get + // a notification on the subscription. + job, err := f.db.GetProvisionerJobByID(f.ctx, f.jobID) + if err != nil { + httpapi.Write(f.ctx, f.rw, http.StatusInternalServerError, codersdk.Response{ + Message: "failed to query job", + Detail: err.Error(), + }) + return + } + f.complete = jobIsComplete(f.logger, job) + f.logger.Debug(f.ctx, "queried job after subscribe", slog.F("complete", f.complete)) + } + + var err error + f.conn, err = websocket.Accept(f.rw, f.r, nil) + if err != nil { + httpapi.Write(f.ctx, f.rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to accept websocket.", + Detail: err.Error(), + }) + return + } + defer f.conn.Close(websocket.StatusNormalClosure, "done") + go httpapi.Heartbeat(f.ctx, f.conn) + + // query for logs once right away, so we can get historical data from before + // subscription + if err := f.query(); err != nil { + if f.ctx.Err() == nil && !xerrors.Is(err, io.EOF) { + // neither context expiry, nor EOF, close and log + f.logger.Error(f.ctx, "failed to query logs", slog.Error(err)) + err = f.conn.Close(websocket.StatusInternalError, err.Error()) + if err != nil { + f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err)) + } + } + return + } + + // no need to wait if the job is done + if f.complete { + return + } + for { select { - case bufferedLogs <- log: - logger.Debug(context.Background(), "subscribe buffered log", slog.F("stage", log.Stage)) - lastSentLogID.Store(log.ID) - default: - // If this overflows users could miss logs streaming. This can happen - // we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub, - // so just drop them. - logger.Warn(context.Background(), "provisioner job log overflowing channel") + case err := <-f.errors: + // we've dropped at least one notification. This can happen if we + // lose database connectivity. We don't know whether the job is + // now complete since we could have missed the end of logs message. + // We could soldier on and retry, but loss of database connectivity + // is fairly serious, so instead just 500 and bail out. Client + // can retry and hopefully find a healthier node. + f.logger.Error(f.ctx, "dropped or corrupted notification", slog.Error(err)) + err = f.conn.Close(websocket.StatusInternalError, err.Error()) + if err != nil { + f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err)) + } + return + case <-f.ctx.Done(): + // client disconnect + return + case n := <-f.notifications: + if n.EndOfLogs { + // safe to return here because we started the subscription, + // and then queried at least once, so we will have already + // gotten all logs prior to the start of our subscription. + return + } + err = f.query() + if err != nil { + if f.ctx.Err() == nil && !xerrors.Is(err, io.EOF) { + // neither context expiry, nor EOF, close and log + f.logger.Error(f.ctx, "failed to query logs", slog.Error(err)) + err = f.conn.Close(websocket.StatusInternalError, err.Error()) + if err != nil { + f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err)) + } + } + return + } } } - - closeSubscribe, err := api.Pubsub.Subscribe( - provisionerJobLogsChannel(jobID), - func(ctx context.Context, message []byte) { - if endOfLogs.Load() { - return - } - jlMsg := provisionerJobLogsMessage{} - err := json.Unmarshal(message, &jlMsg) - if err != nil { - logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err)) - return - } - - // CreatedAfter is sent when logs are streaming! - if jlMsg.CreatedAfter != 0 { - logs, err := api.Database.GetProvisionerLogsAfterID(dbauthz.As(ctx, actor), database.GetProvisionerLogsAfterIDParams{ - JobID: jobID, - CreatedAfter: jlMsg.CreatedAfter, - }) - if err != nil { - logger.Warn(ctx, "get provisioner logs", slog.Error(err)) - return - } - for _, log := range logs { - if endOfLogs.Load() { - // An end of logs message came in while we were fetching - // logs or processing them! - return - } - log := log - sendLog(&log) - } - } - - // EndOfLogs is sent when logs are done streaming. - // We don't want to end the stream until we've sent all the logs, - // so we fetch logs after the last ID we've seen and send them! - if jlMsg.EndOfLogs { - endOfLogs.Store(true) - logs, err := api.Database.GetProvisionerLogsAfterID(dbauthz.As(ctx, actor), database.GetProvisionerLogsAfterIDParams{ - JobID: jobID, - CreatedAfter: lastSentLogID.Load(), - }) - if err != nil { - logger.Warn(ctx, "get provisioner logs", slog.Error(err)) - return - } - for _, log := range logs { - log := log - sendLog(&log) - } - logger.Debug(ctx, "got End of Logs") - bufferedLogs <- nil - } - }, - ) - if err != nil { - return nil, nil, err - } - // We don't need to close the bufferedLogs channel because it will be garbage collected! - return bufferedLogs, closeSubscribe, nil +} + +func (f *logFollower) listener(_ context.Context, message []byte, err error) { + if err != nil { + f.errors <- err + return + } + var n provisionersdk.ProvisionerJobLogsNotifyMessage + err = json.Unmarshal(message, &n) + if err != nil { + f.errors <- err + return + } + f.notifications <- n +} + +// query fetches the latest job logs from the database and writes them to the +// connection. +func (f *logFollower) query() error { + f.logger.Debug(f.ctx, "querying logs", slog.F("after", f.after)) + logs, err := f.db.GetProvisionerLogsAfterID(f.ctx, database.GetProvisionerLogsAfterIDParams{ + JobID: f.jobID, + CreatedAfter: f.after, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return xerrors.Errorf("error fetching logs: %w", err) + } + for _, log := range logs { + logB, err := json.Marshal(convertProvisionerJobLog(log)) + if err != nil { + return xerrors.Errorf("error marshaling log: %w", err) + } + err = f.conn.Write(f.ctx, websocket.MessageText, logB) + if err != nil { + return xerrors.Errorf("error writing to websocket: %w", err) + } + f.after = log.ID + f.logger.Debug(f.ctx, "wrote log to websocket", slog.F("id", log.ID)) + } + return nil } diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index 54512e5da5..ee34e45105 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -1,13 +1,28 @@ package coderd import ( + "context" "database/sql" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" "testing" + "time" + "github.com/golang/mock/gomock" + "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "nhooyr.io/websocket" + + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbmock" "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/testutil" ) func TestConvertProvisionerJob_Unit(t *testing.T) { @@ -115,3 +130,277 @@ func TestConvertProvisionerJob_Unit(t *testing.T) { }) } } + +func Test_logFollower_completeBeforeFollow(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + pubsub := database.NewPubsubInMemory() + now := database.Now() + job := database.ProvisionerJob{ + ID: uuid.New(), + CreatedAt: now.Add(-10 * time.Second), + UpdatedAt: now.Add(-10 * time.Second), + StartedAt: sql.NullTime{ + Time: now.Add(-10 * time.Second), + Valid: true, + }, + CompletedAt: sql.NullTime{ + Time: now.Add(-time.Second), + Valid: true, + }, + Error: sql.NullString{}, + } + + // we need an HTTP server to get a websocket + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 10) + uut.follow() + })) + defer srv.Close() + + // return some historical logs + mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 10)). + Times(1). + Return( + []database.ProvisionerJobLog{ + {Stage: "One", Output: "One", ID: 11}, + {Stage: "One", Output: "Two", ID: 12}, + }, + nil, + ) + + // nolint: bodyclose + client, _, err := websocket.Dial(ctx, srv.URL, nil) + require.NoError(t, err) + mt, msg, err := client.Read(ctx) + require.NoError(t, err) + assert.Equal(t, websocket.MessageText, mt) + assertLog(t, "One", "One", 11, msg) + + mt, msg, err = client.Read(ctx) + require.NoError(t, err) + assert.Equal(t, websocket.MessageText, mt) + assertLog(t, "One", "Two", 12, msg) + + // server should now close + _, _, err = client.Read(ctx) + var closeErr websocket.CloseError + require.ErrorAs(t, err, &closeErr) + assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code) +} + +func Test_logFollower_completeBeforeSubscribe(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + pubsub := database.NewPubsubInMemory() + now := database.Now() + job := database.ProvisionerJob{ + ID: uuid.New(), + CreatedAt: now.Add(-10 * time.Second), + UpdatedAt: now.Add(-10 * time.Second), + StartedAt: sql.NullTime{ + Time: now.Add(-10 * time.Second), + Valid: true, + }, + CanceledAt: sql.NullTime{}, + CompletedAt: sql.NullTime{}, + Error: sql.NullString{}, + } + + // we need an HTTP server to get a websocket + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0) + uut.follow() + })) + defer srv.Close() + + // job was incomplete when we create the logFollower, but is complete as soon + // as it queries again. + mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), job.ID).Times(1).Return( + database.ProvisionerJob{ + ID: job.ID, + CreatedAt: job.CreatedAt, + UpdatedAt: job.UpdatedAt, + StartedAt: job.StartedAt, + CompletedAt: sql.NullTime{ + Time: now.Add(-time.Millisecond), + Valid: true, + }, + }, + nil, + ) + + // return some historical logs + mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 0)). + Times(1). + Return( + []database.ProvisionerJobLog{ + {Stage: "One", Output: "One", ID: 1}, + {Stage: "One", Output: "Two", ID: 2}, + }, + nil, + ) + + // nolint: bodyclose + client, _, err := websocket.Dial(ctx, srv.URL, nil) + require.NoError(t, err) + mt, msg, err := client.Read(ctx) + require.NoError(t, err) + assert.Equal(t, websocket.MessageText, mt) + assertLog(t, "One", "One", 1, msg) + + mt, msg, err = client.Read(ctx) + require.NoError(t, err) + assert.Equal(t, websocket.MessageText, mt) + assertLog(t, "One", "Two", 2, msg) + + // server should now close + _, _, err = client.Read(ctx) + var closeErr websocket.CloseError + require.ErrorAs(t, err, &closeErr) + assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code) +} + +func Test_logFollower_EndOfLogs(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + logger := slogtest.Make(t, nil) + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + pubsub := database.NewPubsubInMemory() + now := database.Now() + job := database.ProvisionerJob{ + ID: uuid.New(), + CreatedAt: now.Add(-10 * time.Second), + UpdatedAt: now.Add(-10 * time.Second), + StartedAt: sql.NullTime{ + Time: now.Add(-10 * time.Second), + Valid: true, + }, + CanceledAt: sql.NullTime{}, + CompletedAt: sql.NullTime{}, + Error: sql.NullString{}, + } + + // we need an HTTP server to get a websocket + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0) + uut.follow() + })) + defer srv.Close() + + // job was incomplete when we create the logFollower, and still incomplete when it queries + mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), job.ID).Times(1).Return(job, nil) + + // return some historical logs + q0 := mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 0)). + Times(1). + Return( + []database.ProvisionerJobLog{ + {Stage: "One", Output: "One", ID: 1}, + {Stage: "One", Output: "Two", ID: 2}, + }, + nil, + ) + // return some logs from a kick. + mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 2)). + After(q0). + Times(1). + Return( + []database.ProvisionerJobLog{ + {Stage: "One", Output: "Three", ID: 3}, + {Stage: "Two", Output: "One", ID: 4}, + }, + nil, + ) + + // nolint: bodyclose + client, _, err := websocket.Dial(ctx, srv.URL, nil) + require.NoError(t, err) + mt, msg, err := client.Read(ctx) + require.NoError(t, err) + assert.Equal(t, websocket.MessageText, mt) + assertLog(t, "One", "One", 1, msg) + + mt, msg, err = client.Read(ctx) + require.NoError(t, err) + assert.Equal(t, websocket.MessageText, mt) + assertLog(t, "One", "Two", 2, msg) + + // send in the kick so follower will query a second time + n := provisionersdk.ProvisionerJobLogsNotifyMessage{ + CreatedAfter: 2, + } + msg, err = json.Marshal(&n) + require.NoError(t, err) + err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg) + require.NoError(t, err) + + mt, msg, err = client.Read(ctx) + require.NoError(t, err) + assert.Equal(t, websocket.MessageText, mt) + assertLog(t, "One", "Three", 3, msg) + + mt, msg, err = client.Read(ctx) + require.NoError(t, err) + assert.Equal(t, websocket.MessageText, mt) + assertLog(t, "Two", "One", 4, msg) + + // send EndOfLogs + n.EndOfLogs = true + n.CreatedAfter = 0 + msg, err = json.Marshal(&n) + require.NoError(t, err) + err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg) + require.NoError(t, err) + + // server should now close + _, _, err = client.Read(ctx) + var closeErr websocket.CloseError + require.ErrorAs(t, err, &closeErr) + assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code) +} + +func assertLog(t *testing.T, stage, output string, id int64, msg []byte) { + t.Helper() + var log codersdk.ProvisionerJobLog + err := json.Unmarshal(msg, &log) + require.NoError(t, err) + assert.Equal(t, stage, log.Stage) + assert.Equal(t, output, log.Output) + assert.Equal(t, id, log.ID) +} + +type logsAfterMatcher struct { + params database.GetProvisionerLogsAfterIDParams +} + +func (m *logsAfterMatcher) Matches(x interface{}) bool { + p, ok := x.(database.GetProvisionerLogsAfterIDParams) + if !ok { + return false + } + return m.params == p +} + +func (m *logsAfterMatcher) String() string { + return fmt.Sprintf("%+v", m.params) +} + +func matchesJobAfter(jobID uuid.UUID, after int64) gomock.Matcher { + return &logsAfterMatcher{ + params: database.GetProvisionerLogsAfterIDParams{ + JobID: jobID, + CreatedAfter: after, + }, + } +} diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 0479d05ae5..040a0318ab 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -132,15 +132,20 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after } logs := make(chan ProvisionerJobLog) closed := make(chan struct{}) - ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText) - decoder := json.NewDecoder(wsNetConn) go func() { defer close(closed) defer close(logs) defer conn.Close(websocket.StatusGoingAway, "") var log ProvisionerJobLog for { - err = decoder.Decode(&log) + msgType, msg, err := conn.Read(ctx) + if err != nil { + return + } + if msgType != websocket.MessageText { + return + } + err = json.Unmarshal(msg, &log) if err != nil { return } @@ -152,7 +157,6 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after } }() return logs, closeFunc(func() error { - _ = wsNetConn.Close() <-closed return nil }), nil diff --git a/provisionersdk/logs.go b/provisionersdk/logs.go new file mode 100644 index 0000000000..e87a4ad3d1 --- /dev/null +++ b/provisionersdk/logs.go @@ -0,0 +1,20 @@ +package provisionersdk + +import ( + "fmt" + + "github.com/google/uuid" +) + +// ProvisionerJobLogsNotifyMessage is the payload published on +// the provisioner job logs notify channel. +type ProvisionerJobLogsNotifyMessage struct { + CreatedAfter int64 `json:"created_after"` + EndOfLogs bool `json:"end_of_logs,omitempty"` +} + +// ProvisionerJobLogsNotifyChannel is the PostgreSQL NOTIFY channel +// to publish updates to job logs on. +func ProvisionerJobLogsNotifyChannel(jobID uuid.UUID) string { + return fmt.Sprintf("provisioner-log-logs:%s", jobID) +}