feat: use Agent v2 API for Service Banner (#11806)

Agent uses the v2 API for the service banner, rather than the v1 HTTP API.

One of several for #10534
This commit is contained in:
Spike Curtis 2024-01-30 07:44:47 +04:00 committed by GitHub
parent 4f5a2f0a9b
commit 13e24f21e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 274 additions and 133 deletions

View File

@ -41,6 +41,7 @@ import (
"github.com/coder/coder/v2/agent/agentproc" "github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/agent/agentscripts" "github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/agent/reconnectingpty" "github.com/coder/coder/v2/agent/reconnectingpty"
"github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/gitauth" "github.com/coder/coder/v2/cli/gitauth"
@ -95,7 +96,6 @@ type Client interface {
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error)
} }
type Agent interface { type Agent interface {
@ -269,7 +269,6 @@ func (a *agent) init(ctx context.Context) {
func (a *agent) runLoop(ctx context.Context) { func (a *agent) runLoop(ctx context.Context) {
go a.reportLifecycleLoop(ctx) go a.reportLifecycleLoop(ctx)
go a.reportMetadataLoop(ctx) go a.reportMetadataLoop(ctx)
go a.fetchServiceBannerLoop(ctx)
go a.manageProcessPriorityLoop(ctx) go a.manageProcessPriorityLoop(ctx)
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
@ -662,22 +661,23 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL
// fetchServiceBannerLoop fetches the service banner on an interval. It will // fetchServiceBannerLoop fetches the service banner on an interval. It will
// not be fetched immediately; the expectation is that it is primed elsewhere // not be fetched immediately; the expectation is that it is primed elsewhere
// (and must be done before the session actually starts). // (and must be done before the session actually starts).
func (a *agent) fetchServiceBannerLoop(ctx context.Context) { func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgentClient) error {
ticker := time.NewTicker(a.serviceBannerRefreshInterval) ticker := time.NewTicker(a.serviceBannerRefreshInterval)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return ctx.Err()
case <-ticker.C: case <-ticker.C:
serviceBanner, err := a.client.GetServiceBanner(ctx) sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
return return ctx.Err()
} }
a.logger.Error(ctx, "failed to update service banner", slog.Error(err)) a.logger.Error(ctx, "failed to update service banner", slog.Error(err))
continue return err
} }
serviceBanner := proto.SDKServiceBannerFromProto(sbp)
a.serviceBanner.Store(&serviceBanner) a.serviceBanner.Store(&serviceBanner)
} }
} }
@ -693,10 +693,24 @@ func (a *agent) run(ctx context.Context) error {
} }
a.sessionToken.Store(&sessionToken) a.sessionToken.Store(&sessionToken)
serviceBanner, err := a.client.GetServiceBanner(ctx) // Listen returns the dRPC connection we use for the Agent v2+ API
conn, err := a.client.Listen(ctx)
if err != nil {
return err
}
defer func() {
cErr := conn.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
}
}()
aAPI := proto.NewDRPCAgentClient(conn)
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
if err != nil { if err != nil {
return xerrors.Errorf("fetch service banner: %w", err) return xerrors.Errorf("fetch service banner: %w", err)
} }
serviceBanner := proto.SDKServiceBannerFromProto(sbp)
a.serviceBanner.Store(&serviceBanner) a.serviceBanner.Store(&serviceBanner)
manifest, err := a.client.Manifest(ctx) manifest, err := a.client.Manifest(ctx)
@ -821,18 +835,6 @@ func (a *agent) run(ctx context.Context) error {
network.SetBlockEndpoints(manifest.DisableDirectConnections) network.SetBlockEndpoints(manifest.DisableDirectConnections)
} }
// Listen returns the dRPC connection we use for both Coordinator and DERPMap updates
conn, err := a.client.Listen(ctx)
if err != nil {
return err
}
defer func() {
cErr := conn.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
}
}()
eg, egCtx := errgroup.WithContext(ctx) eg, egCtx := errgroup.WithContext(ctx)
eg.Go(func() error { eg.Go(func() error {
a.logger.Debug(egCtx, "running tailnet connection coordinator") a.logger.Debug(egCtx, "running tailnet connection coordinator")
@ -852,6 +854,15 @@ func (a *agent) run(ctx context.Context) error {
return nil return nil
}) })
eg.Go(func() error {
a.logger.Debug(egCtx, "running fetch server banner loop")
err := a.fetchServiceBannerLoop(egCtx, aAPI)
if err != nil {
return xerrors.Errorf("fetch server banner loop: %w", err)
}
return nil
})
return eg.Wait() return eg.Wait()
} }

View File

@ -35,6 +35,7 @@ import (
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/valyala/fasthttp/fasthttputil"
"go.uber.org/goleak" "go.uber.org/goleak"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -2026,7 +2027,10 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
afero.Fs, afero.Fs,
agent.Agent, agent.Agent,
) { ) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, &slogtest.Options{
// we get this error when closing the Agent API
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, fasthttputil.ErrInmemoryListenerClosed),
}).Leveled(slog.LevelDebug)
if metadata.DERPMap == nil { if metadata.DERPMap == nil {
metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
} }

View File

@ -18,6 +18,7 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"cdr.dev/slog" "cdr.dev/slog"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc" drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
@ -48,6 +49,9 @@ func NewClient(t testing.TB,
} }
err := proto.DRPCRegisterTailnet(mux, drpcService) err := proto.DRPCRegisterTailnet(mux, drpcService)
require.NoError(t, err) require.NoError(t, err)
fakeAAPI := NewFakeAgentAPI(t, logger)
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{ server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) { Log: func(err error) {
if xerrors.Is(err, io.EOF) { if xerrors.Is(err, io.EOF) {
@ -64,22 +68,23 @@ func NewClient(t testing.TB,
statsChan: statsChan, statsChan: statsChan,
coordinator: coordinator, coordinator: coordinator,
server: server, server: server,
fakeAgentAPI: fakeAAPI,
derpMapUpdates: derpMapUpdates, derpMapUpdates: derpMapUpdates,
} }
} }
type Client struct { type Client struct {
t testing.TB t testing.TB
logger slog.Logger logger slog.Logger
agentID uuid.UUID agentID uuid.UUID
manifest agentsdk.Manifest manifest agentsdk.Manifest
metadata map[string]agentsdk.Metadata metadata map[string]agentsdk.Metadata
statsChan chan *agentsdk.Stats statsChan chan *agentsdk.Stats
coordinator tailnet.Coordinator coordinator tailnet.Coordinator
server *drpcserver.Server server *drpcserver.Server
LastWorkspaceAgent func() fakeAgentAPI *FakeAgentAPI
PatchWorkspaceLogs func() error LastWorkspaceAgent func()
GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) PatchWorkspaceLogs func() error
mu sync.Mutex // Protects following. mu sync.Mutex // Protects following.
lifecycleStates []codersdk.WorkspaceAgentLifecycle lifecycleStates []codersdk.WorkspaceAgentLifecycle
@ -221,20 +226,7 @@ func (c *Client) PatchLogs(ctx context.Context, logs agentsdk.PatchLogs) error {
} }
func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) { func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) {
c.mu.Lock() c.fakeAgentAPI.SetServiceBannerFunc(f)
defer c.mu.Unlock()
c.GetServiceBannerFunc = f
}
func (c *Client) GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.logger.Debug(ctx, "get service banner")
if c.GetServiceBannerFunc != nil {
return c.GetServiceBannerFunc()
}
return codersdk.ServiceBannerConfig{}, nil
} }
func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error { func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
@ -254,3 +246,73 @@ type closeFunc func() error
func (c closeFunc) Close() error { func (c closeFunc) Close() error {
return c() return c()
} }
type FakeAgentAPI struct {
sync.Mutex
t testing.TB
logger slog.Logger
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
}
func (*FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
// TODO implement me
panic("implement me")
}
func (f *FakeAgentAPI) SetServiceBannerFunc(fn func() (codersdk.ServiceBannerConfig, error)) {
f.Lock()
defer f.Unlock()
f.getServiceBannerFunc = fn
f.logger.Info(context.Background(), "updated ServiceBannerFunc")
}
func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
f.Lock()
defer f.Unlock()
if f.getServiceBannerFunc == nil {
return &agentproto.ServiceBanner{}, nil
}
sb, err := f.getServiceBannerFunc()
if err != nil {
return nil, err
}
return agentproto.ServiceBannerFromSDK(sb), nil
}
func (*FakeAgentAPI) UpdateStats(context.Context, *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
// TODO implement me
panic("implement me")
}
func (*FakeAgentAPI) UpdateLifecycle(context.Context, *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
// TODO implement me
panic("implement me")
}
func (*FakeAgentAPI) BatchUpdateAppHealths(context.Context, *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
// TODO implement me
panic("implement me")
}
func (*FakeAgentAPI) UpdateStartup(context.Context, *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
// TODO implement me
panic("implement me")
}
func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
// TODO implement me
panic("implement me")
}
func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) {
// TODO implement me
panic("implement me")
}
func NewFakeAgentAPI(t testing.TB, logger slog.Logger) *FakeAgentAPI {
return &FakeAgentAPI{
t: t,
logger: logger.Named("FakeAgentAPI"),
}
}

View File

@ -83,7 +83,8 @@ func TestWorkspaceAgent(t *testing.T) {
ctx := inv.Context() ctx := inv.Context()
clitest.Start(t, inv) clitest.Start(t, inv)
coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID) workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err) require.NoError(t, err)
resources := workspace.LatestBuild.Resources resources := workspace.LatestBuild.Resources
@ -120,7 +121,9 @@ func TestWorkspaceAgent(t *testing.T) {
clitest.Start(t, inv) clitest.Start(t, inv)
ctx := inv.Context() ctx := inv.Context()
coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID) workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err) require.NoError(t, err)
resources := workspace.LatestBuild.Resources resources := workspace.LatestBuild.Resources
@ -161,7 +164,9 @@ func TestWorkspaceAgent(t *testing.T) {
) )
ctx := inv.Context() ctx := inv.Context()
coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID) workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err) require.NoError(t, err)
resources := workspace.LatestBuild.Resources resources := workspace.LatestBuild.Resources
@ -212,7 +217,8 @@ func TestWorkspaceAgent(t *testing.T) {
clitest.Start(t, inv) clitest.Start(t, inv)
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithSubsystems).Wait()
require.Len(t, resources, 1) require.Len(t, resources, 1)
require.Len(t, resources[0].Agents, 1) require.Len(t, resources[0].Agents, 1)
require.Len(t, resources[0].Agents[0].Subsystems, 2) require.Len(t, resources[0].Agents[0].Subsystems, 2)
@ -221,3 +227,29 @@ func TestWorkspaceAgent(t *testing.T) {
require.Equal(t, codersdk.AgentSubsystemExectrace, resources[0].Agents[0].Subsystems[1]) require.Equal(t, codersdk.AgentSubsystemExectrace, resources[0].Agents[0].Subsystems[1])
}) })
} }
func matchAgentWithVersion(rs []codersdk.WorkspaceResource) bool {
if len(rs) < 1 {
return false
}
if len(rs[0].Agents) < 1 {
return false
}
if rs[0].Agents[0].Version == "" {
return false
}
return true
}
func matchAgentWithSubsystems(rs []codersdk.WorkspaceResource) bool {
if len(rs) < 1 {
return false
}
if len(rs[0].Agents) < 1 {
return false
}
if len(rs[0].Agents[0].Subsystems) < 1 {
return false
}
return true
}

View File

@ -5,12 +5,12 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors" "golang.org/x/xerrors"
agentproto "github.com/coder/coder/v2/agent/proto" agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/appearance" "github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/stretchr/testify/require"
) )
func TestGetServiceBanner(t *testing.T) { func TestGetServiceBanner(t *testing.T) {

View File

@ -915,23 +915,67 @@ func AwaitWorkspaceBuildJobCompleted(t testing.TB, client *codersdk.Client, buil
// AwaitWorkspaceAgents waits for all resources with agents to be connected. If // AwaitWorkspaceAgents waits for all resources with agents to be connected. If
// specific agents are provided, it will wait for those agents to be connected // specific agents are provided, it will wait for those agents to be connected
// but will not fail if other agents are not connected. // but will not fail if other agents are not connected.
//
// Deprecated: Use NewWorkspaceAgentWaiter
func AwaitWorkspaceAgents(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID, agentNames ...string) []codersdk.WorkspaceResource { func AwaitWorkspaceAgents(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID, agentNames ...string) []codersdk.WorkspaceResource {
t.Helper() return NewWorkspaceAgentWaiter(t, client, workspaceID).AgentNames(agentNames).Wait()
}
agentNamesMap := make(map[string]struct{}, len(agentNames)) // WorkspaceAgentWaiter waits for all resources with agents to be connected. If
for _, name := range agentNames { // specific agents are provided using AgentNames(), it will wait for those agents
// to be connected but will not fail if other agents are not connected.
type WorkspaceAgentWaiter struct {
t testing.TB
client *codersdk.Client
workspaceID uuid.UUID
agentNames []string
resourcesMatcher func([]codersdk.WorkspaceResource) bool
}
// NewWorkspaceAgentWaiter returns an object that waits for agents to connect when
// you call Wait() on it.
func NewWorkspaceAgentWaiter(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID) WorkspaceAgentWaiter {
return WorkspaceAgentWaiter{
t: t,
client: client,
workspaceID: workspaceID,
}
}
// AgentNames instructs the waiter to wait for the given, named agents to be connected and will
// return even if other agents are not connected.
func (w WorkspaceAgentWaiter) AgentNames(names []string) WorkspaceAgentWaiter {
//nolint: revive // returns modified struct
w.agentNames = names
return w
}
// MatchResources instructs the waiter to wait until the workspace has resources that cause the
// provided matcher function to return true.
func (w WorkspaceAgentWaiter) MatchResources(m func([]codersdk.WorkspaceResource) bool) WorkspaceAgentWaiter {
//nolint: revive // returns modified struct
w.resourcesMatcher = m
return w
}
// Wait waits for the agent(s) to connect and fails the test if they do not within testutil.WaitLong
func (w WorkspaceAgentWaiter) Wait() []codersdk.WorkspaceResource {
w.t.Helper()
agentNamesMap := make(map[string]struct{}, len(w.agentNames))
for _, name := range w.agentNames {
agentNamesMap[name] = struct{}{} agentNamesMap[name] = struct{}{}
} }
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
t.Logf("waiting for workspace agents (workspace %s)", workspaceID) w.t.Logf("waiting for workspace agents (workspace %s)", w.workspaceID)
var resources []codersdk.WorkspaceResource var resources []codersdk.WorkspaceResource
require.Eventually(t, func() bool { require.Eventually(w.t, func() bool {
var err error var err error
workspace, err := client.Workspace(ctx, workspaceID) workspace, err := w.client.Workspace(ctx, w.workspaceID)
if !assert.NoError(t, err) { if !assert.NoError(w.t, err) {
return false return false
} }
if workspace.LatestBuild.Job.CompletedAt == nil { if workspace.LatestBuild.Job.CompletedAt == nil {
@ -943,23 +987,25 @@ func AwaitWorkspaceAgents(t testing.TB, client *codersdk.Client, workspaceID uui
for _, resource := range workspace.LatestBuild.Resources { for _, resource := range workspace.LatestBuild.Resources {
for _, agent := range resource.Agents { for _, agent := range resource.Agents {
if len(agentNames) > 0 { if len(w.agentNames) > 0 {
if _, ok := agentNamesMap[agent.Name]; !ok { if _, ok := agentNamesMap[agent.Name]; !ok {
continue continue
} }
} }
if agent.Status != codersdk.WorkspaceAgentConnected { if agent.Status != codersdk.WorkspaceAgentConnected {
t.Logf("agent %s not connected yet", agent.Name) w.t.Logf("agent %s not connected yet", agent.Name)
return false return false
} }
} }
} }
resources = workspace.LatestBuild.Resources resources = workspace.LatestBuild.Resources
if w.resourcesMatcher == nil {
return true return true
}
return w.resourcesMatcher(resources)
}, testutil.WaitLong, testutil.IntervalMedium) }, testutil.WaitLong, testutil.IntervalMedium)
t.Logf("got workspace agents (workspace %s)", workspaceID) w.t.Logf("got workspace agents (workspace %s)", w.workspaceID)
return resources return resources
} }

View File

@ -29,6 +29,8 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/wsconncache" "github.com/coder/coder/v2/coderd/wsconncache"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
@ -171,13 +173,12 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
_ = coordinator.Close() _ = coordinator.Close()
}) })
manifest.AgentID = uuid.New() manifest.AgentID = uuid.New()
aC := &client{ aC := newClient(
t: t, t,
agentID: manifest.AgentID, slogtest.Make(t, nil).Leveled(slog.LevelDebug),
manifest: manifest, manifest,
coordinator: coordinator, coordinator,
derpMapUpdates: make(chan *tailcfg.DERPMap), )
}
t.Cleanup(aC.close) t.Cleanup(aC.close)
closer := agent.New(agent.Options{ closer := agent.New(agent.Options{
Client: aC, Client: aC,
@ -239,6 +240,45 @@ type client struct {
coordinator tailnet.Coordinator coordinator tailnet.Coordinator
closeOnce sync.Once closeOnce sync.Once
derpMapUpdates chan *tailcfg.DERPMap derpMapUpdates chan *tailcfg.DERPMap
server *drpcserver.Server
fakeAgentAPI *agenttest.FakeAgentAPI
}
func newClient(t *testing.T, logger slog.Logger, manifest agentsdk.Manifest, coordinator tailnet.Coordinator) *client {
logger = logger.Named("drpc")
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coordinator)
mux := drpcmux.New()
derpMapUpdates := make(chan *tailcfg.DERPMap)
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
DerpMapUpdateFrequency: time.Microsecond,
DerpMapFn: func() *tailcfg.DERPMap { return <-derpMapUpdates },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
require.NoError(t, err)
fakeAAPI := agenttest.NewFakeAgentAPI(t, logger)
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
return &client{
t: t,
agentID: manifest.AgentID,
manifest: manifest,
coordinator: coordinator,
derpMapUpdates: derpMapUpdates,
server: server,
fakeAgentAPI: fakeAAPI,
}
} }
func (c *client) close() { func (c *client) close() {
@ -250,35 +290,12 @@ func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
} }
func (c *client) Listen(_ context.Context) (drpc.Conn, error) { func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc")
conn, lis := drpcsdk.MemTransportPipe() conn, lis := drpcsdk.MemTransportPipe()
c.t.Cleanup(func() { c.t.Cleanup(func() {
_ = conn.Close() _ = conn.Close()
_ = lis.Close() _ = lis.Close()
}) })
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&c.coordinator)
mux := drpcmux.New()
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
DerpMapUpdateFrequency: time.Microsecond,
DerpMapFn: func() *tailcfg.DERPMap { return <-c.derpMapUpdates },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
if err != nil {
return nil, xerrors.Errorf("register DRPC service: %w", err)
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
serveCtx, cancel := context.WithCancel(context.Background()) serveCtx, cancel := context.WithCancel(context.Background())
c.t.Cleanup(cancel) c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{} auth := tailnet.AgentTunnelAuth{}
@ -289,7 +306,7 @@ func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
} }
serveCtx = tailnet.WithStreamID(serveCtx, streamID) serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() { go func() {
server.Serve(serveCtx, lis) c.server.Serve(serveCtx, lis)
}() }()
return conn, nil return conn, nil
} }
@ -317,7 +334,3 @@ func (*client) PostStartup(_ context.Context, _ agentsdk.PostStartupRequest) err
func (*client) PatchLogs(_ context.Context, _ agentsdk.PatchLogs) error { func (*client) PatchLogs(_ context.Context, _ agentsdk.PatchLogs) error {
return nil return nil
} }
func (*client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
return codersdk.ServiceBannerConfig{}, nil
}

View File

@ -637,24 +637,6 @@ func (c *Client) PostLogSource(ctx context.Context, req PostLogSource) (codersdk
return logSource, json.NewDecoder(res.Body).Decode(&logSource) return logSource, json.NewDecoder(res.Body).Decode(&logSource)
} }
// GetServiceBanner relays the service banner config.
func (c *Client) GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error) {
res, err := c.SDK.Request(ctx, http.MethodGet, "/api/v2/appearance", nil)
if err != nil {
return codersdk.ServiceBannerConfig{}, err
}
defer res.Body.Close()
// If the route does not exist then Enterprise code is not enabled.
if res.StatusCode == http.StatusNotFound {
return codersdk.ServiceBannerConfig{}, nil
}
if res.StatusCode != http.StatusOK {
return codersdk.ServiceBannerConfig{}, codersdk.ReadBodyAsError(res)
}
var cfg codersdk.AppearanceConfig
return cfg.ServiceBanner, json.NewDecoder(res.Body).Decode(&cfg)
}
type ExternalAuthResponse struct { type ExternalAuthResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenExtra map[string]interface{} `json:"token_extra"` TokenExtra map[string]interface{} `json:"token_extra"`

View File

@ -157,34 +157,25 @@ func TestServiceBanners(t *testing.T) {
agentClient := agentsdk.New(client.URL) agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(r.AgentToken) agentClient.SetSessionToken(r.AgentToken)
banner, err := agentClient.GetServiceBanner(ctx) banner := requireGetServiceBanner(ctx, t, agentClient)
require.NoError(t, err)
require.Equal(t, cfg.ServiceBanner, banner)
banner = requireGetServiceBannerV2(ctx, t, agentClient)
require.Equal(t, cfg.ServiceBanner, banner) require.Equal(t, cfg.ServiceBanner, banner)
// Create an AGPL Coderd against the same database // Create an AGPL Coderd against the same database
agplClient := coderdtest.New(t, &coderdtest.Options{Database: store, Pubsub: ps}) agplClient := coderdtest.New(t, &coderdtest.Options{Database: store, Pubsub: ps})
agplAgentClient := agentsdk.New(agplClient.URL) agplAgentClient := agentsdk.New(agplClient.URL)
agplAgentClient.SetSessionToken(r.AgentToken) agplAgentClient.SetSessionToken(r.AgentToken)
banner, err = agplAgentClient.GetServiceBanner(ctx) banner = requireGetServiceBanner(ctx, t, agplAgentClient)
require.NoError(t, err)
require.Equal(t, codersdk.ServiceBannerConfig{}, banner)
banner = requireGetServiceBannerV2(ctx, t, agplAgentClient)
require.Equal(t, codersdk.ServiceBannerConfig{}, banner) require.Equal(t, codersdk.ServiceBannerConfig{}, banner)
// No license means no banner. // No license means no banner.
err = client.DeleteLicense(ctx, lic.ID) err = client.DeleteLicense(ctx, lic.ID)
require.NoError(t, err) require.NoError(t, err)
banner, err = agentClient.GetServiceBanner(ctx) banner = requireGetServiceBanner(ctx, t, agentClient)
require.NoError(t, err)
require.Equal(t, codersdk.ServiceBannerConfig{}, banner)
banner = requireGetServiceBannerV2(ctx, t, agentClient)
require.Equal(t, codersdk.ServiceBannerConfig{}, banner) require.Equal(t, codersdk.ServiceBannerConfig{}, banner)
}) })
} }
func requireGetServiceBannerV2(ctx context.Context, t *testing.T, client *agentsdk.Client) codersdk.ServiceBannerConfig { func requireGetServiceBanner(ctx context.Context, t *testing.T, client *agentsdk.Client) codersdk.ServiceBannerConfig {
cc, err := client.Listen(ctx) cc, err := client.Listen(ctx)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {

2
go.mod
View File

@ -73,7 +73,7 @@ replace github.com/imulab/go-scim/pkg/v2 => github.com/coder/go-scim/pkg/v2 v2.0
replace github.com/pkg/sftp => github.com/mafredri/sftp v1.13.6-0.20231212144145-8218e927edb0 replace github.com/pkg/sftp => github.com/mafredri/sftp v1.13.6-0.20231212144145-8218e927edb0
require ( require (
cdr.dev/slog v1.6.2-0.20230929193652-f0c466fabe10 cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6
cloud.google.com/go/compute/metadata v0.2.3 cloud.google.com/go/compute/metadata v0.2.3
github.com/AlecAivazis/survey/v2 v2.3.5 github.com/AlecAivazis/survey/v2 v2.3.5
github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d

4
go.sum
View File

@ -1,5 +1,5 @@
cdr.dev/slog v1.6.2-0.20230929193652-f0c466fabe10 h1:gnB1By6Hzs2PVQXyi/cvo6L3kHPb8utLuzycWHfCztQ= cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 h1:KHblWIE/KHOwQ6lEbMZt6YpcGve2FEZ1sDtrW1Am5UI=
cdr.dev/slog v1.6.2-0.20230929193652-f0c466fabe10/go.mod h1:NaoTA7KwopCrnaSb0JXTC0PTp/O/Y83Lndnq0OEV3ZQ= cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6/go.mod h1:NaoTA7KwopCrnaSb0JXTC0PTp/O/Y83Lndnq0OEV3ZQ=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiVlk= cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiVlk=
cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI= cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI=