mirror of https://github.com/coder/coder.git
chore: add agentapi tests (#11269)
This commit is contained in:
parent
541154b74b
commit
29707099d7
|
@ -41,13 +41,14 @@ func ActivityBumpWorkspace(ctx context.Context, log slog.Logger, db database.Sto
|
|||
// low priority operations fail first.
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*15)
|
||||
defer cancel()
|
||||
if err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{
|
||||
err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{
|
||||
NextAutostart: nextAutostart.UTC(),
|
||||
WorkspaceID: workspaceID,
|
||||
}); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
|
||||
// Bump will fail if the context is canceled, but this is ok.
|
||||
log.Error(ctx, "bump failed", slog.Error(err),
|
||||
log.Error(ctx, "activity bump failed", slog.Error(err),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/batchstats"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
|
@ -61,19 +60,17 @@ type Options struct {
|
|||
DerpMapFn func() *tailcfg.DERPMap
|
||||
TailnetCoordinator *atomic.Pointer[tailnet.Coordinator]
|
||||
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
||||
StatsBatcher *batchstats.Batcher
|
||||
StatsBatcher StatsBatcher
|
||||
PublishWorkspaceUpdateFn func(ctx context.Context, workspaceID uuid.UUID)
|
||||
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
|
||||
|
||||
AccessURL *url.URL
|
||||
AppHostname string
|
||||
AgentInactiveDisconnectTimeout time.Duration
|
||||
AgentFallbackTroubleshootingURL string
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
DisableDirectConnections bool
|
||||
DerpForceWebSockets bool
|
||||
DerpMapUpdateFrequency time.Duration
|
||||
ExternalAuthConfigs []*externalauth.Config
|
||||
AccessURL *url.URL
|
||||
AppHostname string
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
DisableDirectConnections bool
|
||||
DerpForceWebSockets bool
|
||||
DerpMapUpdateFrequency time.Duration
|
||||
ExternalAuthConfigs []*externalauth.Config
|
||||
|
||||
// Optional:
|
||||
// WorkspaceID avoids a future lookup to find the workspace ID by setting
|
||||
|
@ -90,17 +87,14 @@ func New(opts Options) *API {
|
|||
}
|
||||
|
||||
api.ManifestAPI = &ManifestAPI{
|
||||
AccessURL: opts.AccessURL,
|
||||
AppHostname: opts.AppHostname,
|
||||
AgentInactiveDisconnectTimeout: opts.AgentInactiveDisconnectTimeout,
|
||||
AgentFallbackTroubleshootingURL: opts.AgentFallbackTroubleshootingURL,
|
||||
ExternalAuthConfigs: opts.ExternalAuthConfigs,
|
||||
DisableDirectConnections: opts.DisableDirectConnections,
|
||||
DerpForceWebSockets: opts.DerpForceWebSockets,
|
||||
AgentFn: api.agent,
|
||||
Database: opts.Database,
|
||||
DerpMapFn: opts.DerpMapFn,
|
||||
TailnetCoordinator: opts.TailnetCoordinator,
|
||||
AccessURL: opts.AccessURL,
|
||||
AppHostname: opts.AppHostname,
|
||||
ExternalAuthConfigs: opts.ExternalAuthConfigs,
|
||||
DisableDirectConnections: opts.DisableDirectConnections,
|
||||
DerpForceWebSockets: opts.DerpForceWebSockets,
|
||||
AgentFn: api.agent,
|
||||
Database: opts.Database,
|
||||
DerpMapFn: opts.DerpMapFn,
|
||||
}
|
||||
|
||||
api.ServiceBannerAPI = &ServiceBannerAPI{
|
||||
|
@ -214,20 +208,15 @@ func (a *API) workspaceID(ctx context.Context, agent *database.WorkspaceAgent) (
|
|||
agent = &agnt
|
||||
}
|
||||
|
||||
resource, err := a.opts.Database.GetWorkspaceResourceByID(ctx, agent.ResourceID)
|
||||
getWorkspaceAgentByIDRow, err := a.opts.Database.GetWorkspaceByAgentID(ctx, agent.ID)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("get workspace agent resource by id %q: %w", agent.ResourceID, err)
|
||||
}
|
||||
|
||||
build, err := a.opts.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("get workspace build by job id %q: %w", resource.JobID, err)
|
||||
return uuid.Nil, xerrors.Errorf("get workspace by agent id %q: %w", agent.ID, err)
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
a.cachedWorkspaceID = build.WorkspaceID
|
||||
a.cachedWorkspaceID = getWorkspaceAgentByIDRow.Workspace.ID
|
||||
a.mu.Unlock()
|
||||
return build.WorkspaceID, nil
|
||||
return getWorkspaceAgentByIDRow.Workspace.ID, nil
|
||||
}
|
||||
|
||||
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent) error {
|
||||
|
|
|
@ -90,9 +90,11 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
|
|||
}
|
||||
}
|
||||
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
}
|
||||
}
|
||||
return &agentproto.BatchUpdateAppHealthResponse{}, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,252 @@
|
|||
package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
)
|
||||
|
||||
func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
agent = database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
}
|
||||
app1 = database.WorkspaceApp{
|
||||
ID: uuid.New(),
|
||||
AgentID: agent.ID,
|
||||
Slug: "code-server-1",
|
||||
DisplayName: "code-server 1",
|
||||
HealthcheckUrl: "http://localhost:3000",
|
||||
Health: database.WorkspaceAppHealthInitializing,
|
||||
}
|
||||
app2 = database.WorkspaceApp{
|
||||
ID: uuid.New(),
|
||||
AgentID: agent.ID,
|
||||
Slug: "code-server-2",
|
||||
DisplayName: "code-server 2",
|
||||
HealthcheckUrl: "http://localhost:3001",
|
||||
Health: database.WorkspaceAppHealthHealthy,
|
||||
}
|
||||
)
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
|
||||
dbM.EXPECT().UpdateWorkspaceAppHealthByID(gomock.Any(), database.UpdateWorkspaceAppHealthByIDParams{
|
||||
ID: app1.ID,
|
||||
Health: database.WorkspaceAppHealthHealthy,
|
||||
}).Return(nil)
|
||||
dbM.EXPECT().UpdateWorkspaceAppHealthByID(gomock.Any(), database.UpdateWorkspaceAppHealthByIDParams{
|
||||
ID: app2.ID,
|
||||
Health: database.WorkspaceAppHealthUnhealthy,
|
||||
}).Return(nil)
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Set one to healthy, set another to unhealthy.
|
||||
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
|
||||
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
|
||||
{
|
||||
Id: app1.ID[:],
|
||||
Health: agentproto.AppHealth_HEALTHY,
|
||||
},
|
||||
{
|
||||
Id: app2.ID[:],
|
||||
Health: agentproto.AppHealth_UNHEALTHY,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp)
|
||||
|
||||
require.True(t, publishCalled)
|
||||
})
|
||||
|
||||
t.Run("Unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Set both to their current status, neither should be updated in the
|
||||
// DB.
|
||||
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
|
||||
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
|
||||
{
|
||||
Id: app1.ID[:],
|
||||
Health: agentproto.AppHealth_INITIALIZING,
|
||||
},
|
||||
{
|
||||
Id: app2.ID[:],
|
||||
Health: agentproto.AppHealth_HEALTHY,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp)
|
||||
|
||||
require.False(t, publishCalled)
|
||||
})
|
||||
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// No DB queries are made if there are no updates to process.
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Do nothing.
|
||||
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
|
||||
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp)
|
||||
|
||||
require.False(t, publishCalled)
|
||||
})
|
||||
|
||||
t.Run("AppNoHealthcheck", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app3 := database.WorkspaceApp{
|
||||
ID: uuid.New(),
|
||||
AgentID: agent.ID,
|
||||
Slug: "code-server-3",
|
||||
DisplayName: "code-server 3",
|
||||
}
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
|
||||
// Set app3 to healthy, should error.
|
||||
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
|
||||
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
|
||||
{
|
||||
Id: app3.ID[:],
|
||||
Health: agentproto.AppHealth_HEALTHY,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "does not have healthchecks enabled")
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
|
||||
t.Run("UnknownApp", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
|
||||
// Set an unknown app to healthy, should error.
|
||||
id := uuid.New()
|
||||
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
|
||||
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
|
||||
{
|
||||
Id: id[:],
|
||||
Health: agentproto.AppHealth_HEALTHY,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "not found")
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
|
||||
t.Run("InvalidHealth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
|
||||
// Set an unknown app to healthy, should error.
|
||||
resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{
|
||||
Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{
|
||||
{
|
||||
Id: app1.ID[:],
|
||||
Health: -999,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "unknown health status")
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
}
|
|
@ -3,6 +3,7 @@ package agentapi
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/mod/semver"
|
||||
|
@ -21,6 +22,15 @@ type LifecycleAPI struct {
|
|||
Database database.Store
|
||||
Log slog.Logger
|
||||
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
}
|
||||
|
||||
func (a *LifecycleAPI) now() time.Time {
|
||||
if a.TimeNowFn != nil {
|
||||
return a.TimeNowFn()
|
||||
}
|
||||
return dbtime.Now()
|
||||
}
|
||||
|
||||
func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
|
||||
|
@ -68,7 +78,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
|
|||
|
||||
changedAt := req.Lifecycle.ChangedAt.AsTime()
|
||||
if changedAt.IsZero() {
|
||||
changedAt = dbtime.Now()
|
||||
changedAt = a.now()
|
||||
req.Lifecycle.ChangedAt = timestamppb.New(changedAt)
|
||||
}
|
||||
dbChangedAt := sql.NullTime{Time: changedAt, Valid: true}
|
||||
|
@ -78,8 +88,13 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
|
|||
switch lifecycleState {
|
||||
case database.WorkspaceAgentLifecycleStateStarting:
|
||||
startedAt = dbChangedAt
|
||||
readyAt.Valid = false // This agent is re-starting, so it's not ready yet.
|
||||
// This agent is (re)starting, so it's not ready yet.
|
||||
readyAt.Time = time.Time{}
|
||||
readyAt.Valid = false
|
||||
case database.WorkspaceAgentLifecycleStateReady, database.WorkspaceAgentLifecycleStateStartError:
|
||||
if !startedAt.Valid {
|
||||
startedAt = dbChangedAt
|
||||
}
|
||||
readyAt = dbChangedAt
|
||||
}
|
||||
|
||||
|
@ -97,9 +112,11 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
|
|||
return nil, xerrors.Errorf("update workspace agent lifecycle state: %w", err)
|
||||
}
|
||||
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
if a.PublishWorkspaceUpdateFn != nil {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return req.Lifecycle, nil
|
||||
|
|
|
@ -0,0 +1,461 @@
|
|||
package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
)
|
||||
|
||||
func TestUpdateLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
someTime, err := time.Parse(time.RFC3339, "2023-01-01T00:00:00Z")
|
||||
require.NoError(t, err)
|
||||
someTime = dbtime.Time(someTime)
|
||||
now := dbtime.Now()
|
||||
|
||||
var (
|
||||
workspaceID = uuid.New()
|
||||
agentCreated = database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
|
||||
StartedAt: sql.NullTime{Valid: false},
|
||||
ReadyAt: sql.NullTime{Valid: false},
|
||||
}
|
||||
agentStarting = database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
|
||||
StartedAt: sql.NullTime{Valid: true, Time: someTime},
|
||||
ReadyAt: sql.NullTime{Valid: false},
|
||||
}
|
||||
)
|
||||
|
||||
t.Run("OKStarting", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lifecycle := &agentproto.Lifecycle{
|
||||
State: agentproto.Lifecycle_STARTING,
|
||||
ChangedAt: timestamppb.New(now),
|
||||
}
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: agentCreated.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
|
||||
StartedAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
ReadyAt: sql.NullTime{Valid: false},
|
||||
}).Return(nil)
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agentCreated, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
|
||||
Lifecycle: lifecycle,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lifecycle, resp)
|
||||
require.True(t, publishCalled)
|
||||
})
|
||||
|
||||
t.Run("OKReadying", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lifecycle := &agentproto.Lifecycle{
|
||||
State: agentproto.Lifecycle_READY,
|
||||
ChangedAt: timestamppb.New(now),
|
||||
}
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: agentStarting.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
StartedAt: agentStarting.StartedAt,
|
||||
ReadyAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
}).Return(nil)
|
||||
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agentStarting, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
// Test that nil publish fn works.
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
|
||||
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
|
||||
Lifecycle: lifecycle,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lifecycle, resp)
|
||||
})
|
||||
|
||||
// This test jumps from CREATING to READY, skipping STARTED. Both the
|
||||
// StartedAt and ReadyAt fields should be set.
|
||||
t.Run("OKStraightToReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lifecycle := &agentproto.Lifecycle{
|
||||
State: agentproto.Lifecycle_READY,
|
||||
ChangedAt: timestamppb.New(now),
|
||||
}
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: agentCreated.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
StartedAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
ReadyAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
}).Return(nil)
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agentCreated, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
|
||||
Lifecycle: lifecycle,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lifecycle, resp)
|
||||
require.True(t, publishCalled)
|
||||
})
|
||||
|
||||
t.Run("NoTimeSpecified", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lifecycle := &agentproto.Lifecycle{
|
||||
State: agentproto.Lifecycle_READY,
|
||||
// Zero time
|
||||
ChangedAt: timestamppb.New(time.Time{}),
|
||||
}
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
now := dbtime.Now()
|
||||
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: agentCreated.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
StartedAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
ReadyAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agentCreated, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
|
||||
Lifecycle: lifecycle,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lifecycle, resp)
|
||||
})
|
||||
|
||||
t.Run("AllStates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
LifecycleState: database.WorkspaceAgentLifecycleState(""),
|
||||
StartedAt: sql.NullTime{Valid: false},
|
||||
ReadyAt: sql.NullTime{Valid: false},
|
||||
}
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
var publishCalled int64
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
|
||||
atomic.AddInt64(&publishCalled, 1)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
states := []agentproto.Lifecycle_State{
|
||||
agentproto.Lifecycle_CREATED,
|
||||
agentproto.Lifecycle_STARTING,
|
||||
agentproto.Lifecycle_START_TIMEOUT,
|
||||
agentproto.Lifecycle_START_ERROR,
|
||||
agentproto.Lifecycle_READY,
|
||||
agentproto.Lifecycle_SHUTTING_DOWN,
|
||||
agentproto.Lifecycle_SHUTDOWN_TIMEOUT,
|
||||
agentproto.Lifecycle_SHUTDOWN_ERROR,
|
||||
agentproto.Lifecycle_OFF,
|
||||
}
|
||||
for i, state := range states {
|
||||
t.Log("state", state)
|
||||
// Use a time after the last state change to ensure ordering.
|
||||
stateNow := now.Add(time.Hour * time.Duration(i))
|
||||
lifecycle := &agentproto.Lifecycle{
|
||||
State: state,
|
||||
ChangedAt: timestamppb.New(stateNow),
|
||||
}
|
||||
|
||||
expectedStartedAt := agent.StartedAt
|
||||
expectedReadyAt := agent.ReadyAt
|
||||
if state == agentproto.Lifecycle_STARTING {
|
||||
expectedStartedAt = sql.NullTime{Valid: true, Time: stateNow}
|
||||
}
|
||||
if state == agentproto.Lifecycle_READY || state == agentproto.Lifecycle_START_ERROR {
|
||||
expectedReadyAt = sql.NullTime{Valid: true, Time: stateNow}
|
||||
}
|
||||
|
||||
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: agent.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleState(strings.ToLower(state.String())),
|
||||
StartedAt: expectedStartedAt,
|
||||
ReadyAt: expectedReadyAt,
|
||||
}).Times(1).Return(nil)
|
||||
|
||||
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
|
||||
Lifecycle: lifecycle,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lifecycle, resp)
|
||||
require.Equal(t, int64(i+1), atomic.LoadInt64(&publishCalled))
|
||||
|
||||
// For future iterations:
|
||||
agent.StartedAt = expectedStartedAt
|
||||
agent.ReadyAt = expectedReadyAt
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnknownLifecycleState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lifecycle := &agentproto.Lifecycle{
|
||||
State: -999,
|
||||
ChangedAt: timestamppb.New(now),
|
||||
}
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agentCreated, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
|
||||
Lifecycle: lifecycle,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "unknown lifecycle state")
|
||||
require.Nil(t, resp)
|
||||
require.False(t, publishCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateStartup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
workspaceID = uuid.New()
|
||||
agent = database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
}
|
||||
)
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
// Not used by UpdateStartup.
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
|
||||
startup := &agentproto.Startup{
|
||||
Version: "v1.2.3",
|
||||
ExpandedDirectory: "/path/to/expanded/dir",
|
||||
Subsystems: []agentproto.Startup_Subsystem{
|
||||
agentproto.Startup_ENVBOX,
|
||||
agentproto.Startup_ENVBUILDER,
|
||||
agentproto.Startup_EXECTRACE,
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), database.UpdateWorkspaceAgentStartupByIDParams{
|
||||
ID: agent.ID,
|
||||
Version: startup.Version,
|
||||
ExpandedDirectory: startup.ExpandedDirectory,
|
||||
Subsystems: []database.WorkspaceAgentSubsystem{
|
||||
database.WorkspaceAgentSubsystemEnvbox,
|
||||
database.WorkspaceAgentSubsystemEnvbuilder,
|
||||
database.WorkspaceAgentSubsystemExectrace,
|
||||
},
|
||||
APIVersion: agentapi.AgentAPIVersionDRPC,
|
||||
}).Return(nil)
|
||||
|
||||
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
|
||||
Startup: startup,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, startup, resp)
|
||||
})
|
||||
|
||||
t.Run("BadVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
// Not used by UpdateStartup.
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
|
||||
startup := &agentproto.Startup{
|
||||
Version: "asdf",
|
||||
ExpandedDirectory: "/path/to/expanded/dir",
|
||||
Subsystems: []agentproto.Startup_Subsystem{},
|
||||
}
|
||||
|
||||
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
|
||||
Startup: startup,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "invalid agent semver version")
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
|
||||
t.Run("BadSubsystem", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspaceID, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
// Not used by UpdateStartup.
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
|
||||
startup := &agentproto.Startup{
|
||||
Version: "v1.2.3",
|
||||
ExpandedDirectory: "/path/to/expanded/dir",
|
||||
Subsystems: []agentproto.Startup_Subsystem{
|
||||
agentproto.Startup_ENVBOX,
|
||||
-999,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
|
||||
Startup: startup,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "invalid agent subsystem")
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
}
|
|
@ -2,6 +2,7 @@ package agentapi
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -19,6 +20,15 @@ type LogsAPI struct {
|
|||
Log slog.Logger
|
||||
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error
|
||||
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
}
|
||||
|
||||
func (a *LogsAPI) now() time.Time {
|
||||
if a.TimeNowFn != nil {
|
||||
return a.TimeNowFn()
|
||||
}
|
||||
return dbtime.Now()
|
||||
}
|
||||
|
||||
func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) {
|
||||
|
@ -26,6 +36,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if workspaceAgent.LogsOverflowed {
|
||||
return nil, xerrors.New("workspace agent logs overflowed")
|
||||
}
|
||||
|
||||
if len(req.Logs) == 0 {
|
||||
return &agentproto.BatchCreateLogsResponse{}, nil
|
||||
|
@ -42,7 +55,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
|||
// Use the external log source
|
||||
externalSources, err := a.Database.InsertWorkspaceAgentLogSources(ctx, database.InsertWorkspaceAgentLogSourcesParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
CreatedAt: dbtime.Now(),
|
||||
CreatedAt: a.now(),
|
||||
ID: []uuid.UUID{agentsdk.ExternalLogSourceID},
|
||||
DisplayName: []string{"External"},
|
||||
Icon: []string{"/emojis/1f310.png"},
|
||||
|
@ -88,7 +101,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
|||
|
||||
logs, err := a.Database.InsertWorkspaceAgentLogs(ctx, database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: workspaceAgent.ID,
|
||||
CreatedAt: dbtime.Now(),
|
||||
CreatedAt: a.now(),
|
||||
Output: output,
|
||||
Level: level,
|
||||
LogSourceID: logSourceID,
|
||||
|
@ -98,9 +111,6 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
|||
if !database.IsWorkspaceAgentLogsLimitError(err) {
|
||||
return nil, xerrors.Errorf("insert workspace agent logs: %w", err)
|
||||
}
|
||||
if workspaceAgent.LogsOverflowed {
|
||||
return nil, xerrors.New("workspace agent logs overflowed")
|
||||
}
|
||||
err := a.Database.UpdateWorkspaceAgentLogOverflowByID(ctx, database.UpdateWorkspaceAgentLogOverflowByIDParams{
|
||||
ID: workspaceAgent.ID,
|
||||
LogsOverflowed: true,
|
||||
|
@ -112,21 +122,25 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
|||
a.Log.Warn(ctx, "failed to update workspace agent log overflow", slog.Error(err))
|
||||
}
|
||||
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
if a.PublishWorkspaceUpdateFn != nil {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
}
|
||||
}
|
||||
return nil, xerrors.New("workspace agent log limit exceeded")
|
||||
}
|
||||
|
||||
// Publish by the lowest log ID inserted so the log stream will fetch
|
||||
// everything from that point.
|
||||
lowestLogID := logs[0].ID
|
||||
a.PublishWorkspaceAgentLogsUpdateFn(ctx, workspaceAgent.ID, agentsdk.LogsNotifyMessage{
|
||||
CreatedAfter: lowestLogID - 1,
|
||||
})
|
||||
if a.PublishWorkspaceAgentLogsUpdateFn != nil {
|
||||
lowestLogID := logs[0].ID
|
||||
a.PublishWorkspaceAgentLogsUpdateFn(ctx, workspaceAgent.ID, agentsdk.LogsNotifyMessage{
|
||||
CreatedAfter: lowestLogID - 1,
|
||||
})
|
||||
}
|
||||
|
||||
if workspaceAgent.LogsLength == 0 {
|
||||
if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil {
|
||||
// If these are the first logs being appended, we publish a UI update
|
||||
// to notify the UI that logs are now available.
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
|
||||
|
|
|
@ -0,0 +1,427 @@
|
|||
package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func TestBatchCreateLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
agent = database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
}
|
||||
logSource = database.WorkspaceAgentLogSource{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
CreatedAt: dbtime.Now(),
|
||||
ID: uuid.New(),
|
||||
}
|
||||
)
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
publishWorkspaceUpdateCalled := false
|
||||
publishWorkspaceAgentLogsUpdateCalled := false
|
||||
now := dbtime.Now()
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
|
||||
publishWorkspaceAgentLogsUpdateCalled = true
|
||||
|
||||
// Check the message content, should be for -1 since the lowest
|
||||
// log we inserted was 0.
|
||||
assert.Equal(t, agentsdk.LogsNotifyMessage{CreatedAfter: -1}, msg)
|
||||
},
|
||||
TimeNowFn: func() time.Time { return now },
|
||||
}
|
||||
|
||||
req := &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(now),
|
||||
Level: agentproto.Log_TRACE,
|
||||
Output: "log line 1",
|
||||
},
|
||||
{
|
||||
CreatedAt: timestamppb.New(now.Add(time.Hour)),
|
||||
Level: agentproto.Log_DEBUG,
|
||||
Output: "log line 2",
|
||||
},
|
||||
{
|
||||
CreatedAt: timestamppb.New(now.Add(2 * time.Hour)),
|
||||
Level: agentproto.Log_INFO,
|
||||
Output: "log line 3",
|
||||
},
|
||||
{
|
||||
CreatedAt: timestamppb.New(now.Add(3 * time.Hour)),
|
||||
Level: agentproto.Log_WARN,
|
||||
Output: "log line 4",
|
||||
},
|
||||
{
|
||||
CreatedAt: timestamppb.New(now.Add(4 * time.Hour)),
|
||||
Level: agentproto.Log_ERROR,
|
||||
Output: "log line 5",
|
||||
},
|
||||
{
|
||||
CreatedAt: timestamppb.New(now.Add(5 * time.Hour)),
|
||||
Level: -999, // defaults to INFO
|
||||
Output: "log line 6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Craft expected DB request and response dynamically.
|
||||
insertWorkspaceAgentLogsParams := database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: agent.ID,
|
||||
LogSourceID: logSource.ID,
|
||||
CreatedAt: now,
|
||||
Output: make([]string, len(req.Logs)),
|
||||
Level: make([]database.LogLevel, len(req.Logs)),
|
||||
OutputLength: 0,
|
||||
}
|
||||
insertWorkspaceAgentLogsReturn := make([]database.WorkspaceAgentLog, len(req.Logs))
|
||||
for i, logEntry := range req.Logs {
|
||||
insertWorkspaceAgentLogsParams.Output[i] = logEntry.Output
|
||||
level := database.LogLevelInfo
|
||||
if logEntry.Level >= 0 {
|
||||
level = database.LogLevel(strings.ToLower(logEntry.Level.String()))
|
||||
}
|
||||
insertWorkspaceAgentLogsParams.Level[i] = level
|
||||
insertWorkspaceAgentLogsParams.OutputLength += int32(len(logEntry.Output))
|
||||
|
||||
insertWorkspaceAgentLogsReturn[i] = database.WorkspaceAgentLog{
|
||||
AgentID: agent.ID,
|
||||
CreatedAt: logEntry.CreatedAt.AsTime(),
|
||||
ID: int64(i),
|
||||
Output: logEntry.Output,
|
||||
Level: insertWorkspaceAgentLogsParams.Level[i],
|
||||
LogSourceID: logSource.ID,
|
||||
}
|
||||
}
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), insertWorkspaceAgentLogsParams).Return(insertWorkspaceAgentLogsReturn, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
require.True(t, publishWorkspaceUpdateCalled)
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
agentWithLogs := agent
|
||||
agentWithLogs.LogsLength = 1
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
publishWorkspaceUpdateCalled := false
|
||||
publishWorkspaceAgentLogsUpdateCalled := false
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agentWithLogs, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
|
||||
publishWorkspaceAgentLogsUpdateCalled = true
|
||||
},
|
||||
}
|
||||
|
||||
// Don't really care about the DB call.
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), gomock.Any()).Return([]database.WorkspaceAgentLog{
|
||||
{
|
||||
ID: 1,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(dbtime.Now()),
|
||||
Level: agentproto.Log_INFO,
|
||||
Output: "hello world",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
require.False(t, publishWorkspaceUpdateCalled)
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("AlreadyOverflowed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
overflowedAgent := agent
|
||||
overflowedAgent.LogsOverflowed = true
|
||||
|
||||
publishWorkspaceUpdateCalled := false
|
||||
publishWorkspaceAgentLogsUpdateCalled := false
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return overflowedAgent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
|
||||
publishWorkspaceAgentLogsUpdateCalled = true
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "workspace agent logs overflowed")
|
||||
require.Nil(t, resp)
|
||||
require.False(t, publishWorkspaceUpdateCalled)
|
||||
require.False(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("InvalidLogSourceID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
// Test that they are ignored when nil.
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
PublishWorkspaceAgentLogsUpdateFn: nil,
|
||||
}
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: []byte("invalid"),
|
||||
Logs: []*agentproto.Log{
|
||||
{}, // need at least 1 log
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "parse log source ID")
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
|
||||
t.Run("UseExternalLogSourceID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := dbtime.Now()
|
||||
req := &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: uuid.Nil[:], // defaults to "external"
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(now),
|
||||
Level: agentproto.Log_INFO,
|
||||
Output: "hello world",
|
||||
},
|
||||
},
|
||||
}
|
||||
dbInsertParams := database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: agent.ID,
|
||||
LogSourceID: agentsdk.ExternalLogSourceID,
|
||||
CreatedAt: now,
|
||||
Output: []string{"hello world"},
|
||||
Level: []database.LogLevel{database.LogLevelInfo},
|
||||
OutputLength: int32(len(req.Logs[0].Output)),
|
||||
}
|
||||
dbInsertRes := []database.WorkspaceAgentLog{
|
||||
{
|
||||
AgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: 1,
|
||||
Output: "hello world",
|
||||
Level: database.LogLevelInfo,
|
||||
LogSourceID: agentsdk.ExternalLogSourceID,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
publishWorkspaceUpdateCalled := false
|
||||
publishWorkspaceAgentLogsUpdateCalled := false
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
|
||||
publishWorkspaceAgentLogsUpdateCalled = true
|
||||
},
|
||||
TimeNowFn: func() time.Time { return now },
|
||||
}
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), database.InsertWorkspaceAgentLogSourcesParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: []uuid.UUID{agentsdk.ExternalLogSourceID},
|
||||
DisplayName: []string{"External"},
|
||||
Icon: []string{"/emojis/1f310.png"},
|
||||
}).Return([]database.WorkspaceAgentLogSource{
|
||||
{
|
||||
// only the ID field is used
|
||||
ID: agentsdk.ExternalLogSourceID,
|
||||
},
|
||||
}, nil)
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), dbInsertParams).Return(dbInsertRes, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
require.True(t, publishWorkspaceUpdateCalled)
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
publishWorkspaceUpdateCalled := false
|
||||
publishWorkspaceAgentLogsUpdateCalled := false
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
|
||||
publishWorkspaceAgentLogsUpdateCalled = true
|
||||
},
|
||||
TimeNowFn: func() time.Time { return now },
|
||||
}
|
||||
|
||||
// Return a unique violation error to simulate the log source
|
||||
// already existing. This should be handled gracefully.
|
||||
logSourceInsertErr := &pq.Error{
|
||||
Code: pq.ErrorCode("23505"), // unique_violation
|
||||
Constraint: string(database.UniqueWorkspaceAgentLogSourcesPkey),
|
||||
}
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), database.InsertWorkspaceAgentLogSourcesParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: []uuid.UUID{agentsdk.ExternalLogSourceID},
|
||||
DisplayName: []string{"External"},
|
||||
Icon: []string{"/emojis/1f310.png"},
|
||||
}).Return([]database.WorkspaceAgentLogSource{}, logSourceInsertErr)
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), dbInsertParams).Return(dbInsertRes, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
require.True(t, publishWorkspaceUpdateCalled)
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Overflow", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
publishWorkspaceUpdateCalled := false
|
||||
publishWorkspaceAgentLogsUpdateCalled := false
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: slogtest.Make(t, nil),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) {
|
||||
publishWorkspaceAgentLogsUpdateCalled = true
|
||||
},
|
||||
}
|
||||
|
||||
// Don't really care about the DB call params, just want to return an
|
||||
// error.
|
||||
dbErr := &pq.Error{
|
||||
Constraint: "max_logs_length",
|
||||
Table: "workspace_agents",
|
||||
}
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), gomock.Any()).Return(nil, dbErr)
|
||||
|
||||
// Should also update the workspace agent.
|
||||
dbM.EXPECT().UpdateWorkspaceAgentLogOverflowByID(gomock.Any(), database.UpdateWorkspaceAgentLogOverflowByIDParams{
|
||||
ID: agent.ID,
|
||||
LogsOverflowed: true,
|
||||
}).Return(nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(dbtime.Now()),
|
||||
Level: agentproto.Log_INFO,
|
||||
Output: "hello world",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, resp)
|
||||
require.True(t, publishWorkspaceUpdateCalled)
|
||||
require.False(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
}
|
|
@ -5,7 +5,6 @@ import (
|
|||
"database/sql"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -25,18 +24,16 @@ import (
|
|||
)
|
||||
|
||||
type ManifestAPI struct {
|
||||
AccessURL *url.URL
|
||||
AppHostname string
|
||||
AgentInactiveDisconnectTimeout time.Duration
|
||||
AgentFallbackTroubleshootingURL string
|
||||
ExternalAuthConfigs []*externalauth.Config
|
||||
DisableDirectConnections bool
|
||||
DerpForceWebSockets bool
|
||||
AccessURL *url.URL
|
||||
AppHostname string
|
||||
ExternalAuthConfigs []*externalauth.Config
|
||||
DisableDirectConnections bool
|
||||
DerpForceWebSockets bool
|
||||
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Database database.Store
|
||||
DerpMapFn func() *tailcfg.DERPMap
|
||||
TailnetCoordinator *atomic.Pointer[tailnet.Coordinator]
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
|
||||
Database database.Store
|
||||
DerpMapFn func() *tailcfg.DERPMap
|
||||
}
|
||||
|
||||
func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
|
||||
|
@ -44,21 +41,15 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
apiAgent, err := db2sdk.WorkspaceAgent(
|
||||
a.DerpMapFn(), *a.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, a.AgentInactiveDisconnectTimeout,
|
||||
a.AgentFallbackTroubleshootingURL,
|
||||
)
|
||||
workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("converting workspace agent: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
dbApps []database.WorkspaceApp
|
||||
scripts []database.WorkspaceAgentScript
|
||||
metadata []database.WorkspaceAgentMetadatum
|
||||
resource database.WorkspaceResource
|
||||
build database.WorkspaceBuild
|
||||
workspace database.Workspace
|
||||
owner database.User
|
||||
)
|
||||
|
@ -79,20 +70,12 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
|
|||
eg.Go(func() (err error) {
|
||||
metadata, err = a.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
Keys: nil,
|
||||
Keys: nil, // all
|
||||
})
|
||||
return err
|
||||
})
|
||||
eg.Go(func() (err error) {
|
||||
resource, err = a.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("getting resource by id: %w", err)
|
||||
}
|
||||
build, err = a.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("getting workspace build by job id: %w", err)
|
||||
}
|
||||
workspace, err = a.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
|
||||
workspace, err = a.Database.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("getting workspace by id: %w", err)
|
||||
}
|
||||
|
@ -116,6 +99,11 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
|
|||
|
||||
vscodeProxyURI := vscodeProxyURI(appSlug, a.AccessURL, a.AppHostname)
|
||||
|
||||
envs, err := db2sdk.WorkspaceAgentEnvironment(workspaceAgent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var gitAuthConfigs uint32
|
||||
for _, cfg := range a.ExternalAuthConfigs {
|
||||
if codersdk.EnhancedExternalAuthProvider(cfg.Type).Git() {
|
||||
|
@ -135,8 +123,8 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
|
|||
WorkspaceId: workspace.ID[:],
|
||||
WorkspaceName: workspace.Name,
|
||||
GitAuthConfigs: gitAuthConfigs,
|
||||
EnvironmentVariables: apiAgent.EnvironmentVariables,
|
||||
Directory: apiAgent.Directory,
|
||||
EnvironmentVariables: envs,
|
||||
Directory: workspaceAgent.Directory,
|
||||
VsCodePortProxyUri: vscodeProxyURI,
|
||||
MotdPath: workspaceAgent.MOTDFile,
|
||||
DisableDirectConnections: a.DisableDirectConnections,
|
||||
|
|
|
@ -0,0 +1,396 @@
|
|||
package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
)
|
||||
|
||||
func TestGetManifest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
someTime, err := time.Parse(time.RFC3339, "2023-01-01T00:00:00Z")
|
||||
require.NoError(t, err)
|
||||
someTime = dbtime.Time(someTime)
|
||||
|
||||
expectedEnvVars := map[string]string{
|
||||
"FOO": "bar",
|
||||
"COOL_ENV": "dean was here",
|
||||
}
|
||||
expectedEnvVarsJSON, err := json.Marshal(expectedEnvVars)
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
owner = database.User{
|
||||
ID: uuid.New(),
|
||||
Username: "cool-user",
|
||||
}
|
||||
workspace = database.Workspace{
|
||||
ID: uuid.New(),
|
||||
OwnerID: owner.ID,
|
||||
Name: "cool-workspace",
|
||||
}
|
||||
agent = database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
Name: "cool-agent",
|
||||
EnvironmentVariables: pqtype.NullRawMessage{
|
||||
RawMessage: expectedEnvVarsJSON,
|
||||
Valid: true,
|
||||
},
|
||||
Directory: "/cool/dir",
|
||||
MOTDFile: "/cool/motd",
|
||||
}
|
||||
apps = []database.WorkspaceApp{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Url: sql.NullString{String: "http://localhost:1234", Valid: true},
|
||||
External: false,
|
||||
Slug: "cool-app-1",
|
||||
DisplayName: "app 1",
|
||||
Command: sql.NullString{String: "cool command", Valid: true},
|
||||
Icon: "/icon.png",
|
||||
Subdomain: true,
|
||||
SharingLevel: database.AppSharingLevelAuthenticated,
|
||||
Health: database.WorkspaceAppHealthHealthy,
|
||||
HealthcheckUrl: "http://localhost:1234/health",
|
||||
HealthcheckInterval: 10,
|
||||
HealthcheckThreshold: 3,
|
||||
},
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Url: sql.NullString{String: "http://google.com", Valid: true},
|
||||
External: true,
|
||||
Slug: "google",
|
||||
DisplayName: "Literally Google",
|
||||
Command: sql.NullString{Valid: false},
|
||||
Icon: "/google.png",
|
||||
Subdomain: false,
|
||||
SharingLevel: database.AppSharingLevelPublic,
|
||||
Health: database.WorkspaceAppHealthDisabled,
|
||||
},
|
||||
{
|
||||
ID: uuid.New(),
|
||||
Url: sql.NullString{String: "http://localhost:4321", Valid: true},
|
||||
External: true,
|
||||
Slug: "cool-app-2",
|
||||
DisplayName: "another COOL app",
|
||||
Command: sql.NullString{Valid: false},
|
||||
Icon: "",
|
||||
Subdomain: false,
|
||||
SharingLevel: database.AppSharingLevelOwner,
|
||||
Health: database.WorkspaceAppHealthUnhealthy,
|
||||
HealthcheckUrl: "http://localhost:4321/health",
|
||||
HealthcheckInterval: 20,
|
||||
HealthcheckThreshold: 5,
|
||||
},
|
||||
}
|
||||
scripts = []database.WorkspaceAgentScript{
|
||||
{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
LogSourceID: uuid.New(),
|
||||
LogPath: "/cool/log/path/1",
|
||||
Script: "cool script 1",
|
||||
Cron: "30 2 * * *",
|
||||
StartBlocksLogin: true,
|
||||
RunOnStart: true,
|
||||
RunOnStop: false,
|
||||
TimeoutSeconds: 60,
|
||||
},
|
||||
{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
LogSourceID: uuid.New(),
|
||||
LogPath: "/cool/log/path/2",
|
||||
Script: "cool script 2",
|
||||
Cron: "",
|
||||
StartBlocksLogin: false,
|
||||
RunOnStart: false,
|
||||
RunOnStop: true,
|
||||
TimeoutSeconds: 30,
|
||||
},
|
||||
}
|
||||
metadata = []database.WorkspaceAgentMetadatum{
|
||||
{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
DisplayName: "cool metadata 1",
|
||||
Key: "cool-key-1",
|
||||
Script: "cool script 1",
|
||||
Value: "cool value 1",
|
||||
Error: "",
|
||||
Timeout: int64(time.Minute),
|
||||
Interval: int64(time.Minute),
|
||||
CollectedAt: someTime,
|
||||
},
|
||||
{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
DisplayName: "cool metadata 2",
|
||||
Key: "cool-key-2",
|
||||
Script: "cool script 2",
|
||||
Value: "cool value 2",
|
||||
Error: "some uncool error",
|
||||
Timeout: int64(5 * time.Second),
|
||||
Interval: int64(20 * time.Minute),
|
||||
CollectedAt: someTime.Add(time.Hour),
|
||||
},
|
||||
}
|
||||
derpMapFn = func() *tailcfg.DERPMap {
|
||||
return &tailcfg.DERPMap{
|
||||
Regions: map[int]*tailcfg.DERPRegion{
|
||||
1: {RegionName: "cool region"},
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// These are done manually to ensure the conversion logic matches what a
|
||||
// human expects.
|
||||
var (
|
||||
protoApps = []*agentproto.WorkspaceApp{
|
||||
{
|
||||
Id: apps[0].ID[:],
|
||||
Url: apps[0].Url.String,
|
||||
External: apps[0].External,
|
||||
Slug: apps[0].Slug,
|
||||
DisplayName: apps[0].DisplayName,
|
||||
Command: apps[0].Command.String,
|
||||
Icon: apps[0].Icon,
|
||||
Subdomain: apps[0].Subdomain,
|
||||
SubdomainName: fmt.Sprintf("%s--%s--%s--%s", apps[0].Slug, agent.Name, workspace.Name, owner.Username),
|
||||
SharingLevel: agentproto.WorkspaceApp_AUTHENTICATED,
|
||||
Healthcheck: &agentproto.WorkspaceApp_Healthcheck{
|
||||
Url: apps[0].HealthcheckUrl,
|
||||
Interval: durationpb.New(time.Duration(apps[0].HealthcheckInterval) * time.Second),
|
||||
Threshold: apps[0].HealthcheckThreshold,
|
||||
},
|
||||
Health: agentproto.WorkspaceApp_HEALTHY,
|
||||
},
|
||||
{
|
||||
Id: apps[1].ID[:],
|
||||
Url: apps[1].Url.String,
|
||||
External: apps[1].External,
|
||||
Slug: apps[1].Slug,
|
||||
DisplayName: apps[1].DisplayName,
|
||||
Command: apps[1].Command.String,
|
||||
Icon: apps[1].Icon,
|
||||
Subdomain: false,
|
||||
SubdomainName: "",
|
||||
SharingLevel: agentproto.WorkspaceApp_PUBLIC,
|
||||
Healthcheck: &agentproto.WorkspaceApp_Healthcheck{
|
||||
Url: "",
|
||||
Interval: durationpb.New(0),
|
||||
Threshold: 0,
|
||||
},
|
||||
Health: agentproto.WorkspaceApp_DISABLED,
|
||||
},
|
||||
{
|
||||
Id: apps[2].ID[:],
|
||||
Url: apps[2].Url.String,
|
||||
External: apps[2].External,
|
||||
Slug: apps[2].Slug,
|
||||
DisplayName: apps[2].DisplayName,
|
||||
Command: apps[2].Command.String,
|
||||
Icon: apps[2].Icon,
|
||||
Subdomain: false,
|
||||
SubdomainName: "",
|
||||
SharingLevel: agentproto.WorkspaceApp_OWNER,
|
||||
Healthcheck: &agentproto.WorkspaceApp_Healthcheck{
|
||||
Url: apps[2].HealthcheckUrl,
|
||||
Interval: durationpb.New(time.Duration(apps[2].HealthcheckInterval) * time.Second),
|
||||
Threshold: apps[2].HealthcheckThreshold,
|
||||
},
|
||||
Health: agentproto.WorkspaceApp_UNHEALTHY,
|
||||
},
|
||||
}
|
||||
protoScripts = []*agentproto.WorkspaceAgentScript{
|
||||
{
|
||||
LogSourceId: scripts[0].LogSourceID[:],
|
||||
LogPath: scripts[0].LogPath,
|
||||
Script: scripts[0].Script,
|
||||
Cron: scripts[0].Cron,
|
||||
RunOnStart: scripts[0].RunOnStart,
|
||||
RunOnStop: scripts[0].RunOnStop,
|
||||
StartBlocksLogin: scripts[0].StartBlocksLogin,
|
||||
Timeout: durationpb.New(time.Duration(scripts[0].TimeoutSeconds) * time.Second),
|
||||
},
|
||||
{
|
||||
LogSourceId: scripts[1].LogSourceID[:],
|
||||
LogPath: scripts[1].LogPath,
|
||||
Script: scripts[1].Script,
|
||||
Cron: scripts[1].Cron,
|
||||
RunOnStart: scripts[1].RunOnStart,
|
||||
RunOnStop: scripts[1].RunOnStop,
|
||||
StartBlocksLogin: scripts[1].StartBlocksLogin,
|
||||
Timeout: durationpb.New(time.Duration(scripts[1].TimeoutSeconds) * time.Second),
|
||||
},
|
||||
}
|
||||
protoMetadata = []*agentproto.WorkspaceAgentMetadata_Description{
|
||||
{
|
||||
DisplayName: metadata[0].DisplayName,
|
||||
Key: metadata[0].Key,
|
||||
Script: metadata[0].Script,
|
||||
Interval: durationpb.New(time.Duration(metadata[0].Interval)),
|
||||
Timeout: durationpb.New(time.Duration(metadata[0].Timeout)),
|
||||
},
|
||||
{
|
||||
DisplayName: metadata[1].DisplayName,
|
||||
Key: metadata[1].Key,
|
||||
Script: metadata[1].Script,
|
||||
Interval: durationpb.New(time.Duration(metadata[1].Interval)),
|
||||
Timeout: durationpb.New(time.Duration(metadata[1].Timeout)),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mDB := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
api := &agentapi.ManifestAPI{
|
||||
AccessURL: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
AppHostname: "*--apps.example.com",
|
||||
ExternalAuthConfigs: []*externalauth.Config{
|
||||
{Type: string(codersdk.EnhancedExternalAuthProviderGitHub)},
|
||||
{Type: "some-provider"},
|
||||
{Type: string(codersdk.EnhancedExternalAuthProviderGitLab)},
|
||||
},
|
||||
DisableDirectConnections: true,
|
||||
DerpForceWebSockets: true,
|
||||
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspace.ID, nil
|
||||
},
|
||||
Database: mDB,
|
||||
DerpMapFn: derpMapFn,
|
||||
}
|
||||
|
||||
mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil)
|
||||
mDB.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), []uuid.UUID{agent.ID}).Return(scripts, nil)
|
||||
mDB.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), database.GetWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Keys: nil, // all
|
||||
}).Return(metadata, nil)
|
||||
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
|
||||
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
|
||||
|
||||
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := &agentproto.Manifest{
|
||||
AgentId: agent.ID[:],
|
||||
AgentName: agent.Name,
|
||||
OwnerUsername: owner.Username,
|
||||
WorkspaceId: workspace.ID[:],
|
||||
WorkspaceName: workspace.Name,
|
||||
GitAuthConfigs: 2, // two "enhanced" external auth configs
|
||||
EnvironmentVariables: expectedEnvVars,
|
||||
Directory: agent.Directory,
|
||||
VsCodePortProxyUri: fmt.Sprintf("https://{{port}}--%s--%s--%s--apps.example.com", agent.Name, workspace.Name, owner.Username),
|
||||
MotdPath: agent.MOTDFile,
|
||||
DisableDirectConnections: true,
|
||||
DerpForceWebsockets: true,
|
||||
// tailnet.DERPMapToProto() is extensively tested elsewhere, so it's
|
||||
// not necessary to manually recreate a big DERP map here like we
|
||||
// did for apps and metadata.
|
||||
DerpMap: tailnet.DERPMapToProto(derpMapFn()),
|
||||
Scripts: protoScripts,
|
||||
Apps: protoApps,
|
||||
Metadata: protoMetadata,
|
||||
}
|
||||
|
||||
// Log got and expected with spew.
|
||||
// t.Log("got:\n" + spew.Sdump(got))
|
||||
// t.Log("expected:\n" + spew.Sdump(expected))
|
||||
|
||||
require.Equal(t, expected, got)
|
||||
})
|
||||
|
||||
t.Run("NoAppHostname", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mDB := dbmock.NewMockStore(gomock.NewController(t))
|
||||
|
||||
api := &agentapi.ManifestAPI{
|
||||
AccessURL: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
AppHostname: "",
|
||||
ExternalAuthConfigs: []*externalauth.Config{
|
||||
{Type: string(codersdk.EnhancedExternalAuthProviderGitHub)},
|
||||
{Type: "some-provider"},
|
||||
{Type: string(codersdk.EnhancedExternalAuthProviderGitLab)},
|
||||
},
|
||||
DisableDirectConnections: true,
|
||||
DerpForceWebSockets: true,
|
||||
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
return workspace.ID, nil
|
||||
},
|
||||
Database: mDB,
|
||||
DerpMapFn: derpMapFn,
|
||||
}
|
||||
|
||||
mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil)
|
||||
mDB.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), []uuid.UUID{agent.ID}).Return(scripts, nil)
|
||||
mDB.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), database.GetWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Keys: nil, // all
|
||||
}).Return(metadata, nil)
|
||||
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
|
||||
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
|
||||
|
||||
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := &agentproto.Manifest{
|
||||
AgentId: agent.ID[:],
|
||||
AgentName: agent.Name,
|
||||
OwnerUsername: owner.Username,
|
||||
WorkspaceId: workspace.ID[:],
|
||||
WorkspaceName: workspace.Name,
|
||||
GitAuthConfigs: 2, // two "enhanced" external auth configs
|
||||
EnvironmentVariables: expectedEnvVars,
|
||||
Directory: agent.Directory,
|
||||
VsCodePortProxyUri: "", // empty with no AppHost
|
||||
MotdPath: agent.MOTDFile,
|
||||
DisableDirectConnections: true,
|
||||
DerpForceWebsockets: true,
|
||||
// tailnet.DERPMapToProto() is extensively tested elsewhere, so it's
|
||||
// not necessary to manually recreate a big DERP map here like we
|
||||
// did for apps and metadata.
|
||||
DerpMap: tailnet.DERPMapToProto(derpMapFn()),
|
||||
Scripts: protoScripts,
|
||||
Apps: protoApps,
|
||||
Metadata: protoMetadata,
|
||||
}
|
||||
|
||||
// Log got and expected with spew.
|
||||
// t.Log("got:\n" + spew.Sdump(got))
|
||||
// t.Log("expected:\n" + spew.Sdump(expected))
|
||||
|
||||
require.Equal(t, expected, got)
|
||||
})
|
||||
}
|
|
@ -12,6 +12,7 @@ import (
|
|||
"cdr.dev/slog"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
|
@ -20,14 +21,26 @@ type MetadataAPI struct {
|
|||
Database database.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
Log slog.Logger
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
}
|
||||
|
||||
func (a *MetadataAPI) now() time.Time {
|
||||
if a.TimeNowFn != nil {
|
||||
return a.TimeNowFn()
|
||||
}
|
||||
return dbtime.Now()
|
||||
}
|
||||
|
||||
func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
|
||||
const (
|
||||
// maxValueLen is set to 2048 to stay under the 8000 byte Postgres
|
||||
// NOTIFY limit. Since both value and error can be set, the real payload
|
||||
// limit is 2 * 2048 * 4/3 <base64 expansion> = 5461 bytes + a few
|
||||
// hundred bytes for JSON syntax, key names, and metadata.
|
||||
// maxAllKeysLen is the maximum length of all metadata keys. This is
|
||||
// 6144 to stay below the Postgres NOTIFY limit of 8000 bytes, with some
|
||||
// headway for the timestamp and JSON encoding. Any values that would
|
||||
// exceed this limit are discarded (the rest are still inserted) and an
|
||||
// error is returned.
|
||||
maxAllKeysLen = 6144 // 1024 * 6
|
||||
|
||||
maxValueLen = 2048
|
||||
maxErrorLen = maxValueLen
|
||||
)
|
||||
|
@ -37,18 +50,36 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
|||
return nil, err
|
||||
}
|
||||
|
||||
collectedAt := time.Now()
|
||||
dbUpdate := database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
Key: make([]string, 0, len(req.Metadata)),
|
||||
Value: make([]string, 0, len(req.Metadata)),
|
||||
Error: make([]string, 0, len(req.Metadata)),
|
||||
CollectedAt: make([]time.Time, 0, len(req.Metadata)),
|
||||
}
|
||||
|
||||
var (
|
||||
collectedAt = a.now()
|
||||
allKeysLen = 0
|
||||
dbUpdate = database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
// These need to be `make(x, 0, len(req.Metadata))` instead of
|
||||
// `make(x, len(req.Metadata))` because we may not insert all
|
||||
// metadata if the keys are large.
|
||||
Key: make([]string, 0, len(req.Metadata)),
|
||||
Value: make([]string, 0, len(req.Metadata)),
|
||||
Error: make([]string, 0, len(req.Metadata)),
|
||||
CollectedAt: make([]time.Time, 0, len(req.Metadata)),
|
||||
}
|
||||
)
|
||||
for _, md := range req.Metadata {
|
||||
metadataError := md.Result.Error
|
||||
|
||||
allKeysLen += len(md.Key)
|
||||
if allKeysLen > maxAllKeysLen {
|
||||
// We still insert the rest of the metadata, and we return an error
|
||||
// after the insert.
|
||||
a.Log.Warn(
|
||||
ctx, "discarded extra agent metadata due to excessive key length",
|
||||
slog.F("collected_at", collectedAt),
|
||||
slog.F("all_keys_len", allKeysLen),
|
||||
slog.F("max_all_keys_len", maxAllKeysLen),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
// We overwrite the error if the provided payload is too long.
|
||||
if len(md.Result.Value) > maxValueLen {
|
||||
metadataError = fmt.Sprintf("value of %d bytes exceeded %d bytes", len(md.Result.Value), maxValueLen)
|
||||
|
@ -71,12 +102,16 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
|||
a.Log.Debug(
|
||||
ctx, "accepted metadata report",
|
||||
slog.F("collected_at", collectedAt),
|
||||
slog.F("original_collected_at", collectedAt),
|
||||
slog.F("key", md.Key),
|
||||
slog.F("value", ellipse(md.Result.Value, 16)),
|
||||
)
|
||||
}
|
||||
|
||||
err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: collectedAt,
|
||||
Keys: dbUpdate.Key,
|
||||
|
@ -84,17 +119,17 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal workspace agent metadata channel payload: %w", err)
|
||||
}
|
||||
|
||||
err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
|
||||
}
|
||||
|
||||
err = a.Pubsub.Publish(WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace agent metadata: %w", err)
|
||||
}
|
||||
|
||||
// If the metadata keys were too large, we return an error so the agent can
|
||||
// log it.
|
||||
if allKeysLen > maxAllKeysLen {
|
||||
return nil, xerrors.Errorf("metadata keys of %d bytes exceeded %d bytes", allKeysLen, maxAllKeysLen)
|
||||
}
|
||||
|
||||
return &agentproto.BatchUpdateMetadataResponse{}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,276 @@
|
|||
package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
type fakePublisher struct {
|
||||
// Nil pointer to pass interface check.
|
||||
pubsub.Pubsub
|
||||
publishes [][]byte
|
||||
}
|
||||
|
||||
var _ pubsub.Pubsub = &fakePublisher{}
|
||||
|
||||
func (f *fakePublisher) Publish(_ string, message []byte) error {
|
||||
f.publishes = append(f.publishes, message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
}
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := &fakePublisher{}
|
||||
|
||||
now := dbtime.Now()
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "awesome key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-10 * time.Second)),
|
||||
Age: 10,
|
||||
Value: "awesome value",
|
||||
Error: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "uncool key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-3 * time.Second)),
|
||||
Age: 3,
|
||||
Value: "",
|
||||
Error: "uncool value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key},
|
||||
Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value},
|
||||
Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error},
|
||||
// The value from the agent is ignored.
|
||||
CollectedAt: []time.Time{now, now},
|
||||
}).Return(nil)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: slogtest.Make(t, nil),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.BatchUpdateMetadata(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp)
|
||||
|
||||
require.Equal(t, 1, len(pub.publishes))
|
||||
var gotEvent agentapi.WorkspaceAgentMetadataChannelPayload
|
||||
require.NoError(t, json.Unmarshal(pub.publishes[0], &gotEvent))
|
||||
require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: now,
|
||||
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key},
|
||||
}, gotEvent)
|
||||
})
|
||||
|
||||
t.Run("ExceededLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := pubsub.NewInMemory()
|
||||
|
||||
almostLongValue := ""
|
||||
for i := 0; i < 2048; i++ {
|
||||
almostLongValue += "a"
|
||||
}
|
||||
|
||||
now := dbtime.Now()
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "almost long value",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
Value: almostLongValue,
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "too long value",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
Value: almostLongValue + "a",
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "almost long error",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
Error: almostLongValue,
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "too long error",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
Error: almostLongValue + "a",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key, req.Metadata[3].Key},
|
||||
Value: []string{
|
||||
almostLongValue,
|
||||
almostLongValue, // truncated
|
||||
"",
|
||||
"",
|
||||
},
|
||||
Error: []string{
|
||||
"",
|
||||
"value of 2049 bytes exceeded 2048 bytes",
|
||||
almostLongValue,
|
||||
"error of 2049 bytes exceeded 2048 bytes", // replaced
|
||||
},
|
||||
// The value from the agent is ignored.
|
||||
CollectedAt: []time.Time{now, now, now, now},
|
||||
}).Return(nil)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: slogtest.Make(t, nil),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.BatchUpdateMetadata(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp)
|
||||
})
|
||||
|
||||
t.Run("KeysTooLong", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
pub := pubsub.NewInMemory()
|
||||
|
||||
now := dbtime.Now()
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "key 1",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
Value: "value 1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "key 2",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
Value: "value 2",
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: func() string {
|
||||
key := "key 3 "
|
||||
for i := 0; i < (6144 - len("key 1") - len("key 2") - len("key 3") - 1); i++ {
|
||||
key += "a"
|
||||
}
|
||||
return key
|
||||
}(),
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
Value: "value 3",
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: "a", // should be ignored
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
Value: "value 4",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
// No key 4.
|
||||
Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
|
||||
Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value, req.Metadata[2].Result.Value},
|
||||
Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error, req.Metadata[2].Result.Error},
|
||||
// The value from the agent is ignored.
|
||||
CollectedAt: []time.Time{now, now, now},
|
||||
}).Return(nil)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: slogtest.Make(t, nil),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Watch the pubsub for events.
|
||||
var (
|
||||
eventCount int64
|
||||
gotEvent agentapi.WorkspaceAgentMetadataChannelPayload
|
||||
)
|
||||
cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) {
|
||||
if atomic.AddInt64(&eventCount, 1) > 1 {
|
||||
return
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(message, &gotEvent))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
resp, err := api.BatchUpdateMetadata(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "metadata keys of 6145 bytes exceeded 6144 bytes", err.Error())
|
||||
require.Nil(t, resp)
|
||||
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&eventCount))
|
||||
require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{
|
||||
CollectedAt: now,
|
||||
// No key 4.
|
||||
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
|
||||
}, gotEvent)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestGetServiceBanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := codersdk.ServiceBannerConfig{
|
||||
Enabled: true,
|
||||
Message: "hello world",
|
||||
BackgroundColor: "#000000",
|
||||
}
|
||||
cfgJSON, err := json.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().GetServiceBanner(gomock.Any()).Return(string(cfgJSON), nil)
|
||||
|
||||
api := &agentapi.ServiceBannerAPI{
|
||||
Database: dbM,
|
||||
}
|
||||
|
||||
resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &agentproto.ServiceBanner{
|
||||
Enabled: cfg.Enabled,
|
||||
Message: cfg.Message,
|
||||
BackgroundColor: cfg.BackgroundColor,
|
||||
}, resp)
|
||||
})
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().GetServiceBanner(gomock.Any()).Return("", sql.ErrNoRows)
|
||||
|
||||
api := &agentapi.ServiceBannerAPI{
|
||||
Database: dbM,
|
||||
}
|
||||
|
||||
resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, &agentproto.ServiceBanner{
|
||||
Enabled: false,
|
||||
Message: "",
|
||||
BackgroundColor: "",
|
||||
}, resp)
|
||||
})
|
||||
|
||||
t.Run("BadJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().GetServiceBanner(gomock.Any()).Return("hi", nil)
|
||||
|
||||
api := &agentapi.ServiceBannerAPI{
|
||||
Database: dbM,
|
||||
}
|
||||
|
||||
resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "unmarshal json")
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
}
|
|
@ -9,57 +9,71 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/autobuild"
|
||||
"github.com/coder/coder/v2/coderd/batchstats"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/prometheusmetrics"
|
||||
"github.com/coder/coder/v2/coderd/schedule"
|
||||
)
|
||||
|
||||
type StatsBatcher interface {
|
||||
Add(now time.Time, agentID uuid.UUID, templateID uuid.UUID, userID uuid.UUID, workspaceID uuid.UUID, st *agentproto.Stats) error
|
||||
}
|
||||
|
||||
type StatsAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
StatsBatcher *batchstats.Batcher
|
||||
StatsBatcher StatsBatcher
|
||||
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
}
|
||||
|
||||
func (a *StatsAPI) now() time.Time {
|
||||
if a.TimeNowFn != nil {
|
||||
return a.TimeNowFn()
|
||||
}
|
||||
return dbtime.Now()
|
||||
}
|
||||
|
||||
func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
|
||||
// An empty stat means it's just looking for the report interval.
|
||||
res := &agentproto.UpdateStatsResponse{
|
||||
ReportInterval: durationpb.New(a.AgentStatsRefreshInterval),
|
||||
}
|
||||
if req.Stats == nil || len(req.Stats.ConnectionsByProto) == 0 {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
workspaceAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
row, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
getWorkspaceAgentByIDRow, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
|
||||
}
|
||||
workspace := row.Workspace
|
||||
|
||||
res := &agentproto.UpdateStatsResponse{
|
||||
ReportInterval: durationpb.New(a.AgentStatsRefreshInterval),
|
||||
}
|
||||
|
||||
// An empty stat means it's just looking for the report interval.
|
||||
if len(req.Stats.ConnectionsByProto) == 0 {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
workspace := getWorkspaceAgentByIDRow.Workspace
|
||||
a.Log.Debug(ctx, "read stats report",
|
||||
slog.F("interval", a.AgentStatsRefreshInterval),
|
||||
slog.F("workspace_id", workspace.ID),
|
||||
slog.F("payload", req),
|
||||
)
|
||||
|
||||
now := a.now()
|
||||
if req.Stats.ConnectionCount > 0 {
|
||||
var nextAutostart time.Time
|
||||
if workspace.AutostartSchedule.String != "" {
|
||||
templateSchedule, err := (*(a.TemplateScheduleStore.Load())).Get(ctx, a.Database, workspace.TemplateID)
|
||||
// If the template schedule fails to load, just default to bumping without the next trasition and log it.
|
||||
// If the template schedule fails to load, just default to bumping
|
||||
// without the next transition and log it.
|
||||
if err != nil {
|
||||
a.Log.Error(ctx, "failed to load template schedule bumping activity, defaulting to bumping by 60min",
|
||||
slog.F("workspace_id", workspace.ID),
|
||||
|
@ -67,7 +81,7 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
|||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
next, allowed := autobuild.NextAutostartSchedule(time.Now(), workspace.AutostartSchedule.String, templateSchedule)
|
||||
next, allowed := autobuild.NextAutostartSchedule(now, workspace.AutostartSchedule.String, templateSchedule)
|
||||
if allowed {
|
||||
nextAutostart = next
|
||||
}
|
||||
|
@ -76,13 +90,12 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
|||
ActivityBumpWorkspace(ctx, a.Log.Named("activity_bump"), a.Database, workspace.ID, nextAutostart)
|
||||
}
|
||||
|
||||
now := dbtime.Now()
|
||||
|
||||
var errGroup errgroup.Group
|
||||
errGroup.Go(func() error {
|
||||
if err := a.StatsBatcher.Add(time.Now(), workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, req.Stats); err != nil {
|
||||
a.Log.Error(ctx, "failed to add stats to batcher", slog.Error(err))
|
||||
return xerrors.Errorf("can't insert workspace agent stat: %w", err)
|
||||
err := a.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, req.Stats)
|
||||
if err != nil {
|
||||
a.Log.Error(ctx, "add agent stats to batcher", slog.Error(err))
|
||||
return xerrors.Errorf("insert workspace agent stats batch: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
@ -92,7 +105,7 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
|||
LastUsedAt: now,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("can't update workspace LastUsedAt: %w", err)
|
||||
return xerrors.Errorf("update workspace LastUsedAt: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
@ -100,14 +113,14 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
|||
errGroup.Go(func() error {
|
||||
user, err := a.Database.GetUserByID(ctx, workspace.OwnerID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("can't get user: %w", err)
|
||||
return xerrors.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
a.UpdateAgentMetricsFn(ctx, prometheusmetrics.AgentMetricLabels{
|
||||
Username: user.Username,
|
||||
WorkspaceName: workspace.Name,
|
||||
AgentName: workspaceAgent.Name,
|
||||
TemplateName: row.TemplateName,
|
||||
TemplateName: getWorkspaceAgentByIDRow.TemplateName,
|
||||
}, req.Stats.Metrics)
|
||||
return nil
|
||||
})
|
||||
|
|
|
@ -0,0 +1,379 @@
|
|||
package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/prometheusmetrics"
|
||||
"github.com/coder/coder/v2/coderd/schedule"
|
||||
)
|
||||
|
||||
type statsBatcher struct {
|
||||
mu sync.Mutex
|
||||
|
||||
called int64
|
||||
lastTime time.Time
|
||||
lastAgentID uuid.UUID
|
||||
lastTemplateID uuid.UUID
|
||||
lastUserID uuid.UUID
|
||||
lastWorkspaceID uuid.UUID
|
||||
lastStats *agentproto.Stats
|
||||
}
|
||||
|
||||
var _ agentapi.StatsBatcher = &statsBatcher{}
|
||||
|
||||
func (b *statsBatcher) Add(now time.Time, agentID uuid.UUID, templateID uuid.UUID, userID uuid.UUID, workspaceID uuid.UUID, st *agentproto.Stats) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.called++
|
||||
b.lastTime = now
|
||||
b.lastAgentID = agentID
|
||||
b.lastTemplateID = templateID
|
||||
b.lastUserID = userID
|
||||
b.lastWorkspaceID = workspaceID
|
||||
b.lastStats = st
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateStates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
user = database.User{
|
||||
ID: uuid.New(),
|
||||
Username: "bill",
|
||||
}
|
||||
template = database.Template{
|
||||
ID: uuid.New(),
|
||||
Name: "tpl",
|
||||
}
|
||||
workspace = database.Workspace{
|
||||
ID: uuid.New(),
|
||||
OwnerID: user.ID,
|
||||
TemplateID: template.ID,
|
||||
Name: "xyz",
|
||||
}
|
||||
agent = database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
Name: "abc",
|
||||
}
|
||||
)
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
now = dbtime.Now()
|
||||
dbM = dbmock.NewMockStore(gomock.NewController(t))
|
||||
templateScheduleStore = schedule.MockTemplateScheduleStore{
|
||||
GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) {
|
||||
panic("should not be called")
|
||||
},
|
||||
SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
batcher = &statsBatcher{}
|
||||
updateAgentMetricsFnCalled = false
|
||||
|
||||
req = &agentproto.UpdateStatsRequest{
|
||||
Stats: &agentproto.Stats{
|
||||
ConnectionsByProto: map[string]int64{
|
||||
"tcp": 1,
|
||||
"dean": 2,
|
||||
},
|
||||
ConnectionCount: 3,
|
||||
ConnectionMedianLatencyMs: 23,
|
||||
RxPackets: 120,
|
||||
RxBytes: 1000,
|
||||
TxPackets: 130,
|
||||
TxBytes: 2000,
|
||||
SessionCountVscode: 1,
|
||||
SessionCountJetbrains: 2,
|
||||
SessionCountReconnectingPty: 3,
|
||||
SessionCountSsh: 4,
|
||||
Metrics: []*agentproto.Stats_Metric{
|
||||
{
|
||||
Name: "awesome metric",
|
||||
Value: 42,
|
||||
},
|
||||
{
|
||||
Name: "uncool metric",
|
||||
Value: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
StatsBatcher: batcher,
|
||||
TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore),
|
||||
AgentStatsRefreshInterval: 10 * time.Second,
|
||||
UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) {
|
||||
updateAgentMetricsFnCalled = true
|
||||
assert.Equal(t, prometheusmetrics.AgentMetricLabels{
|
||||
Username: user.Username,
|
||||
WorkspaceName: workspace.Name,
|
||||
AgentName: agent.Name,
|
||||
TemplateName: template.Name,
|
||||
}, labels)
|
||||
assert.Equal(t, req.Stats.Metrics, metrics)
|
||||
},
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Workspace gets fetched.
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{
|
||||
Workspace: workspace,
|
||||
TemplateName: template.Name,
|
||||
}, nil)
|
||||
|
||||
// We expect an activity bump because ConnectionCount > 0.
|
||||
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
NextAutostart: time.Time{}.UTC(),
|
||||
}).Return(nil)
|
||||
|
||||
// Workspace last used at gets bumped.
|
||||
dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{
|
||||
ID: workspace.ID,
|
||||
LastUsedAt: now,
|
||||
}).Return(nil)
|
||||
|
||||
// User gets fetched to hit the UpdateAgentMetricsFn.
|
||||
dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil)
|
||||
|
||||
resp, err := api.UpdateStats(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.UpdateStatsResponse{
|
||||
ReportInterval: durationpb.New(10 * time.Second),
|
||||
}, resp)
|
||||
|
||||
batcher.mu.Lock()
|
||||
defer batcher.mu.Unlock()
|
||||
require.Equal(t, int64(1), batcher.called)
|
||||
require.Equal(t, now, batcher.lastTime)
|
||||
require.Equal(t, agent.ID, batcher.lastAgentID)
|
||||
require.Equal(t, template.ID, batcher.lastTemplateID)
|
||||
require.Equal(t, user.ID, batcher.lastUserID)
|
||||
require.Equal(t, workspace.ID, batcher.lastWorkspaceID)
|
||||
require.Equal(t, req.Stats, batcher.lastStats)
|
||||
|
||||
require.True(t, updateAgentMetricsFnCalled)
|
||||
})
|
||||
|
||||
t.Run("ConnectionCountZero", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
now = dbtime.Now()
|
||||
dbM = dbmock.NewMockStore(gomock.NewController(t))
|
||||
templateScheduleStore = schedule.MockTemplateScheduleStore{
|
||||
GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) {
|
||||
panic("should not be called")
|
||||
},
|
||||
SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
batcher = &statsBatcher{}
|
||||
|
||||
req = &agentproto.UpdateStatsRequest{
|
||||
Stats: &agentproto.Stats{
|
||||
ConnectionsByProto: map[string]int64{
|
||||
"tcp": 1,
|
||||
},
|
||||
ConnectionCount: 0,
|
||||
ConnectionMedianLatencyMs: 23,
|
||||
},
|
||||
}
|
||||
)
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
StatsBatcher: batcher,
|
||||
TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore),
|
||||
AgentStatsRefreshInterval: 10 * time.Second,
|
||||
// Ignored when nil.
|
||||
UpdateAgentMetricsFn: nil,
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Workspace gets fetched.
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{
|
||||
Workspace: workspace,
|
||||
TemplateName: template.Name,
|
||||
}, nil)
|
||||
|
||||
// Workspace last used at gets bumped.
|
||||
dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{
|
||||
ID: workspace.ID,
|
||||
LastUsedAt: now,
|
||||
}).Return(nil)
|
||||
|
||||
_, err := api.UpdateStats(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("NoConnectionsByProto", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
dbM = dbmock.NewMockStore(gomock.NewController(t))
|
||||
req = &agentproto.UpdateStatsRequest{
|
||||
Stats: &agentproto.Stats{
|
||||
ConnectionsByProto: map[string]int64{}, // len() == 0
|
||||
},
|
||||
}
|
||||
)
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
StatsBatcher: nil, // should not be called
|
||||
TemplateScheduleStore: nil, // should not be called
|
||||
AgentStatsRefreshInterval: 10 * time.Second,
|
||||
UpdateAgentMetricsFn: nil, // should not be called
|
||||
TimeNowFn: func() time.Time {
|
||||
panic("should not be called")
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := api.UpdateStats(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.UpdateStatsResponse{
|
||||
ReportInterval: durationpb.New(10 * time.Second),
|
||||
}, resp)
|
||||
})
|
||||
|
||||
t.Run("AutostartAwareBump", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a workspace with an autostart schedule.
|
||||
workspace := workspace
|
||||
workspace.AutostartSchedule = sql.NullString{
|
||||
String: "CRON_TZ=Australia/Sydney 0 8 * * *",
|
||||
Valid: true,
|
||||
}
|
||||
|
||||
// Use a custom time for now which would trigger the autostart aware
|
||||
// bump.
|
||||
now, err := time.Parse("2006-01-02 15:04:05 -0700 MST", "2023-12-19 07:30:00 +1100 AEDT")
|
||||
require.NoError(t, err)
|
||||
now = dbtime.Time(now)
|
||||
nextAutostart := now.Add(30 * time.Minute).UTC() // always sent to DB as UTC
|
||||
|
||||
var (
|
||||
dbM = dbmock.NewMockStore(gomock.NewController(t))
|
||||
templateScheduleStore = schedule.MockTemplateScheduleStore{
|
||||
GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) {
|
||||
return schedule.TemplateScheduleOptions{
|
||||
UserAutostartEnabled: true,
|
||||
AutostartRequirement: schedule.TemplateAutostartRequirement{
|
||||
DaysOfWeek: 0b01111111, // every day
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
batcher = &statsBatcher{}
|
||||
updateAgentMetricsFnCalled = false
|
||||
|
||||
req = &agentproto.UpdateStatsRequest{
|
||||
Stats: &agentproto.Stats{
|
||||
ConnectionsByProto: map[string]int64{
|
||||
"tcp": 1,
|
||||
"dean": 2,
|
||||
},
|
||||
ConnectionCount: 3,
|
||||
},
|
||||
}
|
||||
)
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
StatsBatcher: batcher,
|
||||
TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore),
|
||||
AgentStatsRefreshInterval: 15 * time.Second,
|
||||
UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) {
|
||||
updateAgentMetricsFnCalled = true
|
||||
assert.Equal(t, prometheusmetrics.AgentMetricLabels{
|
||||
Username: user.Username,
|
||||
WorkspaceName: workspace.Name,
|
||||
AgentName: agent.Name,
|
||||
TemplateName: template.Name,
|
||||
}, labels)
|
||||
assert.Equal(t, req.Stats.Metrics, metrics)
|
||||
},
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Workspace gets fetched.
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{
|
||||
Workspace: workspace,
|
||||
TemplateName: template.Name,
|
||||
}, nil)
|
||||
|
||||
// We expect an activity bump because ConnectionCount > 0. However, the
|
||||
// next autostart time will be set on the bump.
|
||||
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
NextAutostart: nextAutostart,
|
||||
}).Return(nil)
|
||||
|
||||
// Workspace last used at gets bumped.
|
||||
dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{
|
||||
ID: workspace.ID,
|
||||
LastUsedAt: now,
|
||||
}).Return(nil)
|
||||
|
||||
// User gets fetched to hit the UpdateAgentMetricsFn.
|
||||
dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil)
|
||||
|
||||
resp, err := api.UpdateStats(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.UpdateStatsResponse{
|
||||
ReportInterval: durationpb.New(15 * time.Second),
|
||||
}, resp)
|
||||
|
||||
require.True(t, updateAgentMetricsFnCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func templateScheduleStorePtr(store schedule.TemplateScheduleStore) *atomic.Pointer[schedule.TemplateScheduleStore] {
|
||||
var ptr atomic.Pointer[schedule.TemplateScheduleStore]
|
||||
ptr.Store(&store)
|
||||
return &ptr
|
||||
}
|
|
@ -266,16 +266,25 @@ func convertDisplayApps(apps []database.DisplayApp) []codersdk.DisplayApp {
|
|||
return dapps
|
||||
}
|
||||
|
||||
func WorkspaceAgentEnvironment(workspaceAgent database.WorkspaceAgent) (map[string]string, error) {
|
||||
var envs map[string]string
|
||||
if workspaceAgent.EnvironmentVariables.Valid {
|
||||
err := json.Unmarshal(workspaceAgent.EnvironmentVariables.RawMessage, &envs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal environment variables: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return envs, nil
|
||||
}
|
||||
|
||||
func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
|
||||
dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, scripts []codersdk.WorkspaceAgentScript, logSources []codersdk.WorkspaceAgentLogSource,
|
||||
agentInactiveDisconnectTimeout time.Duration, agentFallbackTroubleshootingURL string,
|
||||
) (codersdk.WorkspaceAgent, error) {
|
||||
var envs map[string]string
|
||||
if dbAgent.EnvironmentVariables.Valid {
|
||||
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal env vars: %w", err)
|
||||
}
|
||||
envs, err := WorkspaceAgentEnvironment(dbAgent)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceAgent{}, err
|
||||
}
|
||||
troubleshootingURL := agentFallbackTroubleshootingURL
|
||||
if dbAgent.TroubleshootingURL != "" {
|
||||
|
|
|
@ -154,18 +154,24 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request)
|
|||
// As this API becomes deprecated, use the new protobuf API and convert the
|
||||
// types back to the SDK types.
|
||||
manifestAPI := &agentapi.ManifestAPI{
|
||||
AccessURL: api.AccessURL,
|
||||
AppHostname: api.AppHostname,
|
||||
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||
AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
|
||||
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
||||
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
|
||||
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
|
||||
AccessURL: api.AccessURL,
|
||||
AppHostname: api.AppHostname,
|
||||
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
||||
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
|
||||
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
|
||||
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil },
|
||||
Database: api.Database,
|
||||
DerpMapFn: api.DERPMap,
|
||||
TailnetCoordinator: &api.TailnetCoordinator,
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil },
|
||||
WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) {
|
||||
// Sadly this results in a double query, but it's only temporary for
|
||||
// now.
|
||||
ws, err := api.Database.GetWorkspaceByAgentID(ctx, wa.ID)
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
return ws.Workspace.ID, nil
|
||||
},
|
||||
Database: api.Database,
|
||||
DerpMapFn: api.DERPMap,
|
||||
}
|
||||
manifest, err := manifestAPI.GetManifest(ctx, &agentproto.GetManifestRequest{})
|
||||
if err != nil {
|
||||
|
|
|
@ -134,15 +134,13 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
|||
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
|
||||
PublishWorkspaceAgentLogsUpdateFn: api.publishWorkspaceAgentLogsUpdate,
|
||||
|
||||
AccessURL: api.AccessURL,
|
||||
AppHostname: api.AppHostname,
|
||||
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||
AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
|
||||
AgentStatsRefreshInterval: api.AgentStatsRefreshInterval,
|
||||
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
|
||||
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
|
||||
DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
|
||||
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
||||
AccessURL: api.AccessURL,
|
||||
AppHostname: api.AppHostname,
|
||||
AgentStatsRefreshInterval: api.AgentStatsRefreshInterval,
|
||||
DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
|
||||
DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
|
||||
DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency,
|
||||
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
||||
|
||||
// Optional:
|
||||
WorkspaceID: build.WorkspaceID, // saves the extra lookup later
|
||||
|
|
Loading…
Reference in New Issue