mirror of https://github.com/coder/coder.git
fix: stream provisioner logs (#7712)
* stream provisioner logs Signed-off-by: Spike Curtis <spike@coder.com> * Fix imports Signed-off-by: Spike Curtis <spike@coder.com> * Better logging, naming, arg order Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
583b777251
commit
7c3dbbbe93
|
@ -537,13 +537,13 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
|
|||
// everything from that point.
|
||||
lowestID := logs[0].ID
|
||||
server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
|
||||
data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{
|
||||
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{
|
||||
CreatedAfter: lowestID - 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal: %w", err)
|
||||
}
|
||||
err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(parsedID), data)
|
||||
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish job log: %w", err)
|
||||
|
@ -846,11 +846,11 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
|
|||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
|
||||
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish end of job logs: %w", err)
|
||||
|
@ -1236,11 +1236,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
|||
reflect.TypeOf(completed.Type).String())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
|
||||
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish end of job logs: %w", err)
|
||||
|
@ -1704,19 +1704,6 @@ type TemplateVersionDryRunJob struct {
|
|||
RichParameterValues []database.WorkspaceBuildParameter `json:"rich_parameter_values"`
|
||||
}
|
||||
|
||||
// ProvisionerJobLogsNotifyMessage is the payload published on
|
||||
// the provisioner job logs notify channel.
|
||||
type ProvisionerJobLogsNotifyMessage struct {
|
||||
CreatedAfter int64 `json:"created_after"`
|
||||
EndOfLogs bool `json:"end_of_logs,omitempty"`
|
||||
}
|
||||
|
||||
// ProvisionerJobLogsNotifyChannel is the PostgreSQL NOTIFY channel
|
||||
// to publish updates to job logs on.
|
||||
func ProvisionerJobLogsNotifyChannel(jobID uuid.UUID) string {
|
||||
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
|
||||
}
|
||||
|
||||
func asVariableValues(templateVariables []database.TemplateVersionVariable) []*sdkproto.VariableValue {
|
||||
var apiVariableValues []*sdkproto.VariableValue
|
||||
for _, v := range templateVariables {
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisionerd/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
sdkproto "github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
@ -528,7 +529,7 @@ func TestUpdateJob(t *testing.T) {
|
|||
|
||||
published := make(chan struct{})
|
||||
|
||||
closeListener, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) {
|
||||
closeListener, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) {
|
||||
close(published)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -776,7 +777,7 @@ func TestFailJob(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer closeWorkspaceSubscribe()
|
||||
publishedLogs := make(chan struct{})
|
||||
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
|
||||
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
|
||||
close(publishedLogs)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -1082,7 +1083,7 @@ func TestCompleteJob(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer closeWorkspaceSubscribe()
|
||||
publishedLogs := make(chan struct{})
|
||||
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
|
||||
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
|
||||
close(publishedLogs)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -5,22 +5,23 @@ import (
|
|||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/atomic"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
)
|
||||
|
||||
// Returns provisioner logs based on query parameters.
|
||||
|
@ -32,7 +33,6 @@ import (
|
|||
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
actor, _ = dbauthz.ActorFromContext(ctx)
|
||||
logger = api.Logger.With(slog.F("job_id", job.ID))
|
||||
follow = r.URL.Query().Has("follow")
|
||||
afterRaw = r.URL.Query().Get("after")
|
||||
|
@ -55,129 +55,16 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
|
|||
}
|
||||
|
||||
if !follow {
|
||||
logs, err := api.Database.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: job.ID,
|
||||
CreatedAfter: after,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if logs == nil {
|
||||
logs = []database.ProvisionerJobLog{}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "Finished non-follow job logs")
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertProvisionerJobLogs(logs))
|
||||
fetchAndWriteLogs(ctx, logger, api.Database, job.ID, after, rw)
|
||||
return
|
||||
}
|
||||
|
||||
// if we are following logs, start the subscription before we query the database, so that we don't miss any logs
|
||||
// between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track
|
||||
// of processed IDs.
|
||||
var bufferedLogs <-chan *database.ProvisionerJobLog
|
||||
if follow {
|
||||
bl, closeFollow, err := api.followProvisionerJobLogs(actor, job.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error watching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer closeFollow()
|
||||
bufferedLogs = bl
|
||||
}
|
||||
|
||||
logs, err := api.Database.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: job.ID,
|
||||
CreatedAfter: after,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if logs == nil {
|
||||
logs = []database.ProvisionerJobLog{}
|
||||
}
|
||||
|
||||
follower := newLogFollower(ctx, logger, api.Database, api.Pubsub, rw, r, job, after)
|
||||
api.WebsocketWaitMutex.Lock()
|
||||
api.WebsocketWaitGroup.Add(1)
|
||||
api.WebsocketWaitMutex.Unlock()
|
||||
defer api.WebsocketWaitGroup.Done()
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to accept websocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
go httpapi.Heartbeat(ctx, conn)
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
logIdsDone := make(map[int64]bool)
|
||||
|
||||
// The Go stdlib JSON encoder appends a newline character after message write.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
for _, provisionerJobLog := range logs {
|
||||
logIdsDone[provisionerJobLog.ID] = true
|
||||
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
job, err = api.Database.GetProvisionerJobByID(ctx, job.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner job.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if job.CompletedAt.Valid {
|
||||
// job was complete before we queried the database for historical logs
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debug(context.Background(), "job logs context canceled")
|
||||
return
|
||||
case log, ok := <-bufferedLogs:
|
||||
// A nil log is sent when complete!
|
||||
if !ok || log == nil {
|
||||
logger.Debug(context.Background(), "reached the end of published logs")
|
||||
return
|
||||
}
|
||||
if logIdsDone[log.ID] {
|
||||
logger.Debug(ctx, "subscribe duplicated log",
|
||||
slog.F("stage", log.Stage))
|
||||
} else {
|
||||
logger.Debug(ctx, "subscribe encoding log",
|
||||
slog.F("stage", log.Stage))
|
||||
err = encoder.Encode(convertProvisionerJobLog(*log))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
follower.follow()
|
||||
}
|
||||
|
||||
func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
|
||||
|
@ -334,98 +221,225 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
|
|||
return job
|
||||
}
|
||||
|
||||
func provisionerJobLogsChannel(jobID uuid.UUID) string {
|
||||
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
|
||||
func fetchAndWriteLogs(ctx context.Context, logger slog.Logger, db database.Store, jobID uuid.UUID, after int64, rw http.ResponseWriter) {
|
||||
logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: jobID,
|
||||
CreatedAfter: after,
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if logs == nil {
|
||||
logs = []database.ProvisionerJobLog{}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "Finished non-follow job logs")
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertProvisionerJobLogs(logs))
|
||||
}
|
||||
|
||||
// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel
|
||||
type provisionerJobLogsMessage struct {
|
||||
CreatedAfter int64 `json:"created_after"`
|
||||
EndOfLogs bool `json:"end_of_logs,omitempty"`
|
||||
func jobIsComplete(logger slog.Logger, job database.ProvisionerJob) bool {
|
||||
status := db2sdk.ProvisionerJobStatus(job)
|
||||
switch status {
|
||||
case codersdk.ProvisionerJobCanceled:
|
||||
return true
|
||||
case codersdk.ProvisionerJobFailed:
|
||||
return true
|
||||
case codersdk.ProvisionerJobSucceeded:
|
||||
return true
|
||||
case codersdk.ProvisionerJobPending:
|
||||
return false
|
||||
case codersdk.ProvisionerJobCanceling:
|
||||
return false
|
||||
case codersdk.ProvisionerJobRunning:
|
||||
return false
|
||||
default:
|
||||
logger.Error(context.Background(),
|
||||
"unknown status",
|
||||
slog.F("job_id", job.ID), slog.F("status", status))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) followProvisionerJobLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan *database.ProvisionerJobLog, func(), error) {
|
||||
logger := api.Logger.With(slog.F("job_id", jobID))
|
||||
type logFollower struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
db database.Store
|
||||
pubsub database.Pubsub
|
||||
r *http.Request
|
||||
rw http.ResponseWriter
|
||||
conn *websocket.Conn
|
||||
|
||||
var (
|
||||
// With debug logging enabled length = 128 is insufficient
|
||||
bufferedLogs = make(chan *database.ProvisionerJobLog, 1024)
|
||||
endOfLogs atomic.Bool
|
||||
lastSentLogID atomic.Int64
|
||||
)
|
||||
jobID uuid.UUID
|
||||
after int64
|
||||
complete bool
|
||||
notifications chan provisionersdk.ProvisionerJobLogsNotifyMessage
|
||||
errors chan error
|
||||
}
|
||||
|
||||
sendLog := func(log *database.ProvisionerJobLog) {
|
||||
func newLogFollower(
|
||||
ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub,
|
||||
rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob, after int64,
|
||||
) *logFollower {
|
||||
return &logFollower{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
db: db,
|
||||
pubsub: pubsub,
|
||||
r: r,
|
||||
rw: rw,
|
||||
jobID: job.ID,
|
||||
after: after,
|
||||
complete: jobIsComplete(logger, job),
|
||||
notifications: make(chan provisionersdk.ProvisionerJobLogsNotifyMessage),
|
||||
errors: make(chan error),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *logFollower) follow() {
|
||||
// note that we only need to subscribe to updates if the job is not yet
|
||||
// complete.
|
||||
if !f.complete {
|
||||
subCancel, err := f.pubsub.SubscribeWithErr(
|
||||
provisionersdk.ProvisionerJobLogsNotifyChannel(f.jobID),
|
||||
f.listener,
|
||||
)
|
||||
if err != nil {
|
||||
httpapi.Write(f.ctx, f.rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "failed to subscribe to job updates",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer subCancel()
|
||||
|
||||
// we were provided `complete` prior to starting this subscription, so
|
||||
// we also need to check whether the job is now complete, in case the
|
||||
// job completed between the last time we queried the job and the start
|
||||
// of the subscription. If the job completes after this, we will get
|
||||
// a notification on the subscription.
|
||||
job, err := f.db.GetProvisionerJobByID(f.ctx, f.jobID)
|
||||
if err != nil {
|
||||
httpapi.Write(f.ctx, f.rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "failed to query job",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
f.complete = jobIsComplete(f.logger, job)
|
||||
f.logger.Debug(f.ctx, "queried job after subscribe", slog.F("complete", f.complete))
|
||||
}
|
||||
|
||||
var err error
|
||||
f.conn, err = websocket.Accept(f.rw, f.r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(f.ctx, f.rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to accept websocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer f.conn.Close(websocket.StatusNormalClosure, "done")
|
||||
go httpapi.Heartbeat(f.ctx, f.conn)
|
||||
|
||||
// query for logs once right away, so we can get historical data from before
|
||||
// subscription
|
||||
if err := f.query(); err != nil {
|
||||
if f.ctx.Err() == nil && !xerrors.Is(err, io.EOF) {
|
||||
// neither context expiry, nor EOF, close and log
|
||||
f.logger.Error(f.ctx, "failed to query logs", slog.Error(err))
|
||||
err = f.conn.Close(websocket.StatusInternalError, err.Error())
|
||||
if err != nil {
|
||||
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// no need to wait if the job is done
|
||||
if f.complete {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case bufferedLogs <- log:
|
||||
logger.Debug(context.Background(), "subscribe buffered log", slog.F("stage", log.Stage))
|
||||
lastSentLogID.Store(log.ID)
|
||||
default:
|
||||
// If this overflows users could miss logs streaming. This can happen
|
||||
// we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub,
|
||||
// so just drop them.
|
||||
logger.Warn(context.Background(), "provisioner job log overflowing channel")
|
||||
case err := <-f.errors:
|
||||
// we've dropped at least one notification. This can happen if we
|
||||
// lose database connectivity. We don't know whether the job is
|
||||
// now complete since we could have missed the end of logs message.
|
||||
// We could soldier on and retry, but loss of database connectivity
|
||||
// is fairly serious, so instead just 500 and bail out. Client
|
||||
// can retry and hopefully find a healthier node.
|
||||
f.logger.Error(f.ctx, "dropped or corrupted notification", slog.Error(err))
|
||||
err = f.conn.Close(websocket.StatusInternalError, err.Error())
|
||||
if err != nil {
|
||||
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
|
||||
}
|
||||
return
|
||||
case <-f.ctx.Done():
|
||||
// client disconnect
|
||||
return
|
||||
case n := <-f.notifications:
|
||||
if n.EndOfLogs {
|
||||
// safe to return here because we started the subscription,
|
||||
// and then queried at least once, so we will have already
|
||||
// gotten all logs prior to the start of our subscription.
|
||||
return
|
||||
}
|
||||
err = f.query()
|
||||
if err != nil {
|
||||
if f.ctx.Err() == nil && !xerrors.Is(err, io.EOF) {
|
||||
// neither context expiry, nor EOF, close and log
|
||||
f.logger.Error(f.ctx, "failed to query logs", slog.Error(err))
|
||||
err = f.conn.Close(websocket.StatusInternalError, err.Error())
|
||||
if err != nil {
|
||||
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
closeSubscribe, err := api.Pubsub.Subscribe(
|
||||
provisionerJobLogsChannel(jobID),
|
||||
func(ctx context.Context, message []byte) {
|
||||
if endOfLogs.Load() {
|
||||
return
|
||||
}
|
||||
jlMsg := provisionerJobLogsMessage{}
|
||||
err := json.Unmarshal(message, &jlMsg)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// CreatedAfter is sent when logs are streaming!
|
||||
if jlMsg.CreatedAfter != 0 {
|
||||
logs, err := api.Database.GetProvisionerLogsAfterID(dbauthz.As(ctx, actor), database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: jobID,
|
||||
CreatedAfter: jlMsg.CreatedAfter,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "get provisioner logs", slog.Error(err))
|
||||
return
|
||||
}
|
||||
for _, log := range logs {
|
||||
if endOfLogs.Load() {
|
||||
// An end of logs message came in while we were fetching
|
||||
// logs or processing them!
|
||||
return
|
||||
}
|
||||
log := log
|
||||
sendLog(&log)
|
||||
}
|
||||
}
|
||||
|
||||
// EndOfLogs is sent when logs are done streaming.
|
||||
// We don't want to end the stream until we've sent all the logs,
|
||||
// so we fetch logs after the last ID we've seen and send them!
|
||||
if jlMsg.EndOfLogs {
|
||||
endOfLogs.Store(true)
|
||||
logs, err := api.Database.GetProvisionerLogsAfterID(dbauthz.As(ctx, actor), database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: jobID,
|
||||
CreatedAfter: lastSentLogID.Load(),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "get provisioner logs", slog.Error(err))
|
||||
return
|
||||
}
|
||||
for _, log := range logs {
|
||||
log := log
|
||||
sendLog(&log)
|
||||
}
|
||||
logger.Debug(ctx, "got End of Logs")
|
||||
bufferedLogs <- nil
|
||||
}
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// We don't need to close the bufferedLogs channel because it will be garbage collected!
|
||||
return bufferedLogs, closeSubscribe, nil
|
||||
}
|
||||
|
||||
func (f *logFollower) listener(_ context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
f.errors <- err
|
||||
return
|
||||
}
|
||||
var n provisionersdk.ProvisionerJobLogsNotifyMessage
|
||||
err = json.Unmarshal(message, &n)
|
||||
if err != nil {
|
||||
f.errors <- err
|
||||
return
|
||||
}
|
||||
f.notifications <- n
|
||||
}
|
||||
|
||||
// query fetches the latest job logs from the database and writes them to the
|
||||
// connection.
|
||||
func (f *logFollower) query() error {
|
||||
f.logger.Debug(f.ctx, "querying logs", slog.F("after", f.after))
|
||||
logs, err := f.db.GetProvisionerLogsAfterID(f.ctx, database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: f.jobID,
|
||||
CreatedAfter: f.after,
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("error fetching logs: %w", err)
|
||||
}
|
||||
for _, log := range logs {
|
||||
logB, err := json.Marshal(convertProvisionerJobLog(log))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("error marshaling log: %w", err)
|
||||
}
|
||||
err = f.conn.Write(f.ctx, websocket.MessageText, logB)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("error writing to websocket: %w", err)
|
||||
}
|
||||
f.after = log.ID
|
||||
f.logger.Debug(f.ctx, "wrote log to websocket", slog.F("id", log.ID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,13 +1,28 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbmock"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestConvertProvisionerJob_Unit(t *testing.T) {
|
||||
|
@ -115,3 +130,277 @@ func TestConvertProvisionerJob_Unit(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logFollower_completeBeforeFollow(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
now := database.Now()
|
||||
job := database.ProvisionerJob{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: now.Add(-10 * time.Second),
|
||||
UpdatedAt: now.Add(-10 * time.Second),
|
||||
StartedAt: sql.NullTime{
|
||||
Time: now.Add(-10 * time.Second),
|
||||
Valid: true,
|
||||
},
|
||||
CompletedAt: sql.NullTime{
|
||||
Time: now.Add(-time.Second),
|
||||
Valid: true,
|
||||
},
|
||||
Error: sql.NullString{},
|
||||
}
|
||||
|
||||
// we need an HTTP server to get a websocket
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 10)
|
||||
uut.follow()
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// return some historical logs
|
||||
mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 10)).
|
||||
Times(1).
|
||||
Return(
|
||||
[]database.ProvisionerJobLog{
|
||||
{Stage: "One", Output: "One", ID: 11},
|
||||
{Stage: "One", Output: "Two", ID: 12},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// nolint: bodyclose
|
||||
client, _, err := websocket.Dial(ctx, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
mt, msg, err := client.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, websocket.MessageText, mt)
|
||||
assertLog(t, "One", "One", 11, msg)
|
||||
|
||||
mt, msg, err = client.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, websocket.MessageText, mt)
|
||||
assertLog(t, "One", "Two", 12, msg)
|
||||
|
||||
// server should now close
|
||||
_, _, err = client.Read(ctx)
|
||||
var closeErr websocket.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code)
|
||||
}
|
||||
|
||||
func Test_logFollower_completeBeforeSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
now := database.Now()
|
||||
job := database.ProvisionerJob{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: now.Add(-10 * time.Second),
|
||||
UpdatedAt: now.Add(-10 * time.Second),
|
||||
StartedAt: sql.NullTime{
|
||||
Time: now.Add(-10 * time.Second),
|
||||
Valid: true,
|
||||
},
|
||||
CanceledAt: sql.NullTime{},
|
||||
CompletedAt: sql.NullTime{},
|
||||
Error: sql.NullString{},
|
||||
}
|
||||
|
||||
// we need an HTTP server to get a websocket
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0)
|
||||
uut.follow()
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// job was incomplete when we create the logFollower, but is complete as soon
|
||||
// as it queries again.
|
||||
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), job.ID).Times(1).Return(
|
||||
database.ProvisionerJob{
|
||||
ID: job.ID,
|
||||
CreatedAt: job.CreatedAt,
|
||||
UpdatedAt: job.UpdatedAt,
|
||||
StartedAt: job.StartedAt,
|
||||
CompletedAt: sql.NullTime{
|
||||
Time: now.Add(-time.Millisecond),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// return some historical logs
|
||||
mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 0)).
|
||||
Times(1).
|
||||
Return(
|
||||
[]database.ProvisionerJobLog{
|
||||
{Stage: "One", Output: "One", ID: 1},
|
||||
{Stage: "One", Output: "Two", ID: 2},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// nolint: bodyclose
|
||||
client, _, err := websocket.Dial(ctx, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
mt, msg, err := client.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, websocket.MessageText, mt)
|
||||
assertLog(t, "One", "One", 1, msg)
|
||||
|
||||
mt, msg, err = client.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, websocket.MessageText, mt)
|
||||
assertLog(t, "One", "Two", 2, msg)
|
||||
|
||||
// server should now close
|
||||
_, _, err = client.Read(ctx)
|
||||
var closeErr websocket.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code)
|
||||
}
|
||||
|
||||
func Test_logFollower_EndOfLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
now := database.Now()
|
||||
job := database.ProvisionerJob{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: now.Add(-10 * time.Second),
|
||||
UpdatedAt: now.Add(-10 * time.Second),
|
||||
StartedAt: sql.NullTime{
|
||||
Time: now.Add(-10 * time.Second),
|
||||
Valid: true,
|
||||
},
|
||||
CanceledAt: sql.NullTime{},
|
||||
CompletedAt: sql.NullTime{},
|
||||
Error: sql.NullString{},
|
||||
}
|
||||
|
||||
// we need an HTTP server to get a websocket
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0)
|
||||
uut.follow()
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// job was incomplete when we create the logFollower, and still incomplete when it queries
|
||||
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), job.ID).Times(1).Return(job, nil)
|
||||
|
||||
// return some historical logs
|
||||
q0 := mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 0)).
|
||||
Times(1).
|
||||
Return(
|
||||
[]database.ProvisionerJobLog{
|
||||
{Stage: "One", Output: "One", ID: 1},
|
||||
{Stage: "One", Output: "Two", ID: 2},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
// return some logs from a kick.
|
||||
mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 2)).
|
||||
After(q0).
|
||||
Times(1).
|
||||
Return(
|
||||
[]database.ProvisionerJobLog{
|
||||
{Stage: "One", Output: "Three", ID: 3},
|
||||
{Stage: "Two", Output: "One", ID: 4},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// nolint: bodyclose
|
||||
client, _, err := websocket.Dial(ctx, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
mt, msg, err := client.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, websocket.MessageText, mt)
|
||||
assertLog(t, "One", "One", 1, msg)
|
||||
|
||||
mt, msg, err = client.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, websocket.MessageText, mt)
|
||||
assertLog(t, "One", "Two", 2, msg)
|
||||
|
||||
// send in the kick so follower will query a second time
|
||||
n := provisionersdk.ProvisionerJobLogsNotifyMessage{
|
||||
CreatedAfter: 2,
|
||||
}
|
||||
msg, err = json.Marshal(&n)
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
mt, msg, err = client.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, websocket.MessageText, mt)
|
||||
assertLog(t, "One", "Three", 3, msg)
|
||||
|
||||
mt, msg, err = client.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, websocket.MessageText, mt)
|
||||
assertLog(t, "Two", "One", 4, msg)
|
||||
|
||||
// send EndOfLogs
|
||||
n.EndOfLogs = true
|
||||
n.CreatedAfter = 0
|
||||
msg, err = json.Marshal(&n)
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// server should now close
|
||||
_, _, err = client.Read(ctx)
|
||||
var closeErr websocket.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code)
|
||||
}
|
||||
|
||||
func assertLog(t *testing.T, stage, output string, id int64, msg []byte) {
|
||||
t.Helper()
|
||||
var log codersdk.ProvisionerJobLog
|
||||
err := json.Unmarshal(msg, &log)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, stage, log.Stage)
|
||||
assert.Equal(t, output, log.Output)
|
||||
assert.Equal(t, id, log.ID)
|
||||
}
|
||||
|
||||
type logsAfterMatcher struct {
|
||||
params database.GetProvisionerLogsAfterIDParams
|
||||
}
|
||||
|
||||
func (m *logsAfterMatcher) Matches(x interface{}) bool {
|
||||
p, ok := x.(database.GetProvisionerLogsAfterIDParams)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m.params == p
|
||||
}
|
||||
|
||||
func (m *logsAfterMatcher) String() string {
|
||||
return fmt.Sprintf("%+v", m.params)
|
||||
}
|
||||
|
||||
func matchesJobAfter(jobID uuid.UUID, after int64) gomock.Matcher {
|
||||
return &logsAfterMatcher{
|
||||
params: database.GetProvisionerLogsAfterIDParams{
|
||||
JobID: jobID,
|
||||
CreatedAfter: after,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -132,15 +132,20 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
|
|||
}
|
||||
logs := make(chan ProvisionerJobLog)
|
||||
closed := make(chan struct{})
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
|
||||
decoder := json.NewDecoder(wsNetConn)
|
||||
go func() {
|
||||
defer close(closed)
|
||||
defer close(logs)
|
||||
defer conn.Close(websocket.StatusGoingAway, "")
|
||||
var log ProvisionerJobLog
|
||||
for {
|
||||
err = decoder.Decode(&log)
|
||||
msgType, msg, err := conn.Read(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if msgType != websocket.MessageText {
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(msg, &log)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -152,7 +157,6 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
|
|||
}
|
||||
}()
|
||||
return logs, closeFunc(func() error {
|
||||
_ = wsNetConn.Close()
|
||||
<-closed
|
||||
return nil
|
||||
}), nil
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package provisionersdk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ProvisionerJobLogsNotifyMessage is the payload published on
|
||||
// the provisioner job logs notify channel.
|
||||
type ProvisionerJobLogsNotifyMessage struct {
|
||||
CreatedAfter int64 `json:"created_after"`
|
||||
EndOfLogs bool `json:"end_of_logs,omitempty"`
|
||||
}
|
||||
|
||||
// ProvisionerJobLogsNotifyChannel is the PostgreSQL NOTIFY channel
|
||||
// to publish updates to job logs on.
|
||||
func ProvisionerJobLogsNotifyChannel(jobID uuid.UUID) string {
|
||||
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
|
||||
}
|
Loading…
Reference in New Issue