chore: add agentapi tests (#11269)

This commit is contained in:
Dean Sheather 2024-01-26 17:04:19 +10:00 committed by GitHub
parent 541154b74b
commit 29707099d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2504 additions and 157 deletions

View File

@ -41,13 +41,14 @@ func ActivityBumpWorkspace(ctx context.Context, log slog.Logger, db database.Sto
// low priority operations fail first.
ctx, cancel := context.WithTimeout(ctx, time.Second*15)
defer cancel()
if err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{
err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{
NextAutostart: nextAutostart.UTC(),
WorkspaceID: workspaceID,
}); err != nil {
})
if err != nil {
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
// Bump will fail if the context is canceled, but this is ok.
log.Error(ctx, "bump failed", slog.Error(err),
log.Error(ctx, "activity bump failed", slog.Error(err),
slog.F("workspace_id", workspaceID),
)
}

View File

@ -17,7 +17,6 @@ import (
"cdr.dev/slog"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/batchstats"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
@ -61,19 +60,17 @@ type Options struct {
DerpMapFn func() *tailcfg.DERPMap
TailnetCoordinator *atomic.Pointer[tailnet.Coordinator]
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
StatsBatcher *batchstats.Batcher
StatsBatcher StatsBatcher
PublishWorkspaceUpdateFn func(ctx context.Context, workspaceID uuid.UUID)
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
AccessURL *url.URL
AppHostname string
AgentInactiveDisconnectTimeout time.Duration
AgentFallbackTroubleshootingURL string
AgentStatsRefreshInterval time.Duration
DisableDirectConnections bool
DerpForceWebSockets bool
DerpMapUpdateFrequency time.Duration
ExternalAuthConfigs []*externalauth.Config
AccessURL *url.URL
AppHostname string
AgentStatsRefreshInterval time.Duration
DisableDirectConnections bool
DerpForceWebSockets bool
DerpMapUpdateFrequency time.Duration
ExternalAuthConfigs []*externalauth.Config
// Optional:
// WorkspaceID avoids a future lookup to find the workspace ID by setting
@ -90,17 +87,14 @@ func New(opts Options) *API {
}
api.ManifestAPI = &ManifestAPI{
AccessURL: opts.AccessURL,
AppHostname: opts.AppHostname,
AgentInactiveDisconnectTimeout: opts.AgentInactiveDisconnectTimeout,
AgentFallbackTroubleshootingURL: opts.AgentFallbackTroubleshootingURL,
ExternalAuthConfigs: opts.ExternalAuthConfigs,
DisableDirectConnections: opts.DisableDirectConnections,
DerpForceWebSockets: opts.DerpForceWebSockets,
AgentFn: api.agent,
Database: opts.Database,
DerpMapFn: opts.DerpMapFn,
TailnetCoordinator: opts.TailnetCoordinator,
AccessURL: opts.AccessURL,
AppHostname: opts.AppHostname,
ExternalAuthConfigs: opts.ExternalAuthConfigs,
DisableDirectConnections: opts.DisableDirectConnections,
DerpForceWebSockets: opts.DerpForceWebSockets,
AgentFn: api.agent,
Database: opts.Database,
DerpMapFn: opts.DerpMapFn,
}
api.ServiceBannerAPI = &ServiceBannerAPI{
@ -214,20 +208,15 @@ func (a *API) workspaceID(ctx context.Context, agent *database.WorkspaceAgent) (
agent = &agnt
}
resource, err := a.opts.Database.GetWorkspaceResourceByID(ctx, agent.ResourceID)
getWorkspaceAgentByIDRow, err := a.opts.Database.GetWorkspaceByAgentID(ctx, agent.ID)
if err != nil {
return uuid.Nil, xerrors.Errorf("get workspace agent resource by id %q: %w", agent.ResourceID, err)
}
build, err := a.opts.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
if err != nil {
return uuid.Nil, xerrors.Errorf("get workspace build by job id %q: %w", resource.JobID, err)
return uuid.Nil, xerrors.Errorf("get workspace by agent id %q: %w", agent.ID, err)
}
a.mu.Lock()
a.cachedWorkspaceID = build.WorkspaceID
a.cachedWorkspaceID = getWorkspaceAgentByIDRow.Workspace.ID
a.mu.Unlock()
return build.WorkspaceID, nil
return getWorkspaceAgentByIDRow.Workspace.ID, nil
}
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent) error {

View File

@ -90,9 +90,11 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}
}
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
}
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}

View File

@ -0,0 +1,252 @@
package agentapi_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"cdr.dev/slog/sloggers/slogtest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
)
func TestBatchUpdateAppHealths(t *testing.T) {
t.Parallel()
var (
agent = database.WorkspaceAgent{
ID: uuid.New(),
}
app1 = database.WorkspaceApp{
ID: uuid.New(),
AgentID: agent.ID,
Slug: "code-server-1",
DisplayName: "code-server 1",
HealthcheckUrl: "http://localhost:3000",
Health: database.WorkspaceAppHealthInitializing,
}
app2 = database.WorkspaceApp{
ID: uuid.New(),
AgentID: agent.ID,
Slug: "code-server-2",
DisplayName: "code-server 2",
HealthcheckUrl: "http://localhost:3001",
Health: database.WorkspaceAppHealthHealthy,
}
)
t.Run("OK", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
dbM.EXPECT().UpdateWorkspaceAppHealthByID(gomock.Any(), database.UpdateWorkspaceAppHealthByIDParams{
ID: app1.ID,
Health: database.WorkspaceAppHealthHealthy,
}).Return(nil)
dbM.EXPECT().UpdateWorkspaceAppHealthByID(gomock.Any(), database.UpdateWorkspaceAppHealthByIDParams{
ID: app2.ID,
Health: database.WorkspaceAppHealthUnhealthy,
}).Return(nil)
publishCalled := false
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishCalled = true
return nil
},
}
// Set one to healthy, set another to unhealthy.
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: app1.ID[:],
Health: agentproto.AppHealth_HEALTHY,
},
{
Id: app2.ID[:],
Health: agentproto.AppHealth_UNHEALTHY,
},
},
})
require.NoError(t, err)
require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp)
require.True(t, publishCalled)
})
t.Run("Unchanged", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
publishCalled := false
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishCalled = true
return nil
},
}
// Set both to their current status, neither should be updated in the
// DB.
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: app1.ID[:],
Health: agentproto.AppHealth_INITIALIZING,
},
{
Id: app2.ID[:],
Health: agentproto.AppHealth_HEALTHY,
},
},
})
require.NoError(t, err)
require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp)
require.False(t, publishCalled)
})
t.Run("Empty", func(t *testing.T) {
t.Parallel()
// No DB queries are made if there are no updates to process.
dbM := dbmock.NewMockStore(gomock.NewController(t))
publishCalled := false
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishCalled = true
return nil
},
}
// Do nothing.
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{},
})
require.NoError(t, err)
require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp)
require.False(t, publishCalled)
})
t.Run("AppNoHealthcheck", func(t *testing.T) {
t.Parallel()
app3 := database.WorkspaceApp{
ID: uuid.New(),
AgentID: agent.ID,
Slug: "code-server-3",
DisplayName: "code-server 3",
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: nil,
}
// Set app3 to healthy, should error.
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: app3.ID[:],
Health: agentproto.AppHealth_HEALTHY,
},
},
})
require.Error(t, err)
require.ErrorContains(t, err, "does not have healthchecks enabled")
require.Nil(t, resp)
})
t.Run("UnknownApp", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: nil,
}
// Set an unknown app to healthy, should error.
id := uuid.New()
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: id[:],
Health: agentproto.AppHealth_HEALTHY,
},
},
})
require.Error(t, err)
require.ErrorContains(t, err, "not found")
require.Nil(t, resp)
})
t.Run("InvalidHealth", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: nil,
}
// Set an unknown app to healthy, should error.
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
{
Id: app1.ID[:],
Health: -999,
},
},
})
require.Error(t, err)
require.ErrorContains(t, err, "unknown health status")
require.Nil(t, resp)
})
}

View File

@ -3,6 +3,7 @@ package agentapi
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"golang.org/x/mod/semver"
@ -21,6 +22,15 @@ type LifecycleAPI struct {
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error
TimeNowFn func() time.Time // defaults to dbtime.Now()
}
func (a *LifecycleAPI) now() time.Time {
if a.TimeNowFn != nil {
return a.TimeNowFn()
}
return dbtime.Now()
}
func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
@ -68,7 +78,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
changedAt := req.Lifecycle.ChangedAt.AsTime()
if changedAt.IsZero() {
changedAt = dbtime.Now()
changedAt = a.now()
req.Lifecycle.ChangedAt = timestamppb.New(changedAt)
}
dbChangedAt := sql.NullTime{Time: changedAt, Valid: true}
@ -78,8 +88,13 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
switch lifecycleState {
case database.WorkspaceAgentLifecycleStateStarting:
startedAt = dbChangedAt
readyAt.Valid = false // This agent is re-starting, so it's not ready yet.
// This agent is (re)starting, so it's not ready yet.
readyAt.Time = time.Time{}
readyAt.Valid = false
case database.WorkspaceAgentLifecycleStateReady, database.WorkspaceAgentLifecycleStateStartError:
if !startedAt.Valid {
startedAt = dbChangedAt
}
readyAt = dbChangedAt
}
@ -97,9 +112,11 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
return nil, xerrors.Errorf("update workspace agent lifecycle state: %w", err)
}
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
}
return req.Lifecycle, nil

View File

@ -0,0 +1,461 @@
package agentapi_test
import (
"context"
"database/sql"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/types/known/timestamppb"
"cdr.dev/slog/sloggers/slogtest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
)
func TestUpdateLifecycle(t *testing.T) {
t.Parallel()
someTime, err := time.Parse(time.RFC3339, "2023-01-01T00:00:00Z")
require.NoError(t, err)
someTime = dbtime.Time(someTime)
now := dbtime.Now()
var (
workspaceID = uuid.New()
agentCreated = database.WorkspaceAgent{
ID: uuid.New(),
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
StartedAt: sql.NullTime{Valid: false},
ReadyAt: sql.NullTime{Valid: false},
}
agentStarting = database.WorkspaceAgent{
ID: uuid.New(),
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
StartedAt: sql.NullTime{Valid: true, Time: someTime},
ReadyAt: sql.NullTime{Valid: false},
}
)
t.Run("OKStarting", func(t *testing.T) {
t.Parallel()
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_STARTING,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agentCreated.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
StartedAt: sql.NullTime{
Time: now,
Valid: true,
},
ReadyAt: sql.NullTime{Valid: false},
}).Return(nil)
publishCalled := false
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentCreated, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
publishCalled = true
return nil
},
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
require.True(t, publishCalled)
})
t.Run("OKReadying", func(t *testing.T) {
t.Parallel()
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agentStarting.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: agentStarting.StartedAt,
ReadyAt: sql.NullTime{
Time: now,
Valid: true,
},
}).Return(nil)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentStarting, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
// Test that nil publish fn works.
PublishWorkspaceUpdateFn: nil,
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
})
// This test jumps from CREATING to READY, skipping STARTED. Both the
// StartedAt and ReadyAt fields should be set.
t.Run("OKStraightToReady", func(t *testing.T) {
t.Parallel()
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agentCreated.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: sql.NullTime{
Time: now,
Valid: true,
},
ReadyAt: sql.NullTime{
Time: now,
Valid: true,
},
}).Return(nil)
publishCalled := false
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentCreated, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
publishCalled = true
return nil
},
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
require.True(t, publishCalled)
})
t.Run("NoTimeSpecified", func(t *testing.T) {
t.Parallel()
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
// Zero time
ChangedAt: timestamppb.New(time.Time{}),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
now := dbtime.Now()
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agentCreated.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: sql.NullTime{
Time: now,
Valid: true,
},
ReadyAt: sql.NullTime{
Time: now,
Valid: true,
},
})
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentCreated, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: nil,
TimeNowFn: func() time.Time {
return now
},
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
})
t.Run("AllStates", func(t *testing.T) {
t.Parallel()
agent := database.WorkspaceAgent{
ID: uuid.New(),
LifecycleState: database.WorkspaceAgentLifecycleState(""),
StartedAt: sql.NullTime{Valid: false},
ReadyAt: sql.NullTime{Valid: false},
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
var publishCalled int64
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
atomic.AddInt64(&publishCalled, 1)
return nil
},
}
states := []agentproto.Lifecycle_State{
agentproto.Lifecycle_CREATED,
agentproto.Lifecycle_STARTING,
agentproto.Lifecycle_START_TIMEOUT,
agentproto.Lifecycle_START_ERROR,
agentproto.Lifecycle_READY,
agentproto.Lifecycle_SHUTTING_DOWN,
agentproto.Lifecycle_SHUTDOWN_TIMEOUT,
agentproto.Lifecycle_SHUTDOWN_ERROR,
agentproto.Lifecycle_OFF,
}
for i, state := range states {
t.Log("state", state)
// Use a time after the last state change to ensure ordering.
stateNow := now.Add(time.Hour * time.Duration(i))
lifecycle := &agentproto.Lifecycle{
State: state,
ChangedAt: timestamppb.New(stateNow),
}
expectedStartedAt := agent.StartedAt
expectedReadyAt := agent.ReadyAt
if state == agentproto.Lifecycle_STARTING {
expectedStartedAt = sql.NullTime{Valid: true, Time: stateNow}
}
if state == agentproto.Lifecycle_READY || state == agentproto.Lifecycle_START_ERROR {
expectedReadyAt = sql.NullTime{Valid: true, Time: stateNow}
}
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agent.ID,
LifecycleState: database.WorkspaceAgentLifecycleState(strings.ToLower(state.String())),
StartedAt: expectedStartedAt,
ReadyAt: expectedReadyAt,
}).Times(1).Return(nil)
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
require.Equal(t, int64(i+1), atomic.LoadInt64(&publishCalled))
// For future iterations:
agent.StartedAt = expectedStartedAt
agent.ReadyAt = expectedReadyAt
}
})
t.Run("UnknownLifecycleState", func(t *testing.T) {
t.Parallel()
lifecycle := &agentproto.Lifecycle{
State: -999,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
publishCalled := false
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agentCreated, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
publishCalled = true
return nil
},
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.Error(t, err)
require.ErrorContains(t, err, "unknown lifecycle state")
require.Nil(t, resp)
require.False(t, publishCalled)
})
}
func TestUpdateStartup(t *testing.T) {
t.Parallel()
var (
workspaceID = uuid.New()
agent = database.WorkspaceAgent{
ID: uuid.New(),
}
)
t.Run("OK", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
// Not used by UpdateStartup.
PublishWorkspaceUpdateFn: nil,
}
startup := &agentproto.Startup{
Version: "v1.2.3",
ExpandedDirectory: "/path/to/expanded/dir",
Subsystems: []agentproto.Startup_Subsystem{
agentproto.Startup_ENVBOX,
agentproto.Startup_ENVBUILDER,
agentproto.Startup_EXECTRACE,
},
}
dbM.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), database.UpdateWorkspaceAgentStartupByIDParams{
ID: agent.ID,
Version: startup.Version,
ExpandedDirectory: startup.ExpandedDirectory,
Subsystems: []database.WorkspaceAgentSubsystem{
database.WorkspaceAgentSubsystemEnvbox,
database.WorkspaceAgentSubsystemEnvbuilder,
database.WorkspaceAgentSubsystemExectrace,
},
APIVersion: agentapi.AgentAPIVersionDRPC,
}).Return(nil)
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.NoError(t, err)
require.Equal(t, startup, resp)
})
t.Run("BadVersion", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
// Not used by UpdateStartup.
PublishWorkspaceUpdateFn: nil,
}
startup := &agentproto.Startup{
Version: "asdf",
ExpandedDirectory: "/path/to/expanded/dir",
Subsystems: []agentproto.Startup_Subsystem{},
}
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.Error(t, err)
require.ErrorContains(t, err, "invalid agent semver version")
require.Nil(t, resp)
})
t.Run("BadSubsystem", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
// Not used by UpdateStartup.
PublishWorkspaceUpdateFn: nil,
}
startup := &agentproto.Startup{
Version: "v1.2.3",
ExpandedDirectory: "/path/to/expanded/dir",
Subsystems: []agentproto.Startup_Subsystem{
agentproto.Startup_ENVBOX,
-999,
},
}
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.Error(t, err)
require.ErrorContains(t, err, "invalid agent subsystem")
require.Nil(t, resp)
})
}

View File

@ -2,6 +2,7 @@ package agentapi
import (
"context"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
@ -19,6 +20,15 @@ type LogsAPI struct {
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
TimeNowFn func() time.Time // defaults to dbtime.Now()
}
func (a *LogsAPI) now() time.Time {
if a.TimeNowFn != nil {
return a.TimeNowFn()
}
return dbtime.Now()
}
func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) {
@ -26,6 +36,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
if err != nil {
return nil, err
}
if workspaceAgent.LogsOverflowed {
return nil, xerrors.New("workspace agent logs overflowed")
}
if len(req.Logs) == 0 {
return &agentproto.BatchCreateLogsResponse{}, nil
@ -42,7 +55,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
// Use the external log source
externalSources, err := a.Database.InsertWorkspaceAgentLogSources(ctx, database.InsertWorkspaceAgentLogSourcesParams{
WorkspaceAgentID: workspaceAgent.ID,
CreatedAt: dbtime.Now(),
CreatedAt: a.now(),
ID: []uuid.UUID{agentsdk.ExternalLogSourceID},
DisplayName: []string{"External"},
Icon: []string{"/emojis/1f310.png"},
@ -88,7 +101,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
logs, err := a.Database.InsertWorkspaceAgentLogs(ctx, database.InsertWorkspaceAgentLogsParams{
AgentID: workspaceAgent.ID,
CreatedAt: dbtime.Now(),
CreatedAt: a.now(),
Output: output,
Level: level,
LogSourceID: logSourceID,
@ -98,9 +111,6 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
if !database.IsWorkspaceAgentLogsLimitError(err) {
return nil, xerrors.Errorf("insert workspace agent logs: %w", err)
}
if workspaceAgent.LogsOverflowed {
return nil, xerrors.New("workspace agent logs overflowed")
}
err := a.Database.UpdateWorkspaceAgentLogOverflowByID(ctx, database.UpdateWorkspaceAgentLogOverflowByIDParams{
ID: workspaceAgent.ID,
LogsOverflowed: true,
@ -112,21 +122,25 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
a.Log.Warn(ctx, "failed to update workspace agent log overflow", slog.Error(err))
}
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
}
return nil, xerrors.New("workspace agent log limit exceeded")
}
// Publish by the lowest log ID inserted so the log stream will fetch
// everything from that point.
lowestLogID := logs[0].ID
a.PublishWorkspaceAgentLogsUpdateFn(ctx, workspaceAgent.ID, agentsdk.LogsNotifyMessage{
CreatedAfter: lowestLogID - 1,
})
if a.PublishWorkspaceAgentLogsUpdateFn != nil {
lowestLogID := logs[0].ID
a.PublishWorkspaceAgentLogsUpdateFn(ctx, workspaceAgent.ID, agentsdk.LogsNotifyMessage{
CreatedAfter: lowestLogID - 1,
})
}
if workspaceAgent.LogsLength == 0 {
if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil {
// If these are the first logs being appended, we publish a UI update
// to notify the UI that logs are now available.
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)

View File

@ -0,0 +1,427 @@
package agentapi_test
import (
"context"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/types/known/timestamppb"
"cdr.dev/slog/sloggers/slogtest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func TestBatchCreateLogs(t *testing.T) {
t.Parallel()
var (
agent = database.WorkspaceAgent{
ID: uuid.New(),
}
logSource = database.WorkspaceAgentLogSource{
WorkspaceAgentID: agent.ID,
CreatedAt: dbtime.Now(),
ID: uuid.New(),
}
)
t.Run("OK", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
publishWorkspaceUpdateCalled := false
publishWorkspaceAgentLogsUpdateCalled := false
now := dbtime.Now()
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishWorkspaceUpdateCalled = true
return nil
},
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
publishWorkspaceAgentLogsUpdateCalled = true
// Check the message content, should be for -1 since the lowest
// log we inserted was 0.
assert.Equal(t, agentsdk.LogsNotifyMessage{CreatedAfter: -1}, msg)
},
TimeNowFn: func() time.Time { return now },
}
req := &agentproto.BatchCreateLogsRequest{
LogSourceId: logSource.ID[:],
Logs: []*agentproto.Log{
{
CreatedAt: timestamppb.New(now),
Level: agentproto.Log_TRACE,
Output: "log line 1",
},
{
CreatedAt: timestamppb.New(now.Add(time.Hour)),
Level: agentproto.Log_DEBUG,
Output: "log line 2",
},
{
CreatedAt: timestamppb.New(now.Add(2 * time.Hour)),
Level: agentproto.Log_INFO,
Output: "log line 3",
},
{
CreatedAt: timestamppb.New(now.Add(3 * time.Hour)),
Level: agentproto.Log_WARN,
Output: "log line 4",
},
{
CreatedAt: timestamppb.New(now.Add(4 * time.Hour)),
Level: agentproto.Log_ERROR,
Output: "log line 5",
},
{
CreatedAt: timestamppb.New(now.Add(5 * time.Hour)),
Level: -999, // defaults to INFO
Output: "log line 6",
},
},
}
// Craft expected DB request and response dynamically.
insertWorkspaceAgentLogsParams := database.InsertWorkspaceAgentLogsParams{
AgentID: agent.ID,
LogSourceID: logSource.ID,
CreatedAt: now,
Output: make([]string, len(req.Logs)),
Level: make([]database.LogLevel, len(req.Logs)),
OutputLength: 0,
}
insertWorkspaceAgentLogsReturn := make([]database.WorkspaceAgentLog, len(req.Logs))
for i, logEntry := range req.Logs {
insertWorkspaceAgentLogsParams.Output[i] = logEntry.Output
level := database.LogLevelInfo
if logEntry.Level >= 0 {
level = database.LogLevel(strings.ToLower(logEntry.Level.String()))
}
insertWorkspaceAgentLogsParams.Level[i] = level
insertWorkspaceAgentLogsParams.OutputLength += int32(len(logEntry.Output))
insertWorkspaceAgentLogsReturn[i] = database.WorkspaceAgentLog{
AgentID: agent.ID,
CreatedAt: logEntry.CreatedAt.AsTime(),
ID: int64(i),
Output: logEntry.Output,
Level: insertWorkspaceAgentLogsParams.Level[i],
LogSourceID: logSource.ID,
}
}
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), insertWorkspaceAgentLogsParams).Return(insertWorkspaceAgentLogsReturn, nil)
resp, err := api.BatchCreateLogs(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
require.True(t, publishWorkspaceUpdateCalled)
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
})
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
t.Parallel()
agentWithLogs := agent
agentWithLogs.LogsLength = 1
dbM := dbmock.NewMockStore(gomock.NewController(t))
publishWorkspaceUpdateCalled := false
publishWorkspaceAgentLogsUpdateCalled := false
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agentWithLogs, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishWorkspaceUpdateCalled = true
return nil
},
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
publishWorkspaceAgentLogsUpdateCalled = true
},
}
// Don't really care about the DB call.
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), gomock.Any()).Return([]database.WorkspaceAgentLog{
{
ID: 1,
},
}, nil)
resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{
LogSourceId: logSource.ID[:],
Logs: []*agentproto.Log{
{
CreatedAt: timestamppb.New(dbtime.Now()),
Level: agentproto.Log_INFO,
Output: "hello world",
},
},
})
require.NoError(t, err)
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
require.False(t, publishWorkspaceUpdateCalled)
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
})
t.Run("AlreadyOverflowed", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
overflowedAgent := agent
overflowedAgent.LogsOverflowed = true
publishWorkspaceUpdateCalled := false
publishWorkspaceAgentLogsUpdateCalled := false
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return overflowedAgent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishWorkspaceUpdateCalled = true
return nil
},
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
publishWorkspaceAgentLogsUpdateCalled = true
},
}
resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{
LogSourceId: logSource.ID[:],
Logs: []*agentproto.Log{},
})
require.Error(t, err)
require.ErrorContains(t, err, "workspace agent logs overflowed")
require.Nil(t, resp)
require.False(t, publishWorkspaceUpdateCalled)
require.False(t, publishWorkspaceAgentLogsUpdateCalled)
})
t.Run("InvalidLogSourceID", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
// Test that they are ignored when nil.
PublishWorkspaceUpdateFn: nil,
PublishWorkspaceAgentLogsUpdateFn: nil,
}
resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{
LogSourceId: []byte("invalid"),
Logs: []*agentproto.Log{
{}, // need at least 1 log
},
})
require.Error(t, err)
require.ErrorContains(t, err, "parse log source ID")
require.Nil(t, resp)
})
t.Run("UseExternalLogSourceID", func(t *testing.T) {
t.Parallel()
now := dbtime.Now()
req := &agentproto.BatchCreateLogsRequest{
LogSourceId: uuid.Nil[:], // defaults to "external"
Logs: []*agentproto.Log{
{
CreatedAt: timestamppb.New(now),
Level: agentproto.Log_INFO,
Output: "hello world",
},
},
}
dbInsertParams := database.InsertWorkspaceAgentLogsParams{
AgentID: agent.ID,
LogSourceID: agentsdk.ExternalLogSourceID,
CreatedAt: now,
Output: []string{"hello world"},
Level: []database.LogLevel{database.LogLevelInfo},
OutputLength: int32(len(req.Logs[0].Output)),
}
dbInsertRes := []database.WorkspaceAgentLog{
{
AgentID: agent.ID,
CreatedAt: now,
ID: 1,
Output: "hello world",
Level: database.LogLevelInfo,
LogSourceID: agentsdk.ExternalLogSourceID,
},
}
t.Run("Create", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
publishWorkspaceUpdateCalled := false
publishWorkspaceAgentLogsUpdateCalled := false
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishWorkspaceUpdateCalled = true
return nil
},
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
publishWorkspaceAgentLogsUpdateCalled = true
},
TimeNowFn: func() time.Time { return now },
}
dbM.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), database.InsertWorkspaceAgentLogSourcesParams{
WorkspaceAgentID: agent.ID,
CreatedAt: now,
ID: []uuid.UUID{agentsdk.ExternalLogSourceID},
DisplayName: []string{"External"},
Icon: []string{"/emojis/1f310.png"},
}).Return([]database.WorkspaceAgentLogSource{
{
// only the ID field is used
ID: agentsdk.ExternalLogSourceID,
},
}, nil)
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), dbInsertParams).Return(dbInsertRes, nil)
resp, err := api.BatchCreateLogs(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
require.True(t, publishWorkspaceUpdateCalled)
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
})
t.Run("Exists", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
publishWorkspaceUpdateCalled := false
publishWorkspaceAgentLogsUpdateCalled := false
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishWorkspaceUpdateCalled = true
return nil
},
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
publishWorkspaceAgentLogsUpdateCalled = true
},
TimeNowFn: func() time.Time { return now },
}
// Return a unique violation error to simulate the log source
// already existing. This should be handled gracefully.
logSourceInsertErr := &pq.Error{
Code: pq.ErrorCode("23505"), // unique_violation
Constraint: string(database.UniqueWorkspaceAgentLogSourcesPkey),
}
dbM.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), database.InsertWorkspaceAgentLogSourcesParams{
WorkspaceAgentID: agent.ID,
CreatedAt: now,
ID: []uuid.UUID{agentsdk.ExternalLogSourceID},
DisplayName: []string{"External"},
Icon: []string{"/emojis/1f310.png"},
}).Return([]database.WorkspaceAgentLogSource{}, logSourceInsertErr)
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), dbInsertParams).Return(dbInsertRes, nil)
resp, err := api.BatchCreateLogs(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
require.True(t, publishWorkspaceUpdateCalled)
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
})
})
t.Run("Overflow", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
publishWorkspaceUpdateCalled := false
publishWorkspaceAgentLogsUpdateCalled := false
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
publishWorkspaceUpdateCalled = true
return nil
},
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
publishWorkspaceAgentLogsUpdateCalled = true
},
}
// Don't really care about the DB call params, just want to return an
// error.
dbErr := &pq.Error{
Constraint: "max_logs_length",
Table: "workspace_agents",
}
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), gomock.Any()).Return(nil, dbErr)
// Should also update the workspace agent.
dbM.EXPECT().UpdateWorkspaceAgentLogOverflowByID(gomock.Any(), database.UpdateWorkspaceAgentLogOverflowByIDParams{
ID: agent.ID,
LogsOverflowed: true,
}).Return(nil)
resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{
LogSourceId: logSource.ID[:],
Logs: []*agentproto.Log{
{
CreatedAt: timestamppb.New(dbtime.Now()),
Level: agentproto.Log_INFO,
Output: "hello world",
},
},
})
require.Error(t, err)
require.Nil(t, resp)
require.True(t, publishWorkspaceUpdateCalled)
require.False(t, publishWorkspaceAgentLogsUpdateCalled)
})
}

View File

@ -5,7 +5,6 @@ import (
"database/sql"
"net/url"
"strings"
"sync/atomic"
"time"
"github.com/google/uuid"
@ -25,18 +24,16 @@ import (
)
type ManifestAPI struct {
AccessURL *url.URL
AppHostname string
AgentInactiveDisconnectTimeout time.Duration
AgentFallbackTroubleshootingURL string
ExternalAuthConfigs []*externalauth.Config
DisableDirectConnections bool
DerpForceWebSockets bool
AccessURL *url.URL
AppHostname string
ExternalAuthConfigs []*externalauth.Config
DisableDirectConnections bool
DerpForceWebSockets bool
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store
DerpMapFn func() *tailcfg.DERPMap
TailnetCoordinator *atomic.Pointer[tailnet.Coordinator]
AgentFn func(context.Context) (database.WorkspaceAgent, error)
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
Database database.Store
DerpMapFn func() *tailcfg.DERPMap
}
func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
@ -44,21 +41,15 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
if err != nil {
return nil, err
}
apiAgent, err := db2sdk.WorkspaceAgent(
a.DerpMapFn(), *a.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, a.AgentInactiveDisconnectTimeout,
a.AgentFallbackTroubleshootingURL,
)
workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("converting workspace agent: %w", err)
return nil, err
}
var (
dbApps []database.WorkspaceApp
scripts []database.WorkspaceAgentScript
metadata []database.WorkspaceAgentMetadatum
resource database.WorkspaceResource
build database.WorkspaceBuild
workspace database.Workspace
owner database.User
)
@ -79,20 +70,12 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
eg.Go(func() (err error) {
metadata, err = a.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
WorkspaceAgentID: workspaceAgent.ID,
Keys: nil,
Keys: nil, // all
})
return err
})
eg.Go(func() (err error) {
resource, err = a.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
if err != nil {
return xerrors.Errorf("getting resource by id: %w", err)
}
build, err = a.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
if err != nil {
return xerrors.Errorf("getting workspace build by job id: %w", err)
}
workspace, err = a.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
workspace, err = a.Database.GetWorkspaceByID(ctx, workspaceID)
if err != nil {
return xerrors.Errorf("getting workspace by id: %w", err)
}
@ -116,6 +99,11 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
vscodeProxyURI := vscodeProxyURI(appSlug, a.AccessURL, a.AppHostname)
envs, err := db2sdk.WorkspaceAgentEnvironment(workspaceAgent)
if err != nil {
return nil, err
}
var gitAuthConfigs uint32
for _, cfg := range a.ExternalAuthConfigs {
if codersdk.EnhancedExternalAuthProvider(cfg.Type).Git() {
@ -135,8 +123,8 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
WorkspaceId: workspace.ID[:],
WorkspaceName: workspace.Name,
GitAuthConfigs: gitAuthConfigs,
EnvironmentVariables: apiAgent.EnvironmentVariables,
Directory: apiAgent.Directory,
EnvironmentVariables: envs,
Directory: workspaceAgent.Directory,
VsCodePortProxyUri: vscodeProxyURI,
MotdPath: workspaceAgent.MOTDFile,
DisableDirectConnections: a.DisableDirectConnections,

View File

@ -0,0 +1,396 @@
package agentapi_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/url"
"testing"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/types/known/durationpb"
"tailscale.com/tailcfg"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
)
func TestGetManifest(t *testing.T) {
t.Parallel()
someTime, err := time.Parse(time.RFC3339, "2023-01-01T00:00:00Z")
require.NoError(t, err)
someTime = dbtime.Time(someTime)
expectedEnvVars := map[string]string{
"FOO": "bar",
"COOL_ENV": "dean was here",
}
expectedEnvVarsJSON, err := json.Marshal(expectedEnvVars)
require.NoError(t, err)
var (
owner = database.User{
ID: uuid.New(),
Username: "cool-user",
}
workspace = database.Workspace{
ID: uuid.New(),
OwnerID: owner.ID,
Name: "cool-workspace",
}
agent = database.WorkspaceAgent{
ID: uuid.New(),
Name: "cool-agent",
EnvironmentVariables: pqtype.NullRawMessage{
RawMessage: expectedEnvVarsJSON,
Valid: true,
},
Directory: "/cool/dir",
MOTDFile: "/cool/motd",
}
apps = []database.WorkspaceApp{
{
ID: uuid.New(),
Url: sql.NullString{String: "http://localhost:1234", Valid: true},
External: false,
Slug: "cool-app-1",
DisplayName: "app 1",
Command: sql.NullString{String: "cool command", Valid: true},
Icon: "/icon.png",
Subdomain: true,
SharingLevel: database.AppSharingLevelAuthenticated,
Health: database.WorkspaceAppHealthHealthy,
HealthcheckUrl: "http://localhost:1234/health",
HealthcheckInterval: 10,
HealthcheckThreshold: 3,
},
{
ID: uuid.New(),
Url: sql.NullString{String: "http://google.com", Valid: true},
External: true,
Slug: "google",
DisplayName: "Literally Google",
Command: sql.NullString{Valid: false},
Icon: "/google.png",
Subdomain: false,
SharingLevel: database.AppSharingLevelPublic,
Health: database.WorkspaceAppHealthDisabled,
},
{
ID: uuid.New(),
Url: sql.NullString{String: "http://localhost:4321", Valid: true},
External: true,
Slug: "cool-app-2",
DisplayName: "another COOL app",
Command: sql.NullString{Valid: false},
Icon: "",
Subdomain: false,
SharingLevel: database.AppSharingLevelOwner,
Health: database.WorkspaceAppHealthUnhealthy,
HealthcheckUrl: "http://localhost:4321/health",
HealthcheckInterval: 20,
HealthcheckThreshold: 5,
},
}
scripts = []database.WorkspaceAgentScript{
{
WorkspaceAgentID: agent.ID,
LogSourceID: uuid.New(),
LogPath: "/cool/log/path/1",
Script: "cool script 1",
Cron: "30 2 * * *",
StartBlocksLogin: true,
RunOnStart: true,
RunOnStop: false,
TimeoutSeconds: 60,
},
{
WorkspaceAgentID: agent.ID,
LogSourceID: uuid.New(),
LogPath: "/cool/log/path/2",
Script: "cool script 2",
Cron: "",
StartBlocksLogin: false,
RunOnStart: false,
RunOnStop: true,
TimeoutSeconds: 30,
},
}
metadata = []database.WorkspaceAgentMetadatum{
{
WorkspaceAgentID: agent.ID,
DisplayName: "cool metadata 1",
Key: "cool-key-1",
Script: "cool script 1",
Value: "cool value 1",
Error: "",
Timeout: int64(time.Minute),
Interval: int64(time.Minute),
CollectedAt: someTime,
},
{
WorkspaceAgentID: agent.ID,
DisplayName: "cool metadata 2",
Key: "cool-key-2",
Script: "cool script 2",
Value: "cool value 2",
Error: "some uncool error",
Timeout: int64(5 * time.Second),
Interval: int64(20 * time.Minute),
CollectedAt: someTime.Add(time.Hour),
},
}
derpMapFn = func() *tailcfg.DERPMap {
return &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {RegionName: "cool region"},
},
}
}
)
// These are done manually to ensure the conversion logic matches what a
// human expects.
var (
protoApps = []*agentproto.WorkspaceApp{
{
Id: apps[0].ID[:],
Url: apps[0].Url.String,
External: apps[0].External,
Slug: apps[0].Slug,
DisplayName: apps[0].DisplayName,
Command: apps[0].Command.String,
Icon: apps[0].Icon,
Subdomain: apps[0].Subdomain,
SubdomainName: fmt.Sprintf("%s--%s--%s--%s", apps[0].Slug, agent.Name, workspace.Name, owner.Username),
SharingLevel: agentproto.WorkspaceApp_AUTHENTICATED,
Healthcheck: &agentproto.WorkspaceApp_Healthcheck{
Url: apps[0].HealthcheckUrl,
Interval: durationpb.New(time.Duration(apps[0].HealthcheckInterval) * time.Second),
Threshold: apps[0].HealthcheckThreshold,
},
Health: agentproto.WorkspaceApp_HEALTHY,
},
{
Id: apps[1].ID[:],
Url: apps[1].Url.String,
External: apps[1].External,
Slug: apps[1].Slug,
DisplayName: apps[1].DisplayName,
Command: apps[1].Command.String,
Icon: apps[1].Icon,
Subdomain: false,
SubdomainName: "",
SharingLevel: agentproto.WorkspaceApp_PUBLIC,
Healthcheck: &agentproto.WorkspaceApp_Healthcheck{
Url: "",
Interval: durationpb.New(0),
Threshold: 0,
},
Health: agentproto.WorkspaceApp_DISABLED,
},
{
Id: apps[2].ID[:],
Url: apps[2].Url.String,
External: apps[2].External,
Slug: apps[2].Slug,
DisplayName: apps[2].DisplayName,
Command: apps[2].Command.String,
Icon: apps[2].Icon,
Subdomain: false,
SubdomainName: "",
SharingLevel: agentproto.WorkspaceApp_OWNER,
Healthcheck: &agentproto.WorkspaceApp_Healthcheck{
Url: apps[2].HealthcheckUrl,
Interval: durationpb.New(time.Duration(apps[2].HealthcheckInterval) * time.Second),
Threshold: apps[2].HealthcheckThreshold,
},
Health: agentproto.WorkspaceApp_UNHEALTHY,
},
}
protoScripts = []*agentproto.WorkspaceAgentScript{
{
LogSourceId: scripts[0].LogSourceID[:],
LogPath: scripts[0].LogPath,
Script: scripts[0].Script,
Cron: scripts[0].Cron,
RunOnStart: scripts[0].RunOnStart,
RunOnStop: scripts[0].RunOnStop,
StartBlocksLogin: scripts[0].StartBlocksLogin,
Timeout: durationpb.New(time.Duration(scripts[0].TimeoutSeconds) * time.Second),
},
{
LogSourceId: scripts[1].LogSourceID[:],
LogPath: scripts[1].LogPath,
Script: scripts[1].Script,
Cron: scripts[1].Cron,
RunOnStart: scripts[1].RunOnStart,
RunOnStop: scripts[1].RunOnStop,
StartBlocksLogin: scripts[1].StartBlocksLogin,
Timeout: durationpb.New(time.Duration(scripts[1].TimeoutSeconds) * time.Second),
},
}
protoMetadata = []*agentproto.WorkspaceAgentMetadata_Description{
{
DisplayName: metadata[0].DisplayName,
Key: metadata[0].Key,
Script: metadata[0].Script,
Interval: durationpb.New(time.Duration(metadata[0].Interval)),
Timeout: durationpb.New(time.Duration(metadata[0].Timeout)),
},
{
DisplayName: metadata[1].DisplayName,
Key: metadata[1].Key,
Script: metadata[1].Script,
Interval: durationpb.New(time.Duration(metadata[1].Interval)),
Timeout: durationpb.New(time.Duration(metadata[1].Timeout)),
},
}
)
t.Run("OK", func(t *testing.T) {
t.Parallel()
mDB := dbmock.NewMockStore(gomock.NewController(t))
api := &agentapi.ManifestAPI{
AccessURL: &url.URL{Scheme: "https", Host: "example.com"},
AppHostname: "*--apps.example.com",
ExternalAuthConfigs: []*externalauth.Config{
{Type: string(codersdk.EnhancedExternalAuthProviderGitHub)},
{Type: "some-provider"},
{Type: string(codersdk.EnhancedExternalAuthProviderGitLab)},
},
DisableDirectConnections: true,
DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) {
return workspace.ID, nil
},
Database: mDB,
DerpMapFn: derpMapFn,
}
mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil)
mDB.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), []uuid.UUID{agent.ID}).Return(scripts, nil)
mDB.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), database.GetWorkspaceAgentMetadataParams{
WorkspaceAgentID: agent.ID,
Keys: nil, // all
}).Return(metadata, nil)
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
require.NoError(t, err)
expected := &agentproto.Manifest{
AgentId: agent.ID[:],
AgentName: agent.Name,
OwnerUsername: owner.Username,
WorkspaceId: workspace.ID[:],
WorkspaceName: workspace.Name,
GitAuthConfigs: 2, // two "enhanced" external auth configs
EnvironmentVariables: expectedEnvVars,
Directory: agent.Directory,
VsCodePortProxyUri: fmt.Sprintf("https://{{port}}--%s--%s--%s--apps.example.com", agent.Name, workspace.Name, owner.Username),
MotdPath: agent.MOTDFile,
DisableDirectConnections: true,
DerpForceWebsockets: true,
// tailnet.DERPMapToProto() is extensively tested elsewhere, so it's
// not necessary to manually recreate a big DERP map here like we
// did for apps and metadata.
DerpMap: tailnet.DERPMapToProto(derpMapFn()),
Scripts: protoScripts,
Apps: protoApps,
Metadata: protoMetadata,
}
// Log got and expected with spew.
// t.Log("got:\n" + spew.Sdump(got))
// t.Log("expected:\n" + spew.Sdump(expected))
require.Equal(t, expected, got)
})
t.Run("NoAppHostname", func(t *testing.T) {
t.Parallel()
mDB := dbmock.NewMockStore(gomock.NewController(t))
api := &agentapi.ManifestAPI{
AccessURL: &url.URL{Scheme: "https", Host: "example.com"},
AppHostname: "",
ExternalAuthConfigs: []*externalauth.Config{
{Type: string(codersdk.EnhancedExternalAuthProviderGitHub)},
{Type: "some-provider"},
{Type: string(codersdk.EnhancedExternalAuthProviderGitLab)},
},
DisableDirectConnections: true,
DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) {
return workspace.ID, nil
},
Database: mDB,
DerpMapFn: derpMapFn,
}
mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil)
mDB.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), []uuid.UUID{agent.ID}).Return(scripts, nil)
mDB.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), database.GetWorkspaceAgentMetadataParams{
WorkspaceAgentID: agent.ID,
Keys: nil, // all
}).Return(metadata, nil)
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
require.NoError(t, err)
expected := &agentproto.Manifest{
AgentId: agent.ID[:],
AgentName: agent.Name,
OwnerUsername: owner.Username,
WorkspaceId: workspace.ID[:],
WorkspaceName: workspace.Name,
GitAuthConfigs: 2, // two "enhanced" external auth configs
EnvironmentVariables: expectedEnvVars,
Directory: agent.Directory,
VsCodePortProxyUri: "", // empty with no AppHost
MotdPath: agent.MOTDFile,
DisableDirectConnections: true,
DerpForceWebsockets: true,
// tailnet.DERPMapToProto() is extensively tested elsewhere, so it's
// not necessary to manually recreate a big DERP map here like we
// did for apps and metadata.
DerpMap: tailnet.DERPMapToProto(derpMapFn()),
Scripts: protoScripts,
Apps: protoApps,
Metadata: protoMetadata,
}
// Log got and expected with spew.
// t.Log("got:\n" + spew.Sdump(got))
// t.Log("expected:\n" + spew.Sdump(expected))
require.Equal(t, expected, got)
})
}

View File

@ -12,6 +12,7 @@ import (
"cdr.dev/slog"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
)
@ -20,14 +21,26 @@ type MetadataAPI struct {
Database database.Store
Pubsub pubsub.Pubsub
Log slog.Logger
TimeNowFn func() time.Time // defaults to dbtime.Now()
}
func (a *MetadataAPI) now() time.Time {
if a.TimeNowFn != nil {
return a.TimeNowFn()
}
return dbtime.Now()
}
func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
const (
// maxValueLen is set to 2048 to stay under the 8000 byte Postgres
// NOTIFY limit. Since both value and error can be set, the real payload
// limit is 2 * 2048 * 4/3 <base64 expansion> = 5461 bytes + a few
// hundred bytes for JSON syntax, key names, and metadata.
// maxAllKeysLen is the maximum length of all metadata keys. This is
// 6144 to stay below the Postgres NOTIFY limit of 8000 bytes, with some
// headway for the timestamp and JSON encoding. Any values that would
// exceed this limit are discarded (the rest are still inserted) and an
// error is returned.
maxAllKeysLen = 6144 // 1024 * 6
maxValueLen = 2048
maxErrorLen = maxValueLen
)
@ -37,18 +50,36 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
return nil, err
}
collectedAt := time.Now()
dbUpdate := database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: workspaceAgent.ID,
Key: make([]string, 0, len(req.Metadata)),
Value: make([]string, 0, len(req.Metadata)),
Error: make([]string, 0, len(req.Metadata)),
CollectedAt: make([]time.Time, 0, len(req.Metadata)),
}
var (
collectedAt = a.now()
allKeysLen = 0
dbUpdate = database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: workspaceAgent.ID,
// These need to be `make(x, 0, len(req.Metadata))` instead of
// `make(x, len(req.Metadata))` because we may not insert all
// metadata if the keys are large.
Key: make([]string, 0, len(req.Metadata)),
Value: make([]string, 0, len(req.Metadata)),
Error: make([]string, 0, len(req.Metadata)),
CollectedAt: make([]time.Time, 0, len(req.Metadata)),
}
)
for _, md := range req.Metadata {
metadataError := md.Result.Error
allKeysLen += len(md.Key)
if allKeysLen > maxAllKeysLen {
// We still insert the rest of the metadata, and we return an error
// after the insert.
a.Log.Warn(
ctx, "discarded extra agent metadata due to excessive key length",
slog.F("collected_at", collectedAt),
slog.F("all_keys_len", allKeysLen),
slog.F("max_all_keys_len", maxAllKeysLen),
)
break
}
// We overwrite the error if the provided payload is too long.
if len(md.Result.Value) > maxValueLen {
metadataError = fmt.Sprintf("value of %d bytes exceeded %d bytes", len(md.Result.Value), maxValueLen)
@ -71,12 +102,16 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
a.Log.Debug(
ctx, "accepted metadata report",
slog.F("collected_at", collectedAt),
slog.F("original_collected_at", collectedAt),
slog.F("key", md.Key),
slog.F("value", ellipse(md.Result.Value, 16)),
)
}
err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate)
if err != nil {
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
}
payload, err := json.Marshal(WorkspaceAgentMetadataChannelPayload{
CollectedAt: collectedAt,
Keys: dbUpdate.Key,
@ -84,17 +119,17 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
if err != nil {
return nil, xerrors.Errorf("marshal workspace agent metadata channel payload: %w", err)
}
err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate)
if err != nil {
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
}
err = a.Pubsub.Publish(WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), payload)
if err != nil {
return nil, xerrors.Errorf("publish workspace agent metadata: %w", err)
}
// If the metadata keys were too large, we return an error so the agent can
// log it.
if allKeysLen > maxAllKeysLen {
return nil, xerrors.Errorf("metadata keys of %d bytes exceeded %d bytes", allKeysLen, maxAllKeysLen)
}
return &agentproto.BatchUpdateMetadataResponse{}, nil
}

View File

@ -0,0 +1,276 @@
package agentapi_test
import (
"context"
"encoding/json"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/types/known/timestamppb"
"cdr.dev/slog/sloggers/slogtest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
)
type fakePublisher struct {
// Nil pointer to pass interface check.
pubsub.Pubsub
publishes [][]byte
}
var _ pubsub.Pubsub = &fakePublisher{}
func (f *fakePublisher) Publish(_ string, message []byte) error {
f.publishes = append(f.publishes, message)
return nil
}
func TestBatchUpdateMetadata(t *testing.T) {
t.Parallel()
agent := database.WorkspaceAgent{
ID: uuid.New(),
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
pub := &fakePublisher{}
now := dbtime.Now()
req := &agentproto.BatchUpdateMetadataRequest{
Metadata: []*agentproto.Metadata{
{
Key: "awesome key",
Result: &agentproto.WorkspaceAgentMetadata_Result{
CollectedAt: timestamppb.New(now.Add(-10 * time.Second)),
Age: 10,
Value: "awesome value",
Error: "",
},
},
{
Key: "uncool key",
Result: &agentproto.WorkspaceAgentMetadata_Result{
CollectedAt: timestamppb.New(now.Add(-3 * time.Second)),
Age: 3,
Value: "",
Error: "uncool value",
},
},
},
}
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: agent.ID,
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key},
Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value},
Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error},
// The value from the agent is ignored.
CollectedAt: []time.Time{now, now},
}).Return(nil)
api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Pubsub: pub,
Log: slogtest.Make(t, nil),
TimeNowFn: func() time.Time {
return now
},
}
resp, err := api.BatchUpdateMetadata(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp)
require.Equal(t, 1, len(pub.publishes))
var gotEvent agentapi.WorkspaceAgentMetadataChannelPayload
require.NoError(t, json.Unmarshal(pub.publishes[0], &gotEvent))
require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{
CollectedAt: now,
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key},
}, gotEvent)
})
t.Run("ExceededLength", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
pub := pubsub.NewInMemory()
almostLongValue := ""
for i := 0; i < 2048; i++ {
almostLongValue += "a"
}
now := dbtime.Now()
req := &agentproto.BatchUpdateMetadataRequest{
Metadata: []*agentproto.Metadata{
{
Key: "almost long value",
Result: &agentproto.WorkspaceAgentMetadata_Result{
Value: almostLongValue,
},
},
{
Key: "too long value",
Result: &agentproto.WorkspaceAgentMetadata_Result{
Value: almostLongValue + "a",
},
},
{
Key: "almost long error",
Result: &agentproto.WorkspaceAgentMetadata_Result{
Error: almostLongValue,
},
},
{
Key: "too long error",
Result: &agentproto.WorkspaceAgentMetadata_Result{
Error: almostLongValue + "a",
},
},
},
}
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: agent.ID,
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key, req.Metadata[3].Key},
Value: []string{
almostLongValue,
almostLongValue, // truncated
"",
"",
},
Error: []string{
"",
"value of 2049 bytes exceeded 2048 bytes",
almostLongValue,
"error of 2049 bytes exceeded 2048 bytes", // replaced
},
// The value from the agent is ignored.
CollectedAt: []time.Time{now, now, now, now},
}).Return(nil)
api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Pubsub: pub,
Log: slogtest.Make(t, nil),
TimeNowFn: func() time.Time {
return now
},
}
resp, err := api.BatchUpdateMetadata(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp)
})
t.Run("KeysTooLong", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
pub := pubsub.NewInMemory()
now := dbtime.Now()
req := &agentproto.BatchUpdateMetadataRequest{
Metadata: []*agentproto.Metadata{
{
Key: "key 1",
Result: &agentproto.WorkspaceAgentMetadata_Result{
Value: "value 1",
},
},
{
Key: "key 2",
Result: &agentproto.WorkspaceAgentMetadata_Result{
Value: "value 2",
},
},
{
Key: func() string {
key := "key 3 "
for i := 0; i < (6144 - len("key 1") - len("key 2") - len("key 3") - 1); i++ {
key += "a"
}
return key
}(),
Result: &agentproto.WorkspaceAgentMetadata_Result{
Value: "value 3",
},
},
{
Key: "a", // should be ignored
Result: &agentproto.WorkspaceAgentMetadata_Result{
Value: "value 4",
},
},
},
}
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: agent.ID,
// No key 4.
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value, req.Metadata[2].Result.Value},
Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error, req.Metadata[2].Result.Error},
// The value from the agent is ignored.
CollectedAt: []time.Time{now, now, now},
}).Return(nil)
api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Pubsub: pub,
Log: slogtest.Make(t, nil),
TimeNowFn: func() time.Time {
return now
},
}
// Watch the pubsub for events.
var (
eventCount int64
gotEvent agentapi.WorkspaceAgentMetadataChannelPayload
)
cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) {
if atomic.AddInt64(&eventCount, 1) > 1 {
return
}
require.NoError(t, json.Unmarshal(message, &gotEvent))
})
require.NoError(t, err)
defer cancel()
resp, err := api.BatchUpdateMetadata(context.Background(), req)
require.Error(t, err)
require.Equal(t, "metadata keys of 6145 bytes exceeded 6144 bytes", err.Error())
require.Nil(t, resp)
require.Equal(t, int64(1), atomic.LoadInt64(&eventCount))
require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{
CollectedAt: now,
// No key 4.
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
}, gotEvent)
})
}

View File

@ -0,0 +1,84 @@
package agentapi_test
import (
"context"
"database/sql"
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/codersdk"
)
func TestGetServiceBanner(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
cfg := codersdk.ServiceBannerConfig{
Enabled: true,
Message: "hello world",
BackgroundColor: "#000000",
}
cfgJSON, err := json.Marshal(cfg)
require.NoError(t, err)
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetServiceBanner(gomock.Any()).Return(string(cfgJSON), nil)
api := &agentapi.ServiceBannerAPI{
Database: dbM,
}
resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{})
require.NoError(t, err)
require.Equal(t, &agentproto.ServiceBanner{
Enabled: cfg.Enabled,
Message: cfg.Message,
BackgroundColor: cfg.BackgroundColor,
}, resp)
})
t.Run("None", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetServiceBanner(gomock.Any()).Return("", sql.ErrNoRows)
api := &agentapi.ServiceBannerAPI{
Database: dbM,
}
resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{})
require.NoError(t, err)
require.Equal(t, &agentproto.ServiceBanner{
Enabled: false,
Message: "",
BackgroundColor: "",
}, resp)
})
t.Run("BadJSON", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetServiceBanner(gomock.Any()).Return("hi", nil)
api := &agentapi.ServiceBannerAPI{
Database: dbM,
}
resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{})
require.Error(t, err)
require.ErrorContains(t, err, "unmarshal json")
require.Nil(t, resp)
})
}

View File

@ -9,57 +9,71 @@ import (
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/google/uuid"
"cdr.dev/slog"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/autobuild"
"github.com/coder/coder/v2/coderd/batchstats"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/prometheusmetrics"
"github.com/coder/coder/v2/coderd/schedule"
)
type StatsBatcher interface {
Add(now time.Time, agentID uuid.UUID, templateID uuid.UUID, userID uuid.UUID, workspaceID uuid.UUID, st *agentproto.Stats) error
}
type StatsAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store
Log slog.Logger
StatsBatcher *batchstats.Batcher
StatsBatcher StatsBatcher
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
AgentStatsRefreshInterval time.Duration
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
TimeNowFn func() time.Time // defaults to dbtime.Now()
}
func (a *StatsAPI) now() time.Time {
if a.TimeNowFn != nil {
return a.TimeNowFn()
}
return dbtime.Now()
}
func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
// An empty stat means it's just looking for the report interval.
res := &agentproto.UpdateStatsResponse{
ReportInterval: durationpb.New(a.AgentStatsRefreshInterval),
}
if req.Stats == nil || len(req.Stats.ConnectionsByProto) == 0 {
return res, nil
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
row, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
getWorkspaceAgentByIDRow, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
}
workspace := row.Workspace
res := &agentproto.UpdateStatsResponse{
ReportInterval: durationpb.New(a.AgentStatsRefreshInterval),
}
// An empty stat means it's just looking for the report interval.
if len(req.Stats.ConnectionsByProto) == 0 {
return res, nil
}
workspace := getWorkspaceAgentByIDRow.Workspace
a.Log.Debug(ctx, "read stats report",
slog.F("interval", a.AgentStatsRefreshInterval),
slog.F("workspace_id", workspace.ID),
slog.F("payload", req),
)
now := a.now()
if req.Stats.ConnectionCount > 0 {
var nextAutostart time.Time
if workspace.AutostartSchedule.String != "" {
templateSchedule, err := (*(a.TemplateScheduleStore.Load())).Get(ctx, a.Database, workspace.TemplateID)
// If the template schedule fails to load, just default to bumping without the next trasition and log it.
// If the template schedule fails to load, just default to bumping
// without the next transition and log it.
if err != nil {
a.Log.Error(ctx, "failed to load template schedule bumping activity, defaulting to bumping by 60min",
slog.F("workspace_id", workspace.ID),
@ -67,7 +81,7 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
slog.Error(err),
)
} else {
next, allowed := autobuild.NextAutostartSchedule(time.Now(), workspace.AutostartSchedule.String, templateSchedule)
next, allowed := autobuild.NextAutostartSchedule(now, workspace.AutostartSchedule.String, templateSchedule)
if allowed {
nextAutostart = next
}
@ -76,13 +90,12 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
ActivityBumpWorkspace(ctx, a.Log.Named("activity_bump"), a.Database, workspace.ID, nextAutostart)
}
now := dbtime.Now()
var errGroup errgroup.Group
errGroup.Go(func() error {
if err := a.StatsBatcher.Add(time.Now(), workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, req.Stats); err != nil {
a.Log.Error(ctx, "failed to add stats to batcher", slog.Error(err))
return xerrors.Errorf("can't insert workspace agent stat: %w", err)
err := a.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, req.Stats)
if err != nil {
a.Log.Error(ctx, "add agent stats to batcher", slog.Error(err))
return xerrors.Errorf("insert workspace agent stats batch: %w", err)
}
return nil
})
@ -92,7 +105,7 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
LastUsedAt: now,
})
if err != nil {
return xerrors.Errorf("can't update workspace LastUsedAt: %w", err)
return xerrors.Errorf("update workspace LastUsedAt: %w", err)
}
return nil
})
@ -100,14 +113,14 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
errGroup.Go(func() error {
user, err := a.Database.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
return xerrors.Errorf("can't get user: %w", err)
return xerrors.Errorf("get user: %w", err)
}
a.UpdateAgentMetricsFn(ctx, prometheusmetrics.AgentMetricLabels{
Username: user.Username,
WorkspaceName: workspace.Name,
AgentName: workspaceAgent.Name,
TemplateName: row.TemplateName,
TemplateName: getWorkspaceAgentByIDRow.TemplateName,
}, req.Stats.Metrics)
return nil
})

View File

@ -0,0 +1,379 @@
package agentapi_test
import (
"context"
"database/sql"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/types/known/durationpb"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/prometheusmetrics"
"github.com/coder/coder/v2/coderd/schedule"
)
type statsBatcher struct {
mu sync.Mutex
called int64
lastTime time.Time
lastAgentID uuid.UUID
lastTemplateID uuid.UUID
lastUserID uuid.UUID
lastWorkspaceID uuid.UUID
lastStats *agentproto.Stats
}
var _ agentapi.StatsBatcher = &statsBatcher{}
func (b *statsBatcher) Add(now time.Time, agentID uuid.UUID, templateID uuid.UUID, userID uuid.UUID, workspaceID uuid.UUID, st *agentproto.Stats) error {
b.mu.Lock()
defer b.mu.Unlock()
b.called++
b.lastTime = now
b.lastAgentID = agentID
b.lastTemplateID = templateID
b.lastUserID = userID
b.lastWorkspaceID = workspaceID
b.lastStats = st
return nil
}
func TestUpdateStates(t *testing.T) {
t.Parallel()
var (
user = database.User{
ID: uuid.New(),
Username: "bill",
}
template = database.Template{
ID: uuid.New(),
Name: "tpl",
}
workspace = database.Workspace{
ID: uuid.New(),
OwnerID: user.ID,
TemplateID: template.ID,
Name: "xyz",
}
agent = database.WorkspaceAgent{
ID: uuid.New(),
Name: "abc",
}
)
t.Run("OK", func(t *testing.T) {
t.Parallel()
var (
now = dbtime.Now()
dbM = dbmock.NewMockStore(gomock.NewController(t))
templateScheduleStore = schedule.MockTemplateScheduleStore{
GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) {
panic("should not be called")
},
SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) {
panic("not implemented")
},
}
batcher = &statsBatcher{}
updateAgentMetricsFnCalled = false
req = &agentproto.UpdateStatsRequest{
Stats: &agentproto.Stats{
ConnectionsByProto: map[string]int64{
"tcp": 1,
"dean": 2,
},
ConnectionCount: 3,
ConnectionMedianLatencyMs: 23,
RxPackets: 120,
RxBytes: 1000,
TxPackets: 130,
TxBytes: 2000,
SessionCountVscode: 1,
SessionCountJetbrains: 2,
SessionCountReconnectingPty: 3,
SessionCountSsh: 4,
Metrics: []*agentproto.Stats_Metric{
{
Name: "awesome metric",
Value: 42,
},
{
Name: "uncool metric",
Value: 0,
},
},
},
}
)
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
StatsBatcher: batcher,
TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore),
AgentStatsRefreshInterval: 10 * time.Second,
UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) {
updateAgentMetricsFnCalled = true
assert.Equal(t, prometheusmetrics.AgentMetricLabels{
Username: user.Username,
WorkspaceName: workspace.Name,
AgentName: agent.Name,
TemplateName: template.Name,
}, labels)
assert.Equal(t, req.Stats.Metrics, metrics)
},
TimeNowFn: func() time.Time {
return now
},
}
// Workspace gets fetched.
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{
Workspace: workspace,
TemplateName: template.Name,
}, nil)
// We expect an activity bump because ConnectionCount > 0.
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
WorkspaceID: workspace.ID,
NextAutostart: time.Time{}.UTC(),
}).Return(nil)
// Workspace last used at gets bumped.
dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{
ID: workspace.ID,
LastUsedAt: now,
}).Return(nil)
// User gets fetched to hit the UpdateAgentMetricsFn.
dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil)
resp, err := api.UpdateStats(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.UpdateStatsResponse{
ReportInterval: durationpb.New(10 * time.Second),
}, resp)
batcher.mu.Lock()
defer batcher.mu.Unlock()
require.Equal(t, int64(1), batcher.called)
require.Equal(t, now, batcher.lastTime)
require.Equal(t, agent.ID, batcher.lastAgentID)
require.Equal(t, template.ID, batcher.lastTemplateID)
require.Equal(t, user.ID, batcher.lastUserID)
require.Equal(t, workspace.ID, batcher.lastWorkspaceID)
require.Equal(t, req.Stats, batcher.lastStats)
require.True(t, updateAgentMetricsFnCalled)
})
t.Run("ConnectionCountZero", func(t *testing.T) {
t.Parallel()
var (
now = dbtime.Now()
dbM = dbmock.NewMockStore(gomock.NewController(t))
templateScheduleStore = schedule.MockTemplateScheduleStore{
GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) {
panic("should not be called")
},
SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) {
panic("not implemented")
},
}
batcher = &statsBatcher{}
req = &agentproto.UpdateStatsRequest{
Stats: &agentproto.Stats{
ConnectionsByProto: map[string]int64{
"tcp": 1,
},
ConnectionCount: 0,
ConnectionMedianLatencyMs: 23,
},
}
)
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
StatsBatcher: batcher,
TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore),
AgentStatsRefreshInterval: 10 * time.Second,
// Ignored when nil.
UpdateAgentMetricsFn: nil,
TimeNowFn: func() time.Time {
return now
},
}
// Workspace gets fetched.
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{
Workspace: workspace,
TemplateName: template.Name,
}, nil)
// Workspace last used at gets bumped.
dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{
ID: workspace.ID,
LastUsedAt: now,
}).Return(nil)
_, err := api.UpdateStats(context.Background(), req)
require.NoError(t, err)
})
t.Run("NoConnectionsByProto", func(t *testing.T) {
t.Parallel()
var (
dbM = dbmock.NewMockStore(gomock.NewController(t))
req = &agentproto.UpdateStatsRequest{
Stats: &agentproto.Stats{
ConnectionsByProto: map[string]int64{}, // len() == 0
},
}
)
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
StatsBatcher: nil, // should not be called
TemplateScheduleStore: nil, // should not be called
AgentStatsRefreshInterval: 10 * time.Second,
UpdateAgentMetricsFn: nil, // should not be called
TimeNowFn: func() time.Time {
panic("should not be called")
},
}
resp, err := api.UpdateStats(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.UpdateStatsResponse{
ReportInterval: durationpb.New(10 * time.Second),
}, resp)
})
t.Run("AutostartAwareBump", func(t *testing.T) {
t.Parallel()
// Use a workspace with an autostart schedule.
workspace := workspace
workspace.AutostartSchedule = sql.NullString{
String: "CRON_TZ=Australia/Sydney 0 8 * * *",
Valid: true,
}
// Use a custom time for now which would trigger the autostart aware
// bump.
now, err := time.Parse("2006-01-02 15:04:05 -0700 MST", "2023-12-19 07:30:00 +1100 AEDT")
require.NoError(t, err)
now = dbtime.Time(now)
nextAutostart := now.Add(30 * time.Minute).UTC() // always sent to DB as UTC
var (
dbM = dbmock.NewMockStore(gomock.NewController(t))
templateScheduleStore = schedule.MockTemplateScheduleStore{
GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) {
return schedule.TemplateScheduleOptions{
UserAutostartEnabled: true,
AutostartRequirement: schedule.TemplateAutostartRequirement{
DaysOfWeek: 0b01111111, // every day
},
}, nil
},
SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) {
panic("not implemented")
},
}
batcher = &statsBatcher{}
updateAgentMetricsFnCalled = false
req = &agentproto.UpdateStatsRequest{
Stats: &agentproto.Stats{
ConnectionsByProto: map[string]int64{
"tcp": 1,
"dean": 2,
},
ConnectionCount: 3,
},
}
)
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
StatsBatcher: batcher,
TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore),
AgentStatsRefreshInterval: 15 * time.Second,
UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) {
updateAgentMetricsFnCalled = true
assert.Equal(t, prometheusmetrics.AgentMetricLabels{
Username: user.Username,
WorkspaceName: workspace.Name,
AgentName: agent.Name,
TemplateName: template.Name,
}, labels)
assert.Equal(t, req.Stats.Metrics, metrics)
},
TimeNowFn: func() time.Time {
return now
},
}
// Workspace gets fetched.
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{
Workspace: workspace,
TemplateName: template.Name,
}, nil)
// We expect an activity bump because ConnectionCount > 0. However, the
// next autostart time will be set on the bump.
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
WorkspaceID: workspace.ID,
NextAutostart: nextAutostart,
}).Return(nil)
// Workspace last used at gets bumped.
dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{
ID: workspace.ID,
LastUsedAt: now,
}).Return(nil)
// User gets fetched to hit the UpdateAgentMetricsFn.
dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil)
resp, err := api.UpdateStats(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.UpdateStatsResponse{
ReportInterval: durationpb.New(15 * time.Second),
}, resp)
require.True(t, updateAgentMetricsFnCalled)
})
}
func templateScheduleStorePtr(store schedule.TemplateScheduleStore) *atomic.Pointer[schedule.TemplateScheduleStore] {
var ptr atomic.Pointer[schedule.TemplateScheduleStore]
ptr.Store(&store)
return &ptr
}

View File

@ -266,16 +266,25 @@ func convertDisplayApps(apps []database.DisplayApp) []codersdk.DisplayApp {
return dapps
}
func WorkspaceAgentEnvironment(workspaceAgent database.WorkspaceAgent) (map[string]string, error) {
var envs map[string]string
if workspaceAgent.EnvironmentVariables.Valid {
err := json.Unmarshal(workspaceAgent.EnvironmentVariables.RawMessage, &envs)
if err != nil {
return nil, xerrors.Errorf("unmarshal environment variables: %w", err)
}
}
return envs, nil
}
func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, scripts []codersdk.WorkspaceAgentScript, logSources []codersdk.WorkspaceAgentLogSource,
agentInactiveDisconnectTimeout time.Duration, agentFallbackTroubleshootingURL string,
) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal env vars: %w", err)
}
envs, err := WorkspaceAgentEnvironment(dbAgent)
if err != nil {
return codersdk.WorkspaceAgent{}, err
}
troubleshootingURL := agentFallbackTroubleshootingURL
if dbAgent.TroubleshootingURL != "" {

View File

@ -154,18 +154,24 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request)
// As this API becomes deprecated, use the new protobuf API and convert the
// types back to the SDK types.
manifestAPI := &agentapi.ManifestAPI{
AccessURL: api.AccessURL,
AppHostname: api.AppHostname,
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
ExternalAuthConfigs: api.ExternalAuthConfigs,
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
AccessURL: api.AccessURL,
AppHostname: api.AppHostname,
ExternalAuthConfigs: api.ExternalAuthConfigs,
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil },
Database: api.Database,
DerpMapFn: api.DERPMap,
TailnetCoordinator: &api.TailnetCoordinator,
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil },
WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) {
// Sadly this results in a double query, but it's only temporary for
// now.
ws, err := api.Database.GetWorkspaceByAgentID(ctx, wa.ID)
if err != nil {
return uuid.Nil, err
}
return ws.Workspace.ID, nil
},
Database: api.Database,
DerpMapFn: api.DERPMap,
}
manifest, err := manifestAPI.GetManifest(ctx, &agentproto.GetManifestRequest{})
if err != nil {

View File

@ -134,15 +134,13 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
PublishWorkspaceAgentLogsUpdateFn: api.publishWorkspaceAgentLogsUpdate,
AccessURL: api.AccessURL,
AppHostname: api.AppHostname,
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
AgentStatsRefreshInterval: api.AgentStatsRefreshInterval,
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
ExternalAuthConfigs: api.ExternalAuthConfigs,
AccessURL: api.AccessURL,
AppHostname: api.AppHostname,
AgentStatsRefreshInterval: api.AgentStatsRefreshInterval,
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
ExternalAuthConfigs: api.ExternalAuthConfigs,
// Optional:
WorkspaceID: build.WorkspaceID, // saves the extra lookup later