fix: stream provisioner logs (#7712)

* stream provisioner logs

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix imports

Signed-off-by: Spike Curtis <spike@coder.com>

* Better logging, naming, arg order

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2023-05-31 10:15:58 +04:00 committed by GitHub
parent 583b777251
commit 7c3dbbbe93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 546 additions and 231 deletions

View File

@ -537,13 +537,13 @@ func (server *Server) UpdateJob(ctx context.Context, request *proto.UpdateJobReq
// everything from that point.
lowestID := logs[0].ID
server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{
CreatedAfter: lowestID - 1,
})
if err != nil {
return nil, xerrors.Errorf("marshal: %w", err)
}
err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(parsedID), data)
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err))
return nil, xerrors.Errorf("publish job log: %w", err)
@ -846,11 +846,11 @@ func (server *Server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*p
}
}
data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(jobID), data)
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
return nil, xerrors.Errorf("publish end of job logs: %w", err)
@ -1236,11 +1236,11 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
reflect.TypeOf(completed.Type).String())
}
data, err := json.Marshal(ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(ProvisionerJobLogsNotifyChannel(jobID), data)
err = server.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
return nil, xerrors.Errorf("publish end of job logs: %w", err)
@ -1704,19 +1704,6 @@ type TemplateVersionDryRunJob struct {
RichParameterValues []database.WorkspaceBuildParameter `json:"rich_parameter_values"`
}
// ProvisionerJobLogsNotifyMessage is the payload published on
// the provisioner job logs notify channel.
type ProvisionerJobLogsNotifyMessage struct {
CreatedAfter int64 `json:"created_after"`
EndOfLogs bool `json:"end_of_logs,omitempty"`
}
// ProvisionerJobLogsNotifyChannel is the PostgreSQL NOTIFY channel
// to publish updates to job logs on.
func ProvisionerJobLogsNotifyChannel(jobID uuid.UUID) string {
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
}
func asVariableValues(templateVariables []database.TemplateVersionVariable) []*sdkproto.VariableValue {
var apiVariableValues []*sdkproto.VariableValue
for _, v := range templateVariables {

View File

@ -27,6 +27,7 @@ import (
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
)
@ -528,7 +529,7 @@ func TestUpdateJob(t *testing.T) {
published := make(chan struct{})
closeListener, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) {
closeListener, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job), func(_ context.Context, _ []byte) {
close(published)
})
require.NoError(t, err)
@ -776,7 +777,7 @@ func TestFailJob(t *testing.T) {
require.NoError(t, err)
defer closeWorkspaceSubscribe()
publishedLogs := make(chan struct{})
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
close(publishedLogs)
})
require.NoError(t, err)
@ -1082,7 +1083,7 @@ func TestCompleteJob(t *testing.T) {
require.NoError(t, err)
defer closeWorkspaceSubscribe()
publishedLogs := make(chan struct{})
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionerdserver.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
closeLogsSubscribe, err := srv.Pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), func(_ context.Context, _ []byte) {
close(publishedLogs)
})
require.NoError(t, err)

View File

@ -5,22 +5,23 @@ import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sort"
"strconv"
"github.com/google/uuid"
"go.uber.org/atomic"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/db2sdk"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
)
// Returns provisioner logs based on query parameters.
@ -32,7 +33,6 @@ import (
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
var (
ctx = r.Context()
actor, _ = dbauthz.ActorFromContext(ctx)
logger = api.Logger.With(slog.F("job_id", job.ID))
follow = r.URL.Query().Has("follow")
afterRaw = r.URL.Query().Get("after")
@ -55,129 +55,16 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
}
if !follow {
logs, err := api.Database.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: job.ID,
CreatedAfter: after,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
return
}
if logs == nil {
logs = []database.ProvisionerJobLog{}
}
logger.Debug(ctx, "Finished non-follow job logs")
httpapi.Write(ctx, rw, http.StatusOK, convertProvisionerJobLogs(logs))
fetchAndWriteLogs(ctx, logger, api.Database, job.ID, after, rw)
return
}
// if we are following logs, start the subscription before we query the database, so that we don't miss any logs
// between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track
// of processed IDs.
var bufferedLogs <-chan *database.ProvisionerJobLog
if follow {
bl, closeFollow, err := api.followProvisionerJobLogs(actor, job.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error watching provisioner logs.",
Detail: err.Error(),
})
return
}
defer closeFollow()
bufferedLogs = bl
}
logs, err := api.Database.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: job.ID,
CreatedAfter: after,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
return
}
if logs == nil {
logs = []database.ProvisionerJobLog{}
}
follower := newLogFollower(ctx, logger, api.Database, api.Pubsub, rw, r, job, after)
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
go httpapi.Heartbeat(ctx, conn)
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
defer wsNetConn.Close() // Also closes conn.
logIdsDone := make(map[int64]bool)
// The Go stdlib JSON encoder appends a newline character after message write.
encoder := json.NewEncoder(wsNetConn)
for _, provisionerJobLog := range logs {
logIdsDone[provisionerJobLog.ID] = true
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
if err != nil {
return
}
}
job, err = api.Database.GetProvisionerJobByID(ctx, job.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
Detail: err.Error(),
})
return
}
if job.CompletedAt.Valid {
// job was complete before we queried the database for historical logs
return
}
for {
select {
case <-ctx.Done():
logger.Debug(context.Background(), "job logs context canceled")
return
case log, ok := <-bufferedLogs:
// A nil log is sent when complete!
if !ok || log == nil {
logger.Debug(context.Background(), "reached the end of published logs")
return
}
if logIdsDone[log.ID] {
logger.Debug(ctx, "subscribe duplicated log",
slog.F("stage", log.Stage))
} else {
logger.Debug(ctx, "subscribe encoding log",
slog.F("stage", log.Stage))
err = encoder.Encode(convertProvisionerJobLog(*log))
if err != nil {
return
}
}
}
}
follower.follow()
}
func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
@ -334,98 +221,225 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
return job
}
func provisionerJobLogsChannel(jobID uuid.UUID) string {
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
func fetchAndWriteLogs(ctx context.Context, logger slog.Logger, db database.Store, jobID uuid.UUID, after int64, rw http.ResponseWriter) {
logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: jobID,
CreatedAfter: after,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
return
}
if logs == nil {
logs = []database.ProvisionerJobLog{}
}
logger.Debug(ctx, "Finished non-follow job logs")
httpapi.Write(ctx, rw, http.StatusOK, convertProvisionerJobLogs(logs))
}
// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel
type provisionerJobLogsMessage struct {
CreatedAfter int64 `json:"created_after"`
EndOfLogs bool `json:"end_of_logs,omitempty"`
func jobIsComplete(logger slog.Logger, job database.ProvisionerJob) bool {
status := db2sdk.ProvisionerJobStatus(job)
switch status {
case codersdk.ProvisionerJobCanceled:
return true
case codersdk.ProvisionerJobFailed:
return true
case codersdk.ProvisionerJobSucceeded:
return true
case codersdk.ProvisionerJobPending:
return false
case codersdk.ProvisionerJobCanceling:
return false
case codersdk.ProvisionerJobRunning:
return false
default:
logger.Error(context.Background(),
"unknown status",
slog.F("job_id", job.ID), slog.F("status", status))
return false
}
}
func (api *API) followProvisionerJobLogs(actor rbac.Subject, jobID uuid.UUID) (<-chan *database.ProvisionerJobLog, func(), error) {
logger := api.Logger.With(slog.F("job_id", jobID))
type logFollower struct {
ctx context.Context
logger slog.Logger
db database.Store
pubsub database.Pubsub
r *http.Request
rw http.ResponseWriter
conn *websocket.Conn
var (
// With debug logging enabled length = 128 is insufficient
bufferedLogs = make(chan *database.ProvisionerJobLog, 1024)
endOfLogs atomic.Bool
lastSentLogID atomic.Int64
)
jobID uuid.UUID
after int64
complete bool
notifications chan provisionersdk.ProvisionerJobLogsNotifyMessage
errors chan error
}
sendLog := func(log *database.ProvisionerJobLog) {
func newLogFollower(
ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub,
rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob, after int64,
) *logFollower {
return &logFollower{
ctx: ctx,
logger: logger,
db: db,
pubsub: pubsub,
r: r,
rw: rw,
jobID: job.ID,
after: after,
complete: jobIsComplete(logger, job),
notifications: make(chan provisionersdk.ProvisionerJobLogsNotifyMessage),
errors: make(chan error),
}
}
func (f *logFollower) follow() {
// note that we only need to subscribe to updates if the job is not yet
// complete.
if !f.complete {
subCancel, err := f.pubsub.SubscribeWithErr(
provisionersdk.ProvisionerJobLogsNotifyChannel(f.jobID),
f.listener,
)
if err != nil {
httpapi.Write(f.ctx, f.rw, http.StatusInternalServerError, codersdk.Response{
Message: "failed to subscribe to job updates",
Detail: err.Error(),
})
return
}
defer subCancel()
// we were provided `complete` prior to starting this subscription, so
// we also need to check whether the job is now complete, in case the
// job completed between the last time we queried the job and the start
// of the subscription. If the job completes after this, we will get
// a notification on the subscription.
job, err := f.db.GetProvisionerJobByID(f.ctx, f.jobID)
if err != nil {
httpapi.Write(f.ctx, f.rw, http.StatusInternalServerError, codersdk.Response{
Message: "failed to query job",
Detail: err.Error(),
})
return
}
f.complete = jobIsComplete(f.logger, job)
f.logger.Debug(f.ctx, "queried job after subscribe", slog.F("complete", f.complete))
}
var err error
f.conn, err = websocket.Accept(f.rw, f.r, nil)
if err != nil {
httpapi.Write(f.ctx, f.rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
defer f.conn.Close(websocket.StatusNormalClosure, "done")
go httpapi.Heartbeat(f.ctx, f.conn)
// query for logs once right away, so we can get historical data from before
// subscription
if err := f.query(); err != nil {
if f.ctx.Err() == nil && !xerrors.Is(err, io.EOF) {
// neither context expiry, nor EOF, close and log
f.logger.Error(f.ctx, "failed to query logs", slog.Error(err))
err = f.conn.Close(websocket.StatusInternalError, err.Error())
if err != nil {
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
}
}
return
}
// no need to wait if the job is done
if f.complete {
return
}
for {
select {
case bufferedLogs <- log:
logger.Debug(context.Background(), "subscribe buffered log", slog.F("stage", log.Stage))
lastSentLogID.Store(log.ID)
default:
// If this overflows users could miss logs streaming. This can happen
// we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub,
// so just drop them.
logger.Warn(context.Background(), "provisioner job log overflowing channel")
case err := <-f.errors:
// we've dropped at least one notification. This can happen if we
// lose database connectivity. We don't know whether the job is
// now complete since we could have missed the end of logs message.
// We could soldier on and retry, but loss of database connectivity
// is fairly serious, so instead just 500 and bail out. Client
// can retry and hopefully find a healthier node.
f.logger.Error(f.ctx, "dropped or corrupted notification", slog.Error(err))
err = f.conn.Close(websocket.StatusInternalError, err.Error())
if err != nil {
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
}
return
case <-f.ctx.Done():
// client disconnect
return
case n := <-f.notifications:
if n.EndOfLogs {
// safe to return here because we started the subscription,
// and then queried at least once, so we will have already
// gotten all logs prior to the start of our subscription.
return
}
err = f.query()
if err != nil {
if f.ctx.Err() == nil && !xerrors.Is(err, io.EOF) {
// neither context expiry, nor EOF, close and log
f.logger.Error(f.ctx, "failed to query logs", slog.Error(err))
err = f.conn.Close(websocket.StatusInternalError, err.Error())
if err != nil {
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
}
}
return
}
}
}
closeSubscribe, err := api.Pubsub.Subscribe(
provisionerJobLogsChannel(jobID),
func(ctx context.Context, message []byte) {
if endOfLogs.Load() {
return
}
jlMsg := provisionerJobLogsMessage{}
err := json.Unmarshal(message, &jlMsg)
if err != nil {
logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err))
return
}
// CreatedAfter is sent when logs are streaming!
if jlMsg.CreatedAfter != 0 {
logs, err := api.Database.GetProvisionerLogsAfterID(dbauthz.As(ctx, actor), database.GetProvisionerLogsAfterIDParams{
JobID: jobID,
CreatedAfter: jlMsg.CreatedAfter,
})
if err != nil {
logger.Warn(ctx, "get provisioner logs", slog.Error(err))
return
}
for _, log := range logs {
if endOfLogs.Load() {
// An end of logs message came in while we were fetching
// logs or processing them!
return
}
log := log
sendLog(&log)
}
}
// EndOfLogs is sent when logs are done streaming.
// We don't want to end the stream until we've sent all the logs,
// so we fetch logs after the last ID we've seen and send them!
if jlMsg.EndOfLogs {
endOfLogs.Store(true)
logs, err := api.Database.GetProvisionerLogsAfterID(dbauthz.As(ctx, actor), database.GetProvisionerLogsAfterIDParams{
JobID: jobID,
CreatedAfter: lastSentLogID.Load(),
})
if err != nil {
logger.Warn(ctx, "get provisioner logs", slog.Error(err))
return
}
for _, log := range logs {
log := log
sendLog(&log)
}
logger.Debug(ctx, "got End of Logs")
bufferedLogs <- nil
}
},
)
if err != nil {
return nil, nil, err
}
// We don't need to close the bufferedLogs channel because it will be garbage collected!
return bufferedLogs, closeSubscribe, nil
}
func (f *logFollower) listener(_ context.Context, message []byte, err error) {
if err != nil {
f.errors <- err
return
}
var n provisionersdk.ProvisionerJobLogsNotifyMessage
err = json.Unmarshal(message, &n)
if err != nil {
f.errors <- err
return
}
f.notifications <- n
}
// query fetches the latest job logs from the database and writes them to the
// connection.
func (f *logFollower) query() error {
f.logger.Debug(f.ctx, "querying logs", slog.F("after", f.after))
logs, err := f.db.GetProvisionerLogsAfterID(f.ctx, database.GetProvisionerLogsAfterIDParams{
JobID: f.jobID,
CreatedAfter: f.after,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("error fetching logs: %w", err)
}
for _, log := range logs {
logB, err := json.Marshal(convertProvisionerJobLog(log))
if err != nil {
return xerrors.Errorf("error marshaling log: %w", err)
}
err = f.conn.Write(f.ctx, websocket.MessageText, logB)
if err != nil {
return xerrors.Errorf("error writing to websocket: %w", err)
}
f.after = log.ID
f.logger.Debug(f.ctx, "wrote log to websocket", slog.F("id", log.ID))
}
return nil
}

View File

@ -1,13 +1,28 @@
package coderd
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nhooyr.io/websocket"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbmock"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/testutil"
)
func TestConvertProvisionerJob_Unit(t *testing.T) {
@ -115,3 +130,277 @@ func TestConvertProvisionerJob_Unit(t *testing.T) {
})
}
}
func Test_logFollower_completeBeforeFollow(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
pubsub := database.NewPubsubInMemory()
now := database.Now()
job := database.ProvisionerJob{
ID: uuid.New(),
CreatedAt: now.Add(-10 * time.Second),
UpdatedAt: now.Add(-10 * time.Second),
StartedAt: sql.NullTime{
Time: now.Add(-10 * time.Second),
Valid: true,
},
CompletedAt: sql.NullTime{
Time: now.Add(-time.Second),
Valid: true,
},
Error: sql.NullString{},
}
// we need an HTTP server to get a websocket
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 10)
uut.follow()
}))
defer srv.Close()
// return some historical logs
mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 10)).
Times(1).
Return(
[]database.ProvisionerJobLog{
{Stage: "One", Output: "One", ID: 11},
{Stage: "One", Output: "Two", ID: 12},
},
nil,
)
// nolint: bodyclose
client, _, err := websocket.Dial(ctx, srv.URL, nil)
require.NoError(t, err)
mt, msg, err := client.Read(ctx)
require.NoError(t, err)
assert.Equal(t, websocket.MessageText, mt)
assertLog(t, "One", "One", 11, msg)
mt, msg, err = client.Read(ctx)
require.NoError(t, err)
assert.Equal(t, websocket.MessageText, mt)
assertLog(t, "One", "Two", 12, msg)
// server should now close
_, _, err = client.Read(ctx)
var closeErr websocket.CloseError
require.ErrorAs(t, err, &closeErr)
assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code)
}
func Test_logFollower_completeBeforeSubscribe(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
pubsub := database.NewPubsubInMemory()
now := database.Now()
job := database.ProvisionerJob{
ID: uuid.New(),
CreatedAt: now.Add(-10 * time.Second),
UpdatedAt: now.Add(-10 * time.Second),
StartedAt: sql.NullTime{
Time: now.Add(-10 * time.Second),
Valid: true,
},
CanceledAt: sql.NullTime{},
CompletedAt: sql.NullTime{},
Error: sql.NullString{},
}
// we need an HTTP server to get a websocket
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0)
uut.follow()
}))
defer srv.Close()
// job was incomplete when we create the logFollower, but is complete as soon
// as it queries again.
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), job.ID).Times(1).Return(
database.ProvisionerJob{
ID: job.ID,
CreatedAt: job.CreatedAt,
UpdatedAt: job.UpdatedAt,
StartedAt: job.StartedAt,
CompletedAt: sql.NullTime{
Time: now.Add(-time.Millisecond),
Valid: true,
},
},
nil,
)
// return some historical logs
mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 0)).
Times(1).
Return(
[]database.ProvisionerJobLog{
{Stage: "One", Output: "One", ID: 1},
{Stage: "One", Output: "Two", ID: 2},
},
nil,
)
// nolint: bodyclose
client, _, err := websocket.Dial(ctx, srv.URL, nil)
require.NoError(t, err)
mt, msg, err := client.Read(ctx)
require.NoError(t, err)
assert.Equal(t, websocket.MessageText, mt)
assertLog(t, "One", "One", 1, msg)
mt, msg, err = client.Read(ctx)
require.NoError(t, err)
assert.Equal(t, websocket.MessageText, mt)
assertLog(t, "One", "Two", 2, msg)
// server should now close
_, _, err = client.Read(ctx)
var closeErr websocket.CloseError
require.ErrorAs(t, err, &closeErr)
assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code)
}
func Test_logFollower_EndOfLogs(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
pubsub := database.NewPubsubInMemory()
now := database.Now()
job := database.ProvisionerJob{
ID: uuid.New(),
CreatedAt: now.Add(-10 * time.Second),
UpdatedAt: now.Add(-10 * time.Second),
StartedAt: sql.NullTime{
Time: now.Add(-10 * time.Second),
Valid: true,
},
CanceledAt: sql.NullTime{},
CompletedAt: sql.NullTime{},
Error: sql.NullString{},
}
// we need an HTTP server to get a websocket
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0)
uut.follow()
}))
defer srv.Close()
// job was incomplete when we create the logFollower, and still incomplete when it queries
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), job.ID).Times(1).Return(job, nil)
// return some historical logs
q0 := mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 0)).
Times(1).
Return(
[]database.ProvisionerJobLog{
{Stage: "One", Output: "One", ID: 1},
{Stage: "One", Output: "Two", ID: 2},
},
nil,
)
// return some logs from a kick.
mDB.EXPECT().GetProvisionerLogsAfterID(gomock.Any(), matchesJobAfter(job.ID, 2)).
After(q0).
Times(1).
Return(
[]database.ProvisionerJobLog{
{Stage: "One", Output: "Three", ID: 3},
{Stage: "Two", Output: "One", ID: 4},
},
nil,
)
// nolint: bodyclose
client, _, err := websocket.Dial(ctx, srv.URL, nil)
require.NoError(t, err)
mt, msg, err := client.Read(ctx)
require.NoError(t, err)
assert.Equal(t, websocket.MessageText, mt)
assertLog(t, "One", "One", 1, msg)
mt, msg, err = client.Read(ctx)
require.NoError(t, err)
assert.Equal(t, websocket.MessageText, mt)
assertLog(t, "One", "Two", 2, msg)
// send in the kick so follower will query a second time
n := provisionersdk.ProvisionerJobLogsNotifyMessage{
CreatedAfter: 2,
}
msg, err = json.Marshal(&n)
require.NoError(t, err)
err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
require.NoError(t, err)
mt, msg, err = client.Read(ctx)
require.NoError(t, err)
assert.Equal(t, websocket.MessageText, mt)
assertLog(t, "One", "Three", 3, msg)
mt, msg, err = client.Read(ctx)
require.NoError(t, err)
assert.Equal(t, websocket.MessageText, mt)
assertLog(t, "Two", "One", 4, msg)
// send EndOfLogs
n.EndOfLogs = true
n.CreatedAfter = 0
msg, err = json.Marshal(&n)
require.NoError(t, err)
err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
require.NoError(t, err)
// server should now close
_, _, err = client.Read(ctx)
var closeErr websocket.CloseError
require.ErrorAs(t, err, &closeErr)
assert.Equal(t, websocket.StatusNormalClosure, closeErr.Code)
}
func assertLog(t *testing.T, stage, output string, id int64, msg []byte) {
t.Helper()
var log codersdk.ProvisionerJobLog
err := json.Unmarshal(msg, &log)
require.NoError(t, err)
assert.Equal(t, stage, log.Stage)
assert.Equal(t, output, log.Output)
assert.Equal(t, id, log.ID)
}
type logsAfterMatcher struct {
params database.GetProvisionerLogsAfterIDParams
}
func (m *logsAfterMatcher) Matches(x interface{}) bool {
p, ok := x.(database.GetProvisionerLogsAfterIDParams)
if !ok {
return false
}
return m.params == p
}
func (m *logsAfterMatcher) String() string {
return fmt.Sprintf("%+v", m.params)
}
func matchesJobAfter(jobID uuid.UUID, after int64) gomock.Matcher {
return &logsAfterMatcher{
params: database.GetProvisionerLogsAfterIDParams{
JobID: jobID,
CreatedAfter: after,
},
}
}

View File

@ -132,15 +132,20 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
}
logs := make(chan ProvisionerJobLog)
closed := make(chan struct{})
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
decoder := json.NewDecoder(wsNetConn)
go func() {
defer close(closed)
defer close(logs)
defer conn.Close(websocket.StatusGoingAway, "")
var log ProvisionerJobLog
for {
err = decoder.Decode(&log)
msgType, msg, err := conn.Read(ctx)
if err != nil {
return
}
if msgType != websocket.MessageText {
return
}
err = json.Unmarshal(msg, &log)
if err != nil {
return
}
@ -152,7 +157,6 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
}
}()
return logs, closeFunc(func() error {
_ = wsNetConn.Close()
<-closed
return nil
}), nil

20
provisionersdk/logs.go Normal file
View File

@ -0,0 +1,20 @@
package provisionersdk
import (
"fmt"
"github.com/google/uuid"
)
// ProvisionerJobLogsNotifyMessage is the payload published on
// the provisioner job logs notify channel.
type ProvisionerJobLogsNotifyMessage struct {
CreatedAfter int64 `json:"created_after"`
EndOfLogs bool `json:"end_of_logs,omitempty"`
}
// ProvisionerJobLogsNotifyChannel is the PostgreSQL NOTIFY channel
// to publish updates to job logs on.
func ProvisionerJobLogsNotifyChannel(jobID uuid.UUID) string {
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
}