diff --git a/coderd/agentapi/activitybump.go b/coderd/agentapi/activitybump.go index ab0797d612..90afaf7e36 100644 --- a/coderd/agentapi/activitybump.go +++ b/coderd/agentapi/activitybump.go @@ -41,13 +41,14 @@ func ActivityBumpWorkspace(ctx context.Context, log slog.Logger, db database.Sto // low priority operations fail first. ctx, cancel := context.WithTimeout(ctx, time.Second*15) defer cancel() - if err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{ + err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{ NextAutostart: nextAutostart.UTC(), WorkspaceID: workspaceID, - }); err != nil { + }) + if err != nil { if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) { // Bump will fail if the context is canceled, but this is ok. - log.Error(ctx, "bump failed", slog.Error(err), + log.Error(ctx, "activity bump failed", slog.Error(err), slog.F("workspace_id", workspaceID), ) } diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index 1f74685d62..e3c7b1d067 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -17,7 +17,6 @@ import ( "cdr.dev/slog" agentproto "github.com/coder/coder/v2/agent/proto" - "github.com/coder/coder/v2/coderd/batchstats" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" @@ -61,19 +60,17 @@ type Options struct { DerpMapFn func() *tailcfg.DERPMap TailnetCoordinator *atomic.Pointer[tailnet.Coordinator] TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] - StatsBatcher *batchstats.Batcher + StatsBatcher StatsBatcher PublishWorkspaceUpdateFn func(ctx context.Context, workspaceID uuid.UUID) PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) - AccessURL *url.URL - AppHostname string - AgentInactiveDisconnectTimeout time.Duration - AgentFallbackTroubleshootingURL string - AgentStatsRefreshInterval time.Duration - DisableDirectConnections bool - DerpForceWebSockets bool - DerpMapUpdateFrequency time.Duration - ExternalAuthConfigs []*externalauth.Config + AccessURL *url.URL + AppHostname string + AgentStatsRefreshInterval time.Duration + DisableDirectConnections bool + DerpForceWebSockets bool + DerpMapUpdateFrequency time.Duration + ExternalAuthConfigs []*externalauth.Config // Optional: // WorkspaceID avoids a future lookup to find the workspace ID by setting @@ -90,17 +87,14 @@ func New(opts Options) *API { } api.ManifestAPI = &ManifestAPI{ - AccessURL: opts.AccessURL, - AppHostname: opts.AppHostname, - AgentInactiveDisconnectTimeout: opts.AgentInactiveDisconnectTimeout, - AgentFallbackTroubleshootingURL: opts.AgentFallbackTroubleshootingURL, - ExternalAuthConfigs: opts.ExternalAuthConfigs, - DisableDirectConnections: opts.DisableDirectConnections, - DerpForceWebSockets: opts.DerpForceWebSockets, - AgentFn: api.agent, - Database: opts.Database, - DerpMapFn: opts.DerpMapFn, - TailnetCoordinator: opts.TailnetCoordinator, + AccessURL: opts.AccessURL, + AppHostname: opts.AppHostname, + ExternalAuthConfigs: opts.ExternalAuthConfigs, + DisableDirectConnections: opts.DisableDirectConnections, + DerpForceWebSockets: opts.DerpForceWebSockets, + AgentFn: api.agent, + Database: opts.Database, + DerpMapFn: opts.DerpMapFn, } api.ServiceBannerAPI = &ServiceBannerAPI{ @@ -214,20 +208,15 @@ func (a *API) workspaceID(ctx context.Context, agent *database.WorkspaceAgent) ( agent = &agnt } - resource, err := a.opts.Database.GetWorkspaceResourceByID(ctx, agent.ResourceID) + getWorkspaceAgentByIDRow, err := a.opts.Database.GetWorkspaceByAgentID(ctx, agent.ID) if err != nil { - return uuid.Nil, xerrors.Errorf("get workspace agent resource by id %q: %w", agent.ResourceID, err) - } - - build, err := a.opts.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) - if err != nil { - return uuid.Nil, xerrors.Errorf("get workspace build by job id %q: %w", resource.JobID, err) + return uuid.Nil, xerrors.Errorf("get workspace by agent id %q: %w", agent.ID, err) } a.mu.Lock() - a.cachedWorkspaceID = build.WorkspaceID + a.cachedWorkspaceID = getWorkspaceAgentByIDRow.Workspace.ID a.mu.Unlock() - return build.WorkspaceID, nil + return getWorkspaceAgentByIDRow.Workspace.ID, nil } func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent) error { diff --git a/coderd/agentapi/apps.go b/coderd/agentapi/apps.go index 1346d7a9b4..7e8bda1262 100644 --- a/coderd/agentapi/apps.go +++ b/coderd/agentapi/apps.go @@ -90,9 +90,11 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat } } - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) + if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 { + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + if err != nil { + return nil, xerrors.Errorf("publish workspace update: %w", err) + } } return &agentproto.BatchUpdateAppHealthResponse{}, nil } diff --git a/coderd/agentapi/apps_test.go b/coderd/agentapi/apps_test.go new file mode 100644 index 0000000000..c774c6777b --- /dev/null +++ b/coderd/agentapi/apps_test.go @@ -0,0 +1,252 @@ +package agentapi_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog/sloggers/slogtest" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" +) + +func TestBatchUpdateAppHealths(t *testing.T) { + t.Parallel() + + var ( + agent = database.WorkspaceAgent{ + ID: uuid.New(), + } + app1 = database.WorkspaceApp{ + ID: uuid.New(), + AgentID: agent.ID, + Slug: "code-server-1", + DisplayName: "code-server 1", + HealthcheckUrl: "http://localhost:3000", + Health: database.WorkspaceAppHealthInitializing, + } + app2 = database.WorkspaceApp{ + ID: uuid.New(), + AgentID: agent.ID, + Slug: "code-server-2", + DisplayName: "code-server 2", + HealthcheckUrl: "http://localhost:3001", + Health: database.WorkspaceAppHealthHealthy, + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) + dbM.EXPECT().UpdateWorkspaceAppHealthByID(gomock.Any(), database.UpdateWorkspaceAppHealthByIDParams{ + ID: app1.ID, + Health: database.WorkspaceAppHealthHealthy, + }).Return(nil) + dbM.EXPECT().UpdateWorkspaceAppHealthByID(gomock.Any(), database.UpdateWorkspaceAppHealthByIDParams{ + ID: app2.ID, + Health: database.WorkspaceAppHealthUnhealthy, + }).Return(nil) + + publishCalled := false + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishCalled = true + return nil + }, + } + + // Set one to healthy, set another to unhealthy. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: app1.ID[:], + Health: agentproto.AppHealth_HEALTHY, + }, + { + Id: app2.ID[:], + Health: agentproto.AppHealth_UNHEALTHY, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) + + require.True(t, publishCalled) + }) + + t.Run("Unchanged", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) + + publishCalled := false + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishCalled = true + return nil + }, + } + + // Set both to their current status, neither should be updated in the + // DB. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: app1.ID[:], + Health: agentproto.AppHealth_INITIALIZING, + }, + { + Id: app2.ID[:], + Health: agentproto.AppHealth_HEALTHY, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) + + require.False(t, publishCalled) + }) + + t.Run("Empty", func(t *testing.T) { + t.Parallel() + + // No DB queries are made if there are no updates to process. + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + publishCalled := false + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishCalled = true + return nil + }, + } + + // Do nothing. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{}, + }) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) + + require.False(t, publishCalled) + }) + + t.Run("AppNoHealthcheck", func(t *testing.T) { + t.Parallel() + + app3 := database.WorkspaceApp{ + ID: uuid.New(), + AgentID: agent.ID, + Slug: "code-server-3", + DisplayName: "code-server 3", + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil) + + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: nil, + } + + // Set app3 to healthy, should error. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: app3.ID[:], + Health: agentproto.AppHealth_HEALTHY, + }, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "does not have healthchecks enabled") + require.Nil(t, resp) + }) + + t.Run("UnknownApp", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) + + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: nil, + } + + // Set an unknown app to healthy, should error. + id := uuid.New() + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: id[:], + Health: agentproto.AppHealth_HEALTHY, + }, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "not found") + require.Nil(t, resp) + }) + + t.Run("InvalidHealth", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) + + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: nil, + } + + // Set an unknown app to healthy, should error. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: app1.ID[:], + Health: -999, + }, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "unknown health status") + require.Nil(t, resp) + }) +} diff --git a/coderd/agentapi/lifecycle.go b/coderd/agentapi/lifecycle.go index d909d35eb8..662d0c0c2e 100644 --- a/coderd/agentapi/lifecycle.go +++ b/coderd/agentapi/lifecycle.go @@ -3,6 +3,7 @@ package agentapi import ( "context" "database/sql" + "time" "github.com/google/uuid" "golang.org/x/mod/semver" @@ -21,6 +22,15 @@ type LifecycleAPI struct { Database database.Store Log slog.Logger PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error + + TimeNowFn func() time.Time // defaults to dbtime.Now() +} + +func (a *LifecycleAPI) now() time.Time { + if a.TimeNowFn != nil { + return a.TimeNowFn() + } + return dbtime.Now() } func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) { @@ -68,7 +78,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda changedAt := req.Lifecycle.ChangedAt.AsTime() if changedAt.IsZero() { - changedAt = dbtime.Now() + changedAt = a.now() req.Lifecycle.ChangedAt = timestamppb.New(changedAt) } dbChangedAt := sql.NullTime{Time: changedAt, Valid: true} @@ -78,8 +88,13 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda switch lifecycleState { case database.WorkspaceAgentLifecycleStateStarting: startedAt = dbChangedAt - readyAt.Valid = false // This agent is re-starting, so it's not ready yet. + // This agent is (re)starting, so it's not ready yet. + readyAt.Time = time.Time{} + readyAt.Valid = false case database.WorkspaceAgentLifecycleStateReady, database.WorkspaceAgentLifecycleStateStartError: + if !startedAt.Valid { + startedAt = dbChangedAt + } readyAt = dbChangedAt } @@ -97,9 +112,11 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda return nil, xerrors.Errorf("update workspace agent lifecycle state: %w", err) } - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) + if a.PublishWorkspaceUpdateFn != nil { + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + if err != nil { + return nil, xerrors.Errorf("publish workspace update: %w", err) + } } return req.Lifecycle, nil diff --git a/coderd/agentapi/lifecycle_test.go b/coderd/agentapi/lifecycle_test.go new file mode 100644 index 0000000000..855ff9329a --- /dev/null +++ b/coderd/agentapi/lifecycle_test.go @@ -0,0 +1,461 @@ +package agentapi_test + +import ( + "context" + "database/sql" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "google.golang.org/protobuf/types/known/timestamppb" + + "cdr.dev/slog/sloggers/slogtest" + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" +) + +func TestUpdateLifecycle(t *testing.T) { + t.Parallel() + + someTime, err := time.Parse(time.RFC3339, "2023-01-01T00:00:00Z") + require.NoError(t, err) + someTime = dbtime.Time(someTime) + now := dbtime.Now() + + var ( + workspaceID = uuid.New() + agentCreated = database.WorkspaceAgent{ + ID: uuid.New(), + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + StartedAt: sql.NullTime{Valid: false}, + ReadyAt: sql.NullTime{Valid: false}, + } + agentStarting = database.WorkspaceAgent{ + ID: uuid.New(), + LifecycleState: database.WorkspaceAgentLifecycleStateStarting, + StartedAt: sql.NullTime{Valid: true, Time: someTime}, + ReadyAt: sql.NullTime{Valid: false}, + } + ) + + t.Run("OKStarting", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: agentproto.Lifecycle_STARTING, + ChangedAt: timestamppb.New(now), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agentCreated.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateStarting, + StartedAt: sql.NullTime{ + Time: now, + Valid: true, + }, + ReadyAt: sql.NullTime{Valid: false}, + }).Return(nil) + + publishCalled := false + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentCreated, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + publishCalled = true + return nil + }, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + require.True(t, publishCalled) + }) + + t.Run("OKReadying", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: agentproto.Lifecycle_READY, + ChangedAt: timestamppb.New(now), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agentStarting.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: agentStarting.StartedAt, + ReadyAt: sql.NullTime{ + Time: now, + Valid: true, + }, + }).Return(nil) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentStarting, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Test that nil publish fn works. + PublishWorkspaceUpdateFn: nil, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + }) + + // This test jumps from CREATING to READY, skipping STARTED. Both the + // StartedAt and ReadyAt fields should be set. + t.Run("OKStraightToReady", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: agentproto.Lifecycle_READY, + ChangedAt: timestamppb.New(now), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agentCreated.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: sql.NullTime{ + Time: now, + Valid: true, + }, + ReadyAt: sql.NullTime{ + Time: now, + Valid: true, + }, + }).Return(nil) + + publishCalled := false + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentCreated, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + publishCalled = true + return nil + }, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + require.True(t, publishCalled) + }) + + t.Run("NoTimeSpecified", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: agentproto.Lifecycle_READY, + // Zero time + ChangedAt: timestamppb.New(time.Time{}), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + now := dbtime.Now() + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agentCreated.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: sql.NullTime{ + Time: now, + Valid: true, + }, + ReadyAt: sql.NullTime{ + Time: now, + Valid: true, + }, + }) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentCreated, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: nil, + TimeNowFn: func() time.Time { + return now + }, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + }) + + t.Run("AllStates", func(t *testing.T) { + t.Parallel() + + agent := database.WorkspaceAgent{ + ID: uuid.New(), + LifecycleState: database.WorkspaceAgentLifecycleState(""), + StartedAt: sql.NullTime{Valid: false}, + ReadyAt: sql.NullTime{Valid: false}, + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishCalled int64 + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + atomic.AddInt64(&publishCalled, 1) + return nil + }, + } + + states := []agentproto.Lifecycle_State{ + agentproto.Lifecycle_CREATED, + agentproto.Lifecycle_STARTING, + agentproto.Lifecycle_START_TIMEOUT, + agentproto.Lifecycle_START_ERROR, + agentproto.Lifecycle_READY, + agentproto.Lifecycle_SHUTTING_DOWN, + agentproto.Lifecycle_SHUTDOWN_TIMEOUT, + agentproto.Lifecycle_SHUTDOWN_ERROR, + agentproto.Lifecycle_OFF, + } + for i, state := range states { + t.Log("state", state) + // Use a time after the last state change to ensure ordering. + stateNow := now.Add(time.Hour * time.Duration(i)) + lifecycle := &agentproto.Lifecycle{ + State: state, + ChangedAt: timestamppb.New(stateNow), + } + + expectedStartedAt := agent.StartedAt + expectedReadyAt := agent.ReadyAt + if state == agentproto.Lifecycle_STARTING { + expectedStartedAt = sql.NullTime{Valid: true, Time: stateNow} + } + if state == agentproto.Lifecycle_READY || state == agentproto.Lifecycle_START_ERROR { + expectedReadyAt = sql.NullTime{Valid: true, Time: stateNow} + } + + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agent.ID, + LifecycleState: database.WorkspaceAgentLifecycleState(strings.ToLower(state.String())), + StartedAt: expectedStartedAt, + ReadyAt: expectedReadyAt, + }).Times(1).Return(nil) + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + require.Equal(t, int64(i+1), atomic.LoadInt64(&publishCalled)) + + // For future iterations: + agent.StartedAt = expectedStartedAt + agent.ReadyAt = expectedReadyAt + } + }) + + t.Run("UnknownLifecycleState", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: -999, + ChangedAt: timestamppb.New(now), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + publishCalled := false + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentCreated, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + publishCalled = true + return nil + }, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.Error(t, err) + require.ErrorContains(t, err, "unknown lifecycle state") + require.Nil(t, resp) + require.False(t, publishCalled) + }) +} + +func TestUpdateStartup(t *testing.T) { + t.Parallel() + + var ( + workspaceID = uuid.New() + agent = database.WorkspaceAgent{ + ID: uuid.New(), + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Not used by UpdateStartup. + PublishWorkspaceUpdateFn: nil, + } + + startup := &agentproto.Startup{ + Version: "v1.2.3", + ExpandedDirectory: "/path/to/expanded/dir", + Subsystems: []agentproto.Startup_Subsystem{ + agentproto.Startup_ENVBOX, + agentproto.Startup_ENVBUILDER, + agentproto.Startup_EXECTRACE, + }, + } + + dbM.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), database.UpdateWorkspaceAgentStartupByIDParams{ + ID: agent.ID, + Version: startup.Version, + ExpandedDirectory: startup.ExpandedDirectory, + Subsystems: []database.WorkspaceAgentSubsystem{ + database.WorkspaceAgentSubsystemEnvbox, + database.WorkspaceAgentSubsystemEnvbuilder, + database.WorkspaceAgentSubsystemExectrace, + }, + APIVersion: agentapi.AgentAPIVersionDRPC, + }).Return(nil) + + resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + Startup: startup, + }) + require.NoError(t, err) + require.Equal(t, startup, resp) + }) + + t.Run("BadVersion", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Not used by UpdateStartup. + PublishWorkspaceUpdateFn: nil, + } + + startup := &agentproto.Startup{ + Version: "asdf", + ExpandedDirectory: "/path/to/expanded/dir", + Subsystems: []agentproto.Startup_Subsystem{}, + } + + resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + Startup: startup, + }) + require.Error(t, err) + require.ErrorContains(t, err, "invalid agent semver version") + require.Nil(t, resp) + }) + + t.Run("BadSubsystem", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Not used by UpdateStartup. + PublishWorkspaceUpdateFn: nil, + } + + startup := &agentproto.Startup{ + Version: "v1.2.3", + ExpandedDirectory: "/path/to/expanded/dir", + Subsystems: []agentproto.Startup_Subsystem{ + agentproto.Startup_ENVBOX, + -999, + }, + } + + resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + Startup: startup, + }) + require.Error(t, err) + require.ErrorContains(t, err, "invalid agent subsystem") + require.Nil(t, resp) + }) +} diff --git a/coderd/agentapi/logs.go b/coderd/agentapi/logs.go index 7d34b41e13..cb3a920b9a 100644 --- a/coderd/agentapi/logs.go +++ b/coderd/agentapi/logs.go @@ -2,6 +2,7 @@ package agentapi import ( "context" + "time" "github.com/google/uuid" "golang.org/x/xerrors" @@ -19,6 +20,15 @@ type LogsAPI struct { Log slog.Logger PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) + + TimeNowFn func() time.Time // defaults to dbtime.Now() +} + +func (a *LogsAPI) now() time.Time { + if a.TimeNowFn != nil { + return a.TimeNowFn() + } + return dbtime.Now() } func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) { @@ -26,6 +36,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea if err != nil { return nil, err } + if workspaceAgent.LogsOverflowed { + return nil, xerrors.New("workspace agent logs overflowed") + } if len(req.Logs) == 0 { return &agentproto.BatchCreateLogsResponse{}, nil @@ -42,7 +55,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea // Use the external log source externalSources, err := a.Database.InsertWorkspaceAgentLogSources(ctx, database.InsertWorkspaceAgentLogSourcesParams{ WorkspaceAgentID: workspaceAgent.ID, - CreatedAt: dbtime.Now(), + CreatedAt: a.now(), ID: []uuid.UUID{agentsdk.ExternalLogSourceID}, DisplayName: []string{"External"}, Icon: []string{"/emojis/1f310.png"}, @@ -88,7 +101,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea logs, err := a.Database.InsertWorkspaceAgentLogs(ctx, database.InsertWorkspaceAgentLogsParams{ AgentID: workspaceAgent.ID, - CreatedAt: dbtime.Now(), + CreatedAt: a.now(), Output: output, Level: level, LogSourceID: logSourceID, @@ -98,9 +111,6 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea if !database.IsWorkspaceAgentLogsLimitError(err) { return nil, xerrors.Errorf("insert workspace agent logs: %w", err) } - if workspaceAgent.LogsOverflowed { - return nil, xerrors.New("workspace agent logs overflowed") - } err := a.Database.UpdateWorkspaceAgentLogOverflowByID(ctx, database.UpdateWorkspaceAgentLogOverflowByIDParams{ ID: workspaceAgent.ID, LogsOverflowed: true, @@ -112,21 +122,25 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea a.Log.Warn(ctx, "failed to update workspace agent log overflow", slog.Error(err)) } - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) + if a.PublishWorkspaceUpdateFn != nil { + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + if err != nil { + return nil, xerrors.Errorf("publish workspace update: %w", err) + } } return nil, xerrors.New("workspace agent log limit exceeded") } // Publish by the lowest log ID inserted so the log stream will fetch // everything from that point. - lowestLogID := logs[0].ID - a.PublishWorkspaceAgentLogsUpdateFn(ctx, workspaceAgent.ID, agentsdk.LogsNotifyMessage{ - CreatedAfter: lowestLogID - 1, - }) + if a.PublishWorkspaceAgentLogsUpdateFn != nil { + lowestLogID := logs[0].ID + a.PublishWorkspaceAgentLogsUpdateFn(ctx, workspaceAgent.ID, agentsdk.LogsNotifyMessage{ + CreatedAfter: lowestLogID - 1, + }) + } - if workspaceAgent.LogsLength == 0 { + if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil { // If these are the first logs being appended, we publish a UI update // to notify the UI that logs are now available. err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) diff --git a/coderd/agentapi/logs_test.go b/coderd/agentapi/logs_test.go new file mode 100644 index 0000000000..66fbaa005d --- /dev/null +++ b/coderd/agentapi/logs_test.go @@ -0,0 +1,427 @@ +package agentapi_test + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "google.golang.org/protobuf/types/known/timestamppb" + + "cdr.dev/slog/sloggers/slogtest" + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/codersdk/agentsdk" +) + +func TestBatchCreateLogs(t *testing.T) { + t.Parallel() + + var ( + agent = database.WorkspaceAgent{ + ID: uuid.New(), + } + logSource = database.WorkspaceAgentLogSource{ + WorkspaceAgentID: agent.ID, + CreatedAt: dbtime.Now(), + ID: uuid.New(), + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false + now := dbtime.Now() + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishWorkspaceUpdateCalled = true + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + publishWorkspaceAgentLogsUpdateCalled = true + + // Check the message content, should be for -1 since the lowest + // log we inserted was 0. + assert.Equal(t, agentsdk.LogsNotifyMessage{CreatedAfter: -1}, msg) + }, + TimeNowFn: func() time.Time { return now }, + } + + req := &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(now), + Level: agentproto.Log_TRACE, + Output: "log line 1", + }, + { + CreatedAt: timestamppb.New(now.Add(time.Hour)), + Level: agentproto.Log_DEBUG, + Output: "log line 2", + }, + { + CreatedAt: timestamppb.New(now.Add(2 * time.Hour)), + Level: agentproto.Log_INFO, + Output: "log line 3", + }, + { + CreatedAt: timestamppb.New(now.Add(3 * time.Hour)), + Level: agentproto.Log_WARN, + Output: "log line 4", + }, + { + CreatedAt: timestamppb.New(now.Add(4 * time.Hour)), + Level: agentproto.Log_ERROR, + Output: "log line 5", + }, + { + CreatedAt: timestamppb.New(now.Add(5 * time.Hour)), + Level: -999, // defaults to INFO + Output: "log line 6", + }, + }, + } + + // Craft expected DB request and response dynamically. + insertWorkspaceAgentLogsParams := database.InsertWorkspaceAgentLogsParams{ + AgentID: agent.ID, + LogSourceID: logSource.ID, + CreatedAt: now, + Output: make([]string, len(req.Logs)), + Level: make([]database.LogLevel, len(req.Logs)), + OutputLength: 0, + } + insertWorkspaceAgentLogsReturn := make([]database.WorkspaceAgentLog, len(req.Logs)) + for i, logEntry := range req.Logs { + insertWorkspaceAgentLogsParams.Output[i] = logEntry.Output + level := database.LogLevelInfo + if logEntry.Level >= 0 { + level = database.LogLevel(strings.ToLower(logEntry.Level.String())) + } + insertWorkspaceAgentLogsParams.Level[i] = level + insertWorkspaceAgentLogsParams.OutputLength += int32(len(logEntry.Output)) + + insertWorkspaceAgentLogsReturn[i] = database.WorkspaceAgentLog{ + AgentID: agent.ID, + CreatedAt: logEntry.CreatedAt.AsTime(), + ID: int64(i), + Output: logEntry.Output, + Level: insertWorkspaceAgentLogsParams.Level[i], + LogSourceID: logSource.ID, + } + } + + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), insertWorkspaceAgentLogsParams).Return(insertWorkspaceAgentLogsReturn, nil) + + resp, err := api.BatchCreateLogs(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + require.True(t, publishWorkspaceUpdateCalled) + require.True(t, publishWorkspaceAgentLogsUpdateCalled) + }) + + t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) { + t.Parallel() + + agentWithLogs := agent + agentWithLogs.LogsLength = 1 + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agentWithLogs, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishWorkspaceUpdateCalled = true + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + publishWorkspaceAgentLogsUpdateCalled = true + }, + } + + // Don't really care about the DB call. + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), gomock.Any()).Return([]database.WorkspaceAgentLog{ + { + ID: 1, + }, + }, nil) + + resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(dbtime.Now()), + Level: agentproto.Log_INFO, + Output: "hello world", + }, + }, + }) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + require.False(t, publishWorkspaceUpdateCalled) + require.True(t, publishWorkspaceAgentLogsUpdateCalled) + }) + + t.Run("AlreadyOverflowed", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + overflowedAgent := agent + overflowedAgent.LogsOverflowed = true + + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return overflowedAgent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishWorkspaceUpdateCalled = true + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + publishWorkspaceAgentLogsUpdateCalled = true + }, + } + + resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{}, + }) + require.Error(t, err) + require.ErrorContains(t, err, "workspace agent logs overflowed") + require.Nil(t, resp) + require.False(t, publishWorkspaceUpdateCalled) + require.False(t, publishWorkspaceAgentLogsUpdateCalled) + }) + + t.Run("InvalidLogSourceID", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Test that they are ignored when nil. + PublishWorkspaceUpdateFn: nil, + PublishWorkspaceAgentLogsUpdateFn: nil, + } + + resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{ + LogSourceId: []byte("invalid"), + Logs: []*agentproto.Log{ + {}, // need at least 1 log + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "parse log source ID") + require.Nil(t, resp) + }) + + t.Run("UseExternalLogSourceID", func(t *testing.T) { + t.Parallel() + + now := dbtime.Now() + req := &agentproto.BatchCreateLogsRequest{ + LogSourceId: uuid.Nil[:], // defaults to "external" + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(now), + Level: agentproto.Log_INFO, + Output: "hello world", + }, + }, + } + dbInsertParams := database.InsertWorkspaceAgentLogsParams{ + AgentID: agent.ID, + LogSourceID: agentsdk.ExternalLogSourceID, + CreatedAt: now, + Output: []string{"hello world"}, + Level: []database.LogLevel{database.LogLevelInfo}, + OutputLength: int32(len(req.Logs[0].Output)), + } + dbInsertRes := []database.WorkspaceAgentLog{ + { + AgentID: agent.ID, + CreatedAt: now, + ID: 1, + Output: "hello world", + Level: database.LogLevelInfo, + LogSourceID: agentsdk.ExternalLogSourceID, + }, + } + + t.Run("Create", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishWorkspaceUpdateCalled = true + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + publishWorkspaceAgentLogsUpdateCalled = true + }, + TimeNowFn: func() time.Time { return now }, + } + + dbM.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), database.InsertWorkspaceAgentLogSourcesParams{ + WorkspaceAgentID: agent.ID, + CreatedAt: now, + ID: []uuid.UUID{agentsdk.ExternalLogSourceID}, + DisplayName: []string{"External"}, + Icon: []string{"/emojis/1f310.png"}, + }).Return([]database.WorkspaceAgentLogSource{ + { + // only the ID field is used + ID: agentsdk.ExternalLogSourceID, + }, + }, nil) + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), dbInsertParams).Return(dbInsertRes, nil) + + resp, err := api.BatchCreateLogs(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + require.True(t, publishWorkspaceUpdateCalled) + require.True(t, publishWorkspaceAgentLogsUpdateCalled) + }) + + t.Run("Exists", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishWorkspaceUpdateCalled = true + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + publishWorkspaceAgentLogsUpdateCalled = true + }, + TimeNowFn: func() time.Time { return now }, + } + + // Return a unique violation error to simulate the log source + // already existing. This should be handled gracefully. + logSourceInsertErr := &pq.Error{ + Code: pq.ErrorCode("23505"), // unique_violation + Constraint: string(database.UniqueWorkspaceAgentLogSourcesPkey), + } + dbM.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), database.InsertWorkspaceAgentLogSourcesParams{ + WorkspaceAgentID: agent.ID, + CreatedAt: now, + ID: []uuid.UUID{agentsdk.ExternalLogSourceID}, + DisplayName: []string{"External"}, + Icon: []string{"/emojis/1f310.png"}, + }).Return([]database.WorkspaceAgentLogSource{}, logSourceInsertErr) + + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), dbInsertParams).Return(dbInsertRes, nil) + + resp, err := api.BatchCreateLogs(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + require.True(t, publishWorkspaceUpdateCalled) + require.True(t, publishWorkspaceAgentLogsUpdateCalled) + }) + }) + + t.Run("Overflow", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + publishWorkspaceUpdateCalled = true + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + publishWorkspaceAgentLogsUpdateCalled = true + }, + } + + // Don't really care about the DB call params, just want to return an + // error. + dbErr := &pq.Error{ + Constraint: "max_logs_length", + Table: "workspace_agents", + } + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), gomock.Any()).Return(nil, dbErr) + + // Should also update the workspace agent. + dbM.EXPECT().UpdateWorkspaceAgentLogOverflowByID(gomock.Any(), database.UpdateWorkspaceAgentLogOverflowByIDParams{ + ID: agent.ID, + LogsOverflowed: true, + }).Return(nil) + + resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(dbtime.Now()), + Level: agentproto.Log_INFO, + Output: "hello world", + }, + }, + }) + require.Error(t, err) + require.Nil(t, resp) + require.True(t, publishWorkspaceUpdateCalled) + require.False(t, publishWorkspaceAgentLogsUpdateCalled) + }) +} diff --git a/coderd/agentapi/manifest.go b/coderd/agentapi/manifest.go index ddd562c969..75f4f953fd 100644 --- a/coderd/agentapi/manifest.go +++ b/coderd/agentapi/manifest.go @@ -5,7 +5,6 @@ import ( "database/sql" "net/url" "strings" - "sync/atomic" "time" "github.com/google/uuid" @@ -25,18 +24,16 @@ import ( ) type ManifestAPI struct { - AccessURL *url.URL - AppHostname string - AgentInactiveDisconnectTimeout time.Duration - AgentFallbackTroubleshootingURL string - ExternalAuthConfigs []*externalauth.Config - DisableDirectConnections bool - DerpForceWebSockets bool + AccessURL *url.URL + AppHostname string + ExternalAuthConfigs []*externalauth.Config + DisableDirectConnections bool + DerpForceWebSockets bool - AgentFn func(context.Context) (database.WorkspaceAgent, error) - Database database.Store - DerpMapFn func() *tailcfg.DERPMap - TailnetCoordinator *atomic.Pointer[tailnet.Coordinator] + AgentFn func(context.Context) (database.WorkspaceAgent, error) + WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) + Database database.Store + DerpMapFn func() *tailcfg.DERPMap } func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) { @@ -44,21 +41,15 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest if err != nil { return nil, err } - - apiAgent, err := db2sdk.WorkspaceAgent( - a.DerpMapFn(), *a.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, a.AgentInactiveDisconnectTimeout, - a.AgentFallbackTroubleshootingURL, - ) + workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent) if err != nil { - return nil, xerrors.Errorf("converting workspace agent: %w", err) + return nil, err } var ( dbApps []database.WorkspaceApp scripts []database.WorkspaceAgentScript metadata []database.WorkspaceAgentMetadatum - resource database.WorkspaceResource - build database.WorkspaceBuild workspace database.Workspace owner database.User ) @@ -79,20 +70,12 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest eg.Go(func() (err error) { metadata, err = a.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{ WorkspaceAgentID: workspaceAgent.ID, - Keys: nil, + Keys: nil, // all }) return err }) eg.Go(func() (err error) { - resource, err = a.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) - if err != nil { - return xerrors.Errorf("getting resource by id: %w", err) - } - build, err = a.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) - if err != nil { - return xerrors.Errorf("getting workspace build by job id: %w", err) - } - workspace, err = a.Database.GetWorkspaceByID(ctx, build.WorkspaceID) + workspace, err = a.Database.GetWorkspaceByID(ctx, workspaceID) if err != nil { return xerrors.Errorf("getting workspace by id: %w", err) } @@ -116,6 +99,11 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest vscodeProxyURI := vscodeProxyURI(appSlug, a.AccessURL, a.AppHostname) + envs, err := db2sdk.WorkspaceAgentEnvironment(workspaceAgent) + if err != nil { + return nil, err + } + var gitAuthConfigs uint32 for _, cfg := range a.ExternalAuthConfigs { if codersdk.EnhancedExternalAuthProvider(cfg.Type).Git() { @@ -135,8 +123,8 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest WorkspaceId: workspace.ID[:], WorkspaceName: workspace.Name, GitAuthConfigs: gitAuthConfigs, - EnvironmentVariables: apiAgent.EnvironmentVariables, - Directory: apiAgent.Directory, + EnvironmentVariables: envs, + Directory: workspaceAgent.Directory, VsCodePortProxyUri: vscodeProxyURI, MotdPath: workspaceAgent.MOTDFile, DisableDirectConnections: a.DisableDirectConnections, diff --git a/coderd/agentapi/manifest_test.go b/coderd/agentapi/manifest_test.go new file mode 100644 index 0000000000..575bc353f7 --- /dev/null +++ b/coderd/agentapi/manifest_test.go @@ -0,0 +1,396 @@ +package agentapi_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/url" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "google.golang.org/protobuf/types/known/durationpb" + "tailscale.com/tailcfg" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/externalauth" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/tailnet" +) + +func TestGetManifest(t *testing.T) { + t.Parallel() + + someTime, err := time.Parse(time.RFC3339, "2023-01-01T00:00:00Z") + require.NoError(t, err) + someTime = dbtime.Time(someTime) + + expectedEnvVars := map[string]string{ + "FOO": "bar", + "COOL_ENV": "dean was here", + } + expectedEnvVarsJSON, err := json.Marshal(expectedEnvVars) + require.NoError(t, err) + + var ( + owner = database.User{ + ID: uuid.New(), + Username: "cool-user", + } + workspace = database.Workspace{ + ID: uuid.New(), + OwnerID: owner.ID, + Name: "cool-workspace", + } + agent = database.WorkspaceAgent{ + ID: uuid.New(), + Name: "cool-agent", + EnvironmentVariables: pqtype.NullRawMessage{ + RawMessage: expectedEnvVarsJSON, + Valid: true, + }, + Directory: "/cool/dir", + MOTDFile: "/cool/motd", + } + apps = []database.WorkspaceApp{ + { + ID: uuid.New(), + Url: sql.NullString{String: "http://localhost:1234", Valid: true}, + External: false, + Slug: "cool-app-1", + DisplayName: "app 1", + Command: sql.NullString{String: "cool command", Valid: true}, + Icon: "/icon.png", + Subdomain: true, + SharingLevel: database.AppSharingLevelAuthenticated, + Health: database.WorkspaceAppHealthHealthy, + HealthcheckUrl: "http://localhost:1234/health", + HealthcheckInterval: 10, + HealthcheckThreshold: 3, + }, + { + ID: uuid.New(), + Url: sql.NullString{String: "http://google.com", Valid: true}, + External: true, + Slug: "google", + DisplayName: "Literally Google", + Command: sql.NullString{Valid: false}, + Icon: "/google.png", + Subdomain: false, + SharingLevel: database.AppSharingLevelPublic, + Health: database.WorkspaceAppHealthDisabled, + }, + { + ID: uuid.New(), + Url: sql.NullString{String: "http://localhost:4321", Valid: true}, + External: true, + Slug: "cool-app-2", + DisplayName: "another COOL app", + Command: sql.NullString{Valid: false}, + Icon: "", + Subdomain: false, + SharingLevel: database.AppSharingLevelOwner, + Health: database.WorkspaceAppHealthUnhealthy, + HealthcheckUrl: "http://localhost:4321/health", + HealthcheckInterval: 20, + HealthcheckThreshold: 5, + }, + } + scripts = []database.WorkspaceAgentScript{ + { + WorkspaceAgentID: agent.ID, + LogSourceID: uuid.New(), + LogPath: "/cool/log/path/1", + Script: "cool script 1", + Cron: "30 2 * * *", + StartBlocksLogin: true, + RunOnStart: true, + RunOnStop: false, + TimeoutSeconds: 60, + }, + { + WorkspaceAgentID: agent.ID, + LogSourceID: uuid.New(), + LogPath: "/cool/log/path/2", + Script: "cool script 2", + Cron: "", + StartBlocksLogin: false, + RunOnStart: false, + RunOnStop: true, + TimeoutSeconds: 30, + }, + } + metadata = []database.WorkspaceAgentMetadatum{ + { + WorkspaceAgentID: agent.ID, + DisplayName: "cool metadata 1", + Key: "cool-key-1", + Script: "cool script 1", + Value: "cool value 1", + Error: "", + Timeout: int64(time.Minute), + Interval: int64(time.Minute), + CollectedAt: someTime, + }, + { + WorkspaceAgentID: agent.ID, + DisplayName: "cool metadata 2", + Key: "cool-key-2", + Script: "cool script 2", + Value: "cool value 2", + Error: "some uncool error", + Timeout: int64(5 * time.Second), + Interval: int64(20 * time.Minute), + CollectedAt: someTime.Add(time.Hour), + }, + } + derpMapFn = func() *tailcfg.DERPMap { + return &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: {RegionName: "cool region"}, + }, + } + } + ) + + // These are done manually to ensure the conversion logic matches what a + // human expects. + var ( + protoApps = []*agentproto.WorkspaceApp{ + { + Id: apps[0].ID[:], + Url: apps[0].Url.String, + External: apps[0].External, + Slug: apps[0].Slug, + DisplayName: apps[0].DisplayName, + Command: apps[0].Command.String, + Icon: apps[0].Icon, + Subdomain: apps[0].Subdomain, + SubdomainName: fmt.Sprintf("%s--%s--%s--%s", apps[0].Slug, agent.Name, workspace.Name, owner.Username), + SharingLevel: agentproto.WorkspaceApp_AUTHENTICATED, + Healthcheck: &agentproto.WorkspaceApp_Healthcheck{ + Url: apps[0].HealthcheckUrl, + Interval: durationpb.New(time.Duration(apps[0].HealthcheckInterval) * time.Second), + Threshold: apps[0].HealthcheckThreshold, + }, + Health: agentproto.WorkspaceApp_HEALTHY, + }, + { + Id: apps[1].ID[:], + Url: apps[1].Url.String, + External: apps[1].External, + Slug: apps[1].Slug, + DisplayName: apps[1].DisplayName, + Command: apps[1].Command.String, + Icon: apps[1].Icon, + Subdomain: false, + SubdomainName: "", + SharingLevel: agentproto.WorkspaceApp_PUBLIC, + Healthcheck: &agentproto.WorkspaceApp_Healthcheck{ + Url: "", + Interval: durationpb.New(0), + Threshold: 0, + }, + Health: agentproto.WorkspaceApp_DISABLED, + }, + { + Id: apps[2].ID[:], + Url: apps[2].Url.String, + External: apps[2].External, + Slug: apps[2].Slug, + DisplayName: apps[2].DisplayName, + Command: apps[2].Command.String, + Icon: apps[2].Icon, + Subdomain: false, + SubdomainName: "", + SharingLevel: agentproto.WorkspaceApp_OWNER, + Healthcheck: &agentproto.WorkspaceApp_Healthcheck{ + Url: apps[2].HealthcheckUrl, + Interval: durationpb.New(time.Duration(apps[2].HealthcheckInterval) * time.Second), + Threshold: apps[2].HealthcheckThreshold, + }, + Health: agentproto.WorkspaceApp_UNHEALTHY, + }, + } + protoScripts = []*agentproto.WorkspaceAgentScript{ + { + LogSourceId: scripts[0].LogSourceID[:], + LogPath: scripts[0].LogPath, + Script: scripts[0].Script, + Cron: scripts[0].Cron, + RunOnStart: scripts[0].RunOnStart, + RunOnStop: scripts[0].RunOnStop, + StartBlocksLogin: scripts[0].StartBlocksLogin, + Timeout: durationpb.New(time.Duration(scripts[0].TimeoutSeconds) * time.Second), + }, + { + LogSourceId: scripts[1].LogSourceID[:], + LogPath: scripts[1].LogPath, + Script: scripts[1].Script, + Cron: scripts[1].Cron, + RunOnStart: scripts[1].RunOnStart, + RunOnStop: scripts[1].RunOnStop, + StartBlocksLogin: scripts[1].StartBlocksLogin, + Timeout: durationpb.New(time.Duration(scripts[1].TimeoutSeconds) * time.Second), + }, + } + protoMetadata = []*agentproto.WorkspaceAgentMetadata_Description{ + { + DisplayName: metadata[0].DisplayName, + Key: metadata[0].Key, + Script: metadata[0].Script, + Interval: durationpb.New(time.Duration(metadata[0].Interval)), + Timeout: durationpb.New(time.Duration(metadata[0].Timeout)), + }, + { + DisplayName: metadata[1].DisplayName, + Key: metadata[1].Key, + Script: metadata[1].Script, + Interval: durationpb.New(time.Duration(metadata[1].Interval)), + Timeout: durationpb.New(time.Duration(metadata[1].Timeout)), + }, + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + mDB := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.ManifestAPI{ + AccessURL: &url.URL{Scheme: "https", Host: "example.com"}, + AppHostname: "*--apps.example.com", + ExternalAuthConfigs: []*externalauth.Config{ + {Type: string(codersdk.EnhancedExternalAuthProviderGitHub)}, + {Type: "some-provider"}, + {Type: string(codersdk.EnhancedExternalAuthProviderGitLab)}, + }, + DisableDirectConnections: true, + DerpForceWebSockets: true, + + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) { + return workspace.ID, nil + }, + Database: mDB, + DerpMapFn: derpMapFn, + } + + mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil) + mDB.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), []uuid.UUID{agent.ID}).Return(scripts, nil) + mDB.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), database.GetWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + Keys: nil, // all + }).Return(metadata, nil) + mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil) + mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil) + + got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{}) + require.NoError(t, err) + + expected := &agentproto.Manifest{ + AgentId: agent.ID[:], + AgentName: agent.Name, + OwnerUsername: owner.Username, + WorkspaceId: workspace.ID[:], + WorkspaceName: workspace.Name, + GitAuthConfigs: 2, // two "enhanced" external auth configs + EnvironmentVariables: expectedEnvVars, + Directory: agent.Directory, + VsCodePortProxyUri: fmt.Sprintf("https://{{port}}--%s--%s--%s--apps.example.com", agent.Name, workspace.Name, owner.Username), + MotdPath: agent.MOTDFile, + DisableDirectConnections: true, + DerpForceWebsockets: true, + // tailnet.DERPMapToProto() is extensively tested elsewhere, so it's + // not necessary to manually recreate a big DERP map here like we + // did for apps and metadata. + DerpMap: tailnet.DERPMapToProto(derpMapFn()), + Scripts: protoScripts, + Apps: protoApps, + Metadata: protoMetadata, + } + + // Log got and expected with spew. + // t.Log("got:\n" + spew.Sdump(got)) + // t.Log("expected:\n" + spew.Sdump(expected)) + + require.Equal(t, expected, got) + }) + + t.Run("NoAppHostname", func(t *testing.T) { + t.Parallel() + + mDB := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.ManifestAPI{ + AccessURL: &url.URL{Scheme: "https", Host: "example.com"}, + AppHostname: "", + ExternalAuthConfigs: []*externalauth.Config{ + {Type: string(codersdk.EnhancedExternalAuthProviderGitHub)}, + {Type: "some-provider"}, + {Type: string(codersdk.EnhancedExternalAuthProviderGitLab)}, + }, + DisableDirectConnections: true, + DerpForceWebSockets: true, + + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) { + return workspace.ID, nil + }, + Database: mDB, + DerpMapFn: derpMapFn, + } + + mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil) + mDB.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), []uuid.UUID{agent.ID}).Return(scripts, nil) + mDB.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), database.GetWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + Keys: nil, // all + }).Return(metadata, nil) + mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil) + mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil) + + got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{}) + require.NoError(t, err) + + expected := &agentproto.Manifest{ + AgentId: agent.ID[:], + AgentName: agent.Name, + OwnerUsername: owner.Username, + WorkspaceId: workspace.ID[:], + WorkspaceName: workspace.Name, + GitAuthConfigs: 2, // two "enhanced" external auth configs + EnvironmentVariables: expectedEnvVars, + Directory: agent.Directory, + VsCodePortProxyUri: "", // empty with no AppHost + MotdPath: agent.MOTDFile, + DisableDirectConnections: true, + DerpForceWebsockets: true, + // tailnet.DERPMapToProto() is extensively tested elsewhere, so it's + // not necessary to manually recreate a big DERP map here like we + // did for apps and metadata. + DerpMap: tailnet.DERPMapToProto(derpMapFn()), + Scripts: protoScripts, + Apps: protoApps, + Metadata: protoMetadata, + } + + // Log got and expected with spew. + // t.Log("got:\n" + spew.Sdump(got)) + // t.Log("expected:\n" + spew.Sdump(expected)) + + require.Equal(t, expected, got) + }) +} diff --git a/coderd/agentapi/metadata.go b/coderd/agentapi/metadata.go index a3bf24b203..0c3e0c8630 100644 --- a/coderd/agentapi/metadata.go +++ b/coderd/agentapi/metadata.go @@ -12,6 +12,7 @@ import ( "cdr.dev/slog" agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" ) @@ -20,14 +21,26 @@ type MetadataAPI struct { Database database.Store Pubsub pubsub.Pubsub Log slog.Logger + + TimeNowFn func() time.Time // defaults to dbtime.Now() +} + +func (a *MetadataAPI) now() time.Time { + if a.TimeNowFn != nil { + return a.TimeNowFn() + } + return dbtime.Now() } func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) { const ( - // maxValueLen is set to 2048 to stay under the 8000 byte Postgres - // NOTIFY limit. Since both value and error can be set, the real payload - // limit is 2 * 2048 * 4/3 = 5461 bytes + a few - // hundred bytes for JSON syntax, key names, and metadata. + // maxAllKeysLen is the maximum length of all metadata keys. This is + // 6144 to stay below the Postgres NOTIFY limit of 8000 bytes, with some + // headway for the timestamp and JSON encoding. Any values that would + // exceed this limit are discarded (the rest are still inserted) and an + // error is returned. + maxAllKeysLen = 6144 // 1024 * 6 + maxValueLen = 2048 maxErrorLen = maxValueLen ) @@ -37,18 +50,36 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B return nil, err } - collectedAt := time.Now() - dbUpdate := database.UpdateWorkspaceAgentMetadataParams{ - WorkspaceAgentID: workspaceAgent.ID, - Key: make([]string, 0, len(req.Metadata)), - Value: make([]string, 0, len(req.Metadata)), - Error: make([]string, 0, len(req.Metadata)), - CollectedAt: make([]time.Time, 0, len(req.Metadata)), - } - + var ( + collectedAt = a.now() + allKeysLen = 0 + dbUpdate = database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: workspaceAgent.ID, + // These need to be `make(x, 0, len(req.Metadata))` instead of + // `make(x, len(req.Metadata))` because we may not insert all + // metadata if the keys are large. + Key: make([]string, 0, len(req.Metadata)), + Value: make([]string, 0, len(req.Metadata)), + Error: make([]string, 0, len(req.Metadata)), + CollectedAt: make([]time.Time, 0, len(req.Metadata)), + } + ) for _, md := range req.Metadata { metadataError := md.Result.Error + allKeysLen += len(md.Key) + if allKeysLen > maxAllKeysLen { + // We still insert the rest of the metadata, and we return an error + // after the insert. + a.Log.Warn( + ctx, "discarded extra agent metadata due to excessive key length", + slog.F("collected_at", collectedAt), + slog.F("all_keys_len", allKeysLen), + slog.F("max_all_keys_len", maxAllKeysLen), + ) + break + } + // We overwrite the error if the provided payload is too long. if len(md.Result.Value) > maxValueLen { metadataError = fmt.Sprintf("value of %d bytes exceeded %d bytes", len(md.Result.Value), maxValueLen) @@ -71,12 +102,16 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B a.Log.Debug( ctx, "accepted metadata report", slog.F("collected_at", collectedAt), - slog.F("original_collected_at", collectedAt), slog.F("key", md.Key), slog.F("value", ellipse(md.Result.Value, 16)), ) } + err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate) + if err != nil { + return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err) + } + payload, err := json.Marshal(WorkspaceAgentMetadataChannelPayload{ CollectedAt: collectedAt, Keys: dbUpdate.Key, @@ -84,17 +119,17 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B if err != nil { return nil, xerrors.Errorf("marshal workspace agent metadata channel payload: %w", err) } - - err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate) - if err != nil { - return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err) - } - err = a.Pubsub.Publish(WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), payload) if err != nil { return nil, xerrors.Errorf("publish workspace agent metadata: %w", err) } + // If the metadata keys were too large, we return an error so the agent can + // log it. + if allKeysLen > maxAllKeysLen { + return nil, xerrors.Errorf("metadata keys of %d bytes exceeded %d bytes", allKeysLen, maxAllKeysLen) + } + return &agentproto.BatchUpdateMetadataResponse{}, nil } diff --git a/coderd/agentapi/metadata_test.go b/coderd/agentapi/metadata_test.go new file mode 100644 index 0000000000..c3d0ec5528 --- /dev/null +++ b/coderd/agentapi/metadata_test.go @@ -0,0 +1,276 @@ +package agentapi_test + +import ( + "context" + "encoding/json" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "google.golang.org/protobuf/types/known/timestamppb" + + "cdr.dev/slog/sloggers/slogtest" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/database/pubsub" +) + +type fakePublisher struct { + // Nil pointer to pass interface check. + pubsub.Pubsub + publishes [][]byte +} + +var _ pubsub.Pubsub = &fakePublisher{} + +func (f *fakePublisher) Publish(_ string, message []byte) error { + f.publishes = append(f.publishes, message) + return nil +} + +func TestBatchUpdateMetadata(t *testing.T) { + t.Parallel() + + agent := database.WorkspaceAgent{ + ID: uuid.New(), + } + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + pub := &fakePublisher{} + + now := dbtime.Now() + req := &agentproto.BatchUpdateMetadataRequest{ + Metadata: []*agentproto.Metadata{ + { + Key: "awesome key", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + CollectedAt: timestamppb.New(now.Add(-10 * time.Second)), + Age: 10, + Value: "awesome value", + Error: "", + }, + }, + { + Key: "uncool key", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + CollectedAt: timestamppb.New(now.Add(-3 * time.Second)), + Age: 3, + Value: "", + Error: "uncool value", + }, + }, + }, + } + + dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + Key: []string{req.Metadata[0].Key, req.Metadata[1].Key}, + Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value}, + Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error}, + // The value from the agent is ignored. + CollectedAt: []time.Time{now, now}, + }).Return(nil) + + api := &agentapi.MetadataAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Pubsub: pub, + Log: slogtest.Make(t, nil), + TimeNowFn: func() time.Time { + return now + }, + } + + resp, err := api.BatchUpdateMetadata(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp) + + require.Equal(t, 1, len(pub.publishes)) + var gotEvent agentapi.WorkspaceAgentMetadataChannelPayload + require.NoError(t, json.Unmarshal(pub.publishes[0], &gotEvent)) + require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{ + CollectedAt: now, + Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key}, + }, gotEvent) + }) + + t.Run("ExceededLength", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + pub := pubsub.NewInMemory() + + almostLongValue := "" + for i := 0; i < 2048; i++ { + almostLongValue += "a" + } + + now := dbtime.Now() + req := &agentproto.BatchUpdateMetadataRequest{ + Metadata: []*agentproto.Metadata{ + { + Key: "almost long value", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: almostLongValue, + }, + }, + { + Key: "too long value", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: almostLongValue + "a", + }, + }, + { + Key: "almost long error", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Error: almostLongValue, + }, + }, + { + Key: "too long error", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Error: almostLongValue + "a", + }, + }, + }, + } + + dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key, req.Metadata[3].Key}, + Value: []string{ + almostLongValue, + almostLongValue, // truncated + "", + "", + }, + Error: []string{ + "", + "value of 2049 bytes exceeded 2048 bytes", + almostLongValue, + "error of 2049 bytes exceeded 2048 bytes", // replaced + }, + // The value from the agent is ignored. + CollectedAt: []time.Time{now, now, now, now}, + }).Return(nil) + + api := &agentapi.MetadataAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Pubsub: pub, + Log: slogtest.Make(t, nil), + TimeNowFn: func() time.Time { + return now + }, + } + + resp, err := api.BatchUpdateMetadata(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp) + }) + + t.Run("KeysTooLong", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + pub := pubsub.NewInMemory() + + now := dbtime.Now() + req := &agentproto.BatchUpdateMetadataRequest{ + Metadata: []*agentproto.Metadata{ + { + Key: "key 1", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: "value 1", + }, + }, + { + Key: "key 2", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: "value 2", + }, + }, + { + Key: func() string { + key := "key 3 " + for i := 0; i < (6144 - len("key 1") - len("key 2") - len("key 3") - 1); i++ { + key += "a" + } + return key + }(), + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: "value 3", + }, + }, + { + Key: "a", // should be ignored + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: "value 4", + }, + }, + }, + } + + dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + // No key 4. + Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key}, + Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value, req.Metadata[2].Result.Value}, + Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error, req.Metadata[2].Result.Error}, + // The value from the agent is ignored. + CollectedAt: []time.Time{now, now, now}, + }).Return(nil) + + api := &agentapi.MetadataAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Pubsub: pub, + Log: slogtest.Make(t, nil), + TimeNowFn: func() time.Time { + return now + }, + } + + // Watch the pubsub for events. + var ( + eventCount int64 + gotEvent agentapi.WorkspaceAgentMetadataChannelPayload + ) + cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) { + if atomic.AddInt64(&eventCount, 1) > 1 { + return + } + require.NoError(t, json.Unmarshal(message, &gotEvent)) + }) + require.NoError(t, err) + defer cancel() + + resp, err := api.BatchUpdateMetadata(context.Background(), req) + require.Error(t, err) + require.Equal(t, "metadata keys of 6145 bytes exceeded 6144 bytes", err.Error()) + require.Nil(t, resp) + + require.Equal(t, int64(1), atomic.LoadInt64(&eventCount)) + require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{ + CollectedAt: now, + // No key 4. + Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key}, + }, gotEvent) + }) +} diff --git a/coderd/agentapi/servicebanner_test.go b/coderd/agentapi/servicebanner_test.go new file mode 100644 index 0000000000..902af7395e --- /dev/null +++ b/coderd/agentapi/servicebanner_test.go @@ -0,0 +1,84 @@ +package agentapi_test + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/codersdk" +) + +func TestGetServiceBanner(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + cfg := codersdk.ServiceBannerConfig{ + Enabled: true, + Message: "hello world", + BackgroundColor: "#000000", + } + cfgJSON, err := json.Marshal(cfg) + require.NoError(t, err) + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetServiceBanner(gomock.Any()).Return(string(cfgJSON), nil) + + api := &agentapi.ServiceBannerAPI{ + Database: dbM, + } + + resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{}) + require.NoError(t, err) + + require.Equal(t, &agentproto.ServiceBanner{ + Enabled: cfg.Enabled, + Message: cfg.Message, + BackgroundColor: cfg.BackgroundColor, + }, resp) + }) + + t.Run("None", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetServiceBanner(gomock.Any()).Return("", sql.ErrNoRows) + + api := &agentapi.ServiceBannerAPI{ + Database: dbM, + } + + resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{}) + require.NoError(t, err) + + require.Equal(t, &agentproto.ServiceBanner{ + Enabled: false, + Message: "", + BackgroundColor: "", + }, resp) + }) + + t.Run("BadJSON", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetServiceBanner(gomock.Any()).Return("hi", nil) + + api := &agentapi.ServiceBannerAPI{ + Database: dbM, + } + + resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{}) + require.Error(t, err) + require.ErrorContains(t, err, "unmarshal json") + require.Nil(t, resp) + }) +} diff --git a/coderd/agentapi/stats.go b/coderd/agentapi/stats.go index d2098f2cce..1185b99abd 100644 --- a/coderd/agentapi/stats.go +++ b/coderd/agentapi/stats.go @@ -9,57 +9,71 @@ import ( "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/durationpb" + "github.com/google/uuid" + "cdr.dev/slog" agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/autobuild" - "github.com/coder/coder/v2/coderd/batchstats" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/prometheusmetrics" "github.com/coder/coder/v2/coderd/schedule" ) +type StatsBatcher interface { + Add(now time.Time, agentID uuid.UUID, templateID uuid.UUID, userID uuid.UUID, workspaceID uuid.UUID, st *agentproto.Stats) error +} + type StatsAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) Database database.Store Log slog.Logger - StatsBatcher *batchstats.Batcher + StatsBatcher StatsBatcher TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] AgentStatsRefreshInterval time.Duration UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) + + TimeNowFn func() time.Time // defaults to dbtime.Now() +} + +func (a *StatsAPI) now() time.Time { + if a.TimeNowFn != nil { + return a.TimeNowFn() + } + return dbtime.Now() } func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) { + // An empty stat means it's just looking for the report interval. + res := &agentproto.UpdateStatsResponse{ + ReportInterval: durationpb.New(a.AgentStatsRefreshInterval), + } + if req.Stats == nil || len(req.Stats.ConnectionsByProto) == 0 { + return res, nil + } + workspaceAgent, err := a.AgentFn(ctx) if err != nil { return nil, err } - row, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) + getWorkspaceAgentByIDRow, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) if err != nil { return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err) } - workspace := row.Workspace - - res := &agentproto.UpdateStatsResponse{ - ReportInterval: durationpb.New(a.AgentStatsRefreshInterval), - } - - // An empty stat means it's just looking for the report interval. - if len(req.Stats.ConnectionsByProto) == 0 { - return res, nil - } - + workspace := getWorkspaceAgentByIDRow.Workspace a.Log.Debug(ctx, "read stats report", slog.F("interval", a.AgentStatsRefreshInterval), slog.F("workspace_id", workspace.ID), slog.F("payload", req), ) + now := a.now() if req.Stats.ConnectionCount > 0 { var nextAutostart time.Time if workspace.AutostartSchedule.String != "" { templateSchedule, err := (*(a.TemplateScheduleStore.Load())).Get(ctx, a.Database, workspace.TemplateID) - // If the template schedule fails to load, just default to bumping without the next trasition and log it. + // If the template schedule fails to load, just default to bumping + // without the next transition and log it. if err != nil { a.Log.Error(ctx, "failed to load template schedule bumping activity, defaulting to bumping by 60min", slog.F("workspace_id", workspace.ID), @@ -67,7 +81,7 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR slog.Error(err), ) } else { - next, allowed := autobuild.NextAutostartSchedule(time.Now(), workspace.AutostartSchedule.String, templateSchedule) + next, allowed := autobuild.NextAutostartSchedule(now, workspace.AutostartSchedule.String, templateSchedule) if allowed { nextAutostart = next } @@ -76,13 +90,12 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR ActivityBumpWorkspace(ctx, a.Log.Named("activity_bump"), a.Database, workspace.ID, nextAutostart) } - now := dbtime.Now() - var errGroup errgroup.Group errGroup.Go(func() error { - if err := a.StatsBatcher.Add(time.Now(), workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, req.Stats); err != nil { - a.Log.Error(ctx, "failed to add stats to batcher", slog.Error(err)) - return xerrors.Errorf("can't insert workspace agent stat: %w", err) + err := a.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, req.Stats) + if err != nil { + a.Log.Error(ctx, "add agent stats to batcher", slog.Error(err)) + return xerrors.Errorf("insert workspace agent stats batch: %w", err) } return nil }) @@ -92,7 +105,7 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR LastUsedAt: now, }) if err != nil { - return xerrors.Errorf("can't update workspace LastUsedAt: %w", err) + return xerrors.Errorf("update workspace LastUsedAt: %w", err) } return nil }) @@ -100,14 +113,14 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR errGroup.Go(func() error { user, err := a.Database.GetUserByID(ctx, workspace.OwnerID) if err != nil { - return xerrors.Errorf("can't get user: %w", err) + return xerrors.Errorf("get user: %w", err) } a.UpdateAgentMetricsFn(ctx, prometheusmetrics.AgentMetricLabels{ Username: user.Username, WorkspaceName: workspace.Name, AgentName: workspaceAgent.Name, - TemplateName: row.TemplateName, + TemplateName: getWorkspaceAgentByIDRow.TemplateName, }, req.Stats.Metrics) return nil }) diff --git a/coderd/agentapi/stats_test.go b/coderd/agentapi/stats_test.go new file mode 100644 index 0000000000..a26e7fbf6a --- /dev/null +++ b/coderd/agentapi/stats_test.go @@ -0,0 +1,379 @@ +package agentapi_test + +import ( + "context" + "database/sql" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "google.golang.org/protobuf/types/known/durationpb" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/prometheusmetrics" + "github.com/coder/coder/v2/coderd/schedule" +) + +type statsBatcher struct { + mu sync.Mutex + + called int64 + lastTime time.Time + lastAgentID uuid.UUID + lastTemplateID uuid.UUID + lastUserID uuid.UUID + lastWorkspaceID uuid.UUID + lastStats *agentproto.Stats +} + +var _ agentapi.StatsBatcher = &statsBatcher{} + +func (b *statsBatcher) Add(now time.Time, agentID uuid.UUID, templateID uuid.UUID, userID uuid.UUID, workspaceID uuid.UUID, st *agentproto.Stats) error { + b.mu.Lock() + defer b.mu.Unlock() + b.called++ + b.lastTime = now + b.lastAgentID = agentID + b.lastTemplateID = templateID + b.lastUserID = userID + b.lastWorkspaceID = workspaceID + b.lastStats = st + return nil +} + +func TestUpdateStates(t *testing.T) { + t.Parallel() + + var ( + user = database.User{ + ID: uuid.New(), + Username: "bill", + } + template = database.Template{ + ID: uuid.New(), + Name: "tpl", + } + workspace = database.Workspace{ + ID: uuid.New(), + OwnerID: user.ID, + TemplateID: template.ID, + Name: "xyz", + } + agent = database.WorkspaceAgent{ + ID: uuid.New(), + Name: "abc", + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + var ( + now = dbtime.Now() + dbM = dbmock.NewMockStore(gomock.NewController(t)) + templateScheduleStore = schedule.MockTemplateScheduleStore{ + GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) { + panic("should not be called") + }, + SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) { + panic("not implemented") + }, + } + batcher = &statsBatcher{} + updateAgentMetricsFnCalled = false + + req = &agentproto.UpdateStatsRequest{ + Stats: &agentproto.Stats{ + ConnectionsByProto: map[string]int64{ + "tcp": 1, + "dean": 2, + }, + ConnectionCount: 3, + ConnectionMedianLatencyMs: 23, + RxPackets: 120, + RxBytes: 1000, + TxPackets: 130, + TxBytes: 2000, + SessionCountVscode: 1, + SessionCountJetbrains: 2, + SessionCountReconnectingPty: 3, + SessionCountSsh: 4, + Metrics: []*agentproto.Stats_Metric{ + { + Name: "awesome metric", + Value: 42, + }, + { + Name: "uncool metric", + Value: 0, + }, + }, + }, + } + ) + api := agentapi.StatsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + StatsBatcher: batcher, + TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore), + AgentStatsRefreshInterval: 10 * time.Second, + UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) { + updateAgentMetricsFnCalled = true + assert.Equal(t, prometheusmetrics.AgentMetricLabels{ + Username: user.Username, + WorkspaceName: workspace.Name, + AgentName: agent.Name, + TemplateName: template.Name, + }, labels) + assert.Equal(t, req.Stats.Metrics, metrics) + }, + TimeNowFn: func() time.Time { + return now + }, + } + + // Workspace gets fetched. + dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{ + Workspace: workspace, + TemplateName: template.Name, + }, nil) + + // We expect an activity bump because ConnectionCount > 0. + dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{ + WorkspaceID: workspace.ID, + NextAutostart: time.Time{}.UTC(), + }).Return(nil) + + // Workspace last used at gets bumped. + dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{ + ID: workspace.ID, + LastUsedAt: now, + }).Return(nil) + + // User gets fetched to hit the UpdateAgentMetricsFn. + dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil) + + resp, err := api.UpdateStats(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.UpdateStatsResponse{ + ReportInterval: durationpb.New(10 * time.Second), + }, resp) + + batcher.mu.Lock() + defer batcher.mu.Unlock() + require.Equal(t, int64(1), batcher.called) + require.Equal(t, now, batcher.lastTime) + require.Equal(t, agent.ID, batcher.lastAgentID) + require.Equal(t, template.ID, batcher.lastTemplateID) + require.Equal(t, user.ID, batcher.lastUserID) + require.Equal(t, workspace.ID, batcher.lastWorkspaceID) + require.Equal(t, req.Stats, batcher.lastStats) + + require.True(t, updateAgentMetricsFnCalled) + }) + + t.Run("ConnectionCountZero", func(t *testing.T) { + t.Parallel() + + var ( + now = dbtime.Now() + dbM = dbmock.NewMockStore(gomock.NewController(t)) + templateScheduleStore = schedule.MockTemplateScheduleStore{ + GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) { + panic("should not be called") + }, + SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) { + panic("not implemented") + }, + } + batcher = &statsBatcher{} + + req = &agentproto.UpdateStatsRequest{ + Stats: &agentproto.Stats{ + ConnectionsByProto: map[string]int64{ + "tcp": 1, + }, + ConnectionCount: 0, + ConnectionMedianLatencyMs: 23, + }, + } + ) + api := agentapi.StatsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + StatsBatcher: batcher, + TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore), + AgentStatsRefreshInterval: 10 * time.Second, + // Ignored when nil. + UpdateAgentMetricsFn: nil, + TimeNowFn: func() time.Time { + return now + }, + } + + // Workspace gets fetched. + dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{ + Workspace: workspace, + TemplateName: template.Name, + }, nil) + + // Workspace last used at gets bumped. + dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{ + ID: workspace.ID, + LastUsedAt: now, + }).Return(nil) + + _, err := api.UpdateStats(context.Background(), req) + require.NoError(t, err) + }) + + t.Run("NoConnectionsByProto", func(t *testing.T) { + t.Parallel() + + var ( + dbM = dbmock.NewMockStore(gomock.NewController(t)) + req = &agentproto.UpdateStatsRequest{ + Stats: &agentproto.Stats{ + ConnectionsByProto: map[string]int64{}, // len() == 0 + }, + } + ) + api := agentapi.StatsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + StatsBatcher: nil, // should not be called + TemplateScheduleStore: nil, // should not be called + AgentStatsRefreshInterval: 10 * time.Second, + UpdateAgentMetricsFn: nil, // should not be called + TimeNowFn: func() time.Time { + panic("should not be called") + }, + } + + resp, err := api.UpdateStats(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.UpdateStatsResponse{ + ReportInterval: durationpb.New(10 * time.Second), + }, resp) + }) + + t.Run("AutostartAwareBump", func(t *testing.T) { + t.Parallel() + + // Use a workspace with an autostart schedule. + workspace := workspace + workspace.AutostartSchedule = sql.NullString{ + String: "CRON_TZ=Australia/Sydney 0 8 * * *", + Valid: true, + } + + // Use a custom time for now which would trigger the autostart aware + // bump. + now, err := time.Parse("2006-01-02 15:04:05 -0700 MST", "2023-12-19 07:30:00 +1100 AEDT") + require.NoError(t, err) + now = dbtime.Time(now) + nextAutostart := now.Add(30 * time.Minute).UTC() // always sent to DB as UTC + + var ( + dbM = dbmock.NewMockStore(gomock.NewController(t)) + templateScheduleStore = schedule.MockTemplateScheduleStore{ + GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) { + return schedule.TemplateScheduleOptions{ + UserAutostartEnabled: true, + AutostartRequirement: schedule.TemplateAutostartRequirement{ + DaysOfWeek: 0b01111111, // every day + }, + }, nil + }, + SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) { + panic("not implemented") + }, + } + batcher = &statsBatcher{} + updateAgentMetricsFnCalled = false + + req = &agentproto.UpdateStatsRequest{ + Stats: &agentproto.Stats{ + ConnectionsByProto: map[string]int64{ + "tcp": 1, + "dean": 2, + }, + ConnectionCount: 3, + }, + } + ) + api := agentapi.StatsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + StatsBatcher: batcher, + TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore), + AgentStatsRefreshInterval: 15 * time.Second, + UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) { + updateAgentMetricsFnCalled = true + assert.Equal(t, prometheusmetrics.AgentMetricLabels{ + Username: user.Username, + WorkspaceName: workspace.Name, + AgentName: agent.Name, + TemplateName: template.Name, + }, labels) + assert.Equal(t, req.Stats.Metrics, metrics) + }, + TimeNowFn: func() time.Time { + return now + }, + } + + // Workspace gets fetched. + dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{ + Workspace: workspace, + TemplateName: template.Name, + }, nil) + + // We expect an activity bump because ConnectionCount > 0. However, the + // next autostart time will be set on the bump. + dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{ + WorkspaceID: workspace.ID, + NextAutostart: nextAutostart, + }).Return(nil) + + // Workspace last used at gets bumped. + dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{ + ID: workspace.ID, + LastUsedAt: now, + }).Return(nil) + + // User gets fetched to hit the UpdateAgentMetricsFn. + dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil) + + resp, err := api.UpdateStats(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.UpdateStatsResponse{ + ReportInterval: durationpb.New(15 * time.Second), + }, resp) + + require.True(t, updateAgentMetricsFnCalled) + }) +} + +func templateScheduleStorePtr(store schedule.TemplateScheduleStore) *atomic.Pointer[schedule.TemplateScheduleStore] { + var ptr atomic.Pointer[schedule.TemplateScheduleStore] + ptr.Store(&store) + return &ptr +} diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 3d9953dd87..ef58e858b6 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -266,16 +266,25 @@ func convertDisplayApps(apps []database.DisplayApp) []codersdk.DisplayApp { return dapps } +func WorkspaceAgentEnvironment(workspaceAgent database.WorkspaceAgent) (map[string]string, error) { + var envs map[string]string + if workspaceAgent.EnvironmentVariables.Valid { + err := json.Unmarshal(workspaceAgent.EnvironmentVariables.RawMessage, &envs) + if err != nil { + return nil, xerrors.Errorf("unmarshal environment variables: %w", err) + } + } + + return envs, nil +} + func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, scripts []codersdk.WorkspaceAgentScript, logSources []codersdk.WorkspaceAgentLogSource, agentInactiveDisconnectTimeout time.Duration, agentFallbackTroubleshootingURL string, ) (codersdk.WorkspaceAgent, error) { - var envs map[string]string - if dbAgent.EnvironmentVariables.Valid { - err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs) - if err != nil { - return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal env vars: %w", err) - } + envs, err := WorkspaceAgentEnvironment(dbAgent) + if err != nil { + return codersdk.WorkspaceAgent{}, err } troubleshootingURL := agentFallbackTroubleshootingURL if dbAgent.TroubleshootingURL != "" { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index d438d6663d..d5a967b4a7 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -154,18 +154,24 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request) // As this API becomes deprecated, use the new protobuf API and convert the // types back to the SDK types. manifestAPI := &agentapi.ManifestAPI{ - AccessURL: api.AccessURL, - AppHostname: api.AppHostname, - AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout, - AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), - ExternalAuthConfigs: api.ExternalAuthConfigs, - DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(), - DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(), + AccessURL: api.AccessURL, + AppHostname: api.AppHostname, + ExternalAuthConfigs: api.ExternalAuthConfigs, + DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(), + DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(), - AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil }, - Database: api.Database, - DerpMapFn: api.DERPMap, - TailnetCoordinator: &api.TailnetCoordinator, + AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil }, + WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) { + // Sadly this results in a double query, but it's only temporary for + // now. + ws, err := api.Database.GetWorkspaceByAgentID(ctx, wa.ID) + if err != nil { + return uuid.Nil, err + } + return ws.Workspace.ID, nil + }, + Database: api.Database, + DerpMapFn: api.DERPMap, } manifest, err := manifestAPI.GetManifest(ctx, &agentproto.GetManifestRequest{}) if err != nil { diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index a0bb039d1f..8e02fc878e 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -134,15 +134,13 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, PublishWorkspaceAgentLogsUpdateFn: api.publishWorkspaceAgentLogsUpdate, - AccessURL: api.AccessURL, - AppHostname: api.AppHostname, - AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout, - AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), - AgentStatsRefreshInterval: api.AgentStatsRefreshInterval, - DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(), - DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(), - DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency, - ExternalAuthConfigs: api.ExternalAuthConfigs, + AccessURL: api.AccessURL, + AppHostname: api.AppHostname, + AgentStatsRefreshInterval: api.AgentStatsRefreshInterval, + DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(), + DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(), + DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency, + ExternalAuthConfigs: api.ExternalAuthConfigs, // Optional: WorkspaceID: build.WorkspaceID, // saves the extra lookup later