feat: use agent v2 API to post startup (#11877)

Uses the v2 Agent API to post startup information.
This commit is contained in:
Spike Curtis 2024-01-30 11:23:28 +04:00 committed by GitHub
parent da8bb1c198
commit 2599850e54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 101 additions and 71 deletions

View File

@ -92,7 +92,6 @@ type Client interface {
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
RewriteDERPMap(derpMap *tailcfg.DERPMap)
@ -737,13 +736,18 @@ func (a *agent) run(ctx context.Context) error {
if err != nil {
return xerrors.Errorf("expand directory: %w", err)
}
err = a.client.PostStartup(ctx, agentsdk.PostStartupRequest{
subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems)
if err != nil {
a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err))
return xerrors.Errorf("failed to convert subsystems: %w", err)
}
_, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{
Version: buildinfo.Version(),
ExpandedDirectory: manifest.Directory,
Subsystems: a.subsystems,
})
Subsystems: subsys,
}})
if err != nil {
return xerrors.Errorf("update workspace agent version: %w", err)
return xerrors.Errorf("update workspace agent startup: %w", err)
}
oldManifest := a.manifest.Swap(&manifest)

View File

@ -1394,56 +1394,52 @@ func TestAgent_Startup(t *testing.T) {
t.Run("EmptyDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "",
}, 0)
assert.Eventually(t, func() bool {
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, "", client.GetStartup().ExpandedDirectory)
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
require.Equal(t, "", startup.GetExpandedDirectory())
})
t.Run("HomeDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "~",
}, 0)
assert.Eventually(t, func() bool {
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
require.Equal(t, homeDir, startup.GetExpandedDirectory())
})
t.Run("NotAbsoluteDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "coder/coder",
}, 0)
assert.Eventually(t, func() bool {
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory)
require.Equal(t, filepath.Join(homeDir, "coder/coder"), startup.GetExpandedDirectory())
})
t.Run("HomeEnvironmentVariable", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "$HOME",
}, 0)
assert.Eventually(t, func() bool {
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
require.Equal(t, homeDir, startup.GetExpandedDirectory())
})
}

View File

@ -88,7 +88,6 @@ type Client struct {
mu sync.Mutex // Protects following.
lifecycleStates []codersdk.WorkspaceAgentLifecycle
startup agentsdk.PostStartupRequest
logs []agentsdk.Log
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
@ -173,10 +172,8 @@ func (c *Client) PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsR
return nil
}
func (c *Client) GetStartup() agentsdk.PostStartupRequest {
c.mu.Lock()
defer c.mu.Unlock()
return c.startup
func (c *Client) GetStartup() <-chan *agentproto.Startup {
return c.fakeAgentAPI.startupCh
}
func (c *Client) GetMetadata() map[string]agentsdk.Metadata {
@ -198,14 +195,6 @@ func (c *Client) PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequ
return nil
}
func (c *Client) PostStartup(ctx context.Context, startup agentsdk.PostStartupRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.startup = startup
c.logger.Debug(ctx, "post startup", slog.F("req", startup))
return nil
}
func (c *Client) GetStartupLogs() []agentsdk.Log {
c.mu.Lock()
defer c.mu.Unlock()
@ -250,7 +239,8 @@ type FakeAgentAPI struct {
t testing.TB
logger slog.Logger
manifest *agentproto.Manifest
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
}
@ -294,9 +284,9 @@ func (*FakeAgentAPI) BatchUpdateAppHealths(context.Context, *agentproto.BatchUpd
panic("implement me")
}
func (*FakeAgentAPI) UpdateStartup(context.Context, *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
// TODO implement me
panic("implement me")
func (f *FakeAgentAPI) UpdateStartup(_ context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
f.startupCh <- req.GetStartup()
return req.GetStartup(), nil
}
func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
@ -311,8 +301,9 @@ func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLog
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI {
return &FakeAgentAPI{
t: t,
logger: logger.Named("FakeAgentAPI"),
manifest: manifest,
t: t,
logger: logger.Named("FakeAgentAPI"),
manifest: manifest,
startupCh: make(chan *agentproto.Startup, 100),
}
}

View File

@ -29,8 +29,6 @@ import (
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
)
const AgentAPIVersionDRPC = "2.0"
// API implements the DRPC agent API interface from agent/proto. This struct is
// instantiated once per agent connection and kept alive for the duration of the
// session.

View File

@ -6,6 +6,7 @@ import (
"time"
"github.com/google/uuid"
"golang.org/x/exp/slices"
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
@ -16,6 +17,12 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
)
type contextKeyAPIVersion struct{}
func WithAPIVersion(ctx context.Context, version string) context.Context {
return context.WithValue(ctx, contextKeyAPIVersion{}, version)
}
type LifecycleAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
@ -123,6 +130,10 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
}
func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
apiVersion, ok := ctx.Value(contextKeyAPIVersion{}).(string)
if !ok {
return nil, xerrors.Errorf("internal error; api version unspecified")
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
@ -164,13 +175,14 @@ func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.Update
dbSubsystems = append(dbSubsystems, dbSubsystem)
}
}
slices.Sort(dbSubsystems)
err = a.Database.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{
ID: workspaceAgent.ID,
Version: req.Startup.Version,
ExpandedDirectory: req.Startup.ExpandedDirectory,
Subsystems: dbSubsystems,
APIVersion: AgentAPIVersionDRPC,
APIVersion: apiVersion,
})
if err != nil {
return nil, xerrors.Errorf("update workspace agent startup in database: %w", err)

View File

@ -382,10 +382,11 @@ func TestUpdateStartup(t *testing.T) {
database.WorkspaceAgentSubsystemEnvbuilder,
database.WorkspaceAgentSubsystemExectrace,
},
APIVersion: agentapi.AgentAPIVersionDRPC,
APIVersion: "2.0",
}).Return(nil)
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.NoError(t, err)
@ -416,7 +417,8 @@ func TestUpdateStartup(t *testing.T) {
Subsystems: []agentproto.Startup_Subsystem{},
}
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.Error(t, err)
@ -451,7 +453,8 @@ func TestUpdateStartup(t *testing.T) {
},
}
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.Error(t, err)

View File

@ -24,7 +24,6 @@ import (
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
@ -1389,13 +1388,13 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
}
)
err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
Version: expectedVersion,
ExpandedDirectory: expectedDir,
Subsystems: []codersdk.AgentSubsystem{
Subsystems: []agentproto.Startup_Subsystem{
// Not sorted.
expectedSubsystems[1],
expectedSubsystems[0],
agentproto.Startup_EXECTRACE,
agentproto.Startup_ENVBOX,
},
})
require.NoError(t, err)
@ -1409,7 +1408,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
require.Equal(t, expectedDir, wsagent.ExpandedDirectory)
// Sorted
require.Equal(t, expectedSubsystems, wsagent.Subsystems)
require.Equal(t, coderd.AgentAPIVersionREST, wsagent.APIVersion)
require.Equal(t, agentproto.CurrentVersion.String(), wsagent.APIVersion)
})
t.Run("InvalidSemver", func(t *testing.T) {
@ -1427,13 +1426,10 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitMedium)
err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
Version: "1.2.3",
})
require.Error(t, err)
cerr, ok := codersdk.AsError(err)
require.True(t, ok)
require.Equal(t, http.StatusBadRequest, cerr.StatusCode())
require.ErrorContains(t, err, "invalid agent semver version")
})
}
@ -1640,3 +1636,15 @@ func requireGetManifest(ctx context.Context, t testing.TB, client agent.Client)
require.NoError(t, err)
return manifest
}
func postStartup(ctx context.Context, t testing.TB, client agent.Client, startup *agentproto.Startup) error {
conn, err := client.Listen(ctx)
require.NoError(t, err)
defer func() {
cErr := conn.Close()
require.NoError(t, cErr)
}()
aAPI := agentproto.NewDRPCAgentClient(conn)
_, err = aAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{Startup: startup})
return err
}

View File

@ -154,6 +154,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
Auth: tailnet.AgentTunnelAuth{},
}
ctx = tailnet.WithStreamID(ctx, streamID)
ctx = agentapi.WithAPIVersion(ctx, version)
err = agentAPI.Serve(ctx, mux)
if err != nil {
api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err))

View File

@ -556,18 +556,6 @@ type PostStartupRequest struct {
Subsystems []codersdk.AgentSubsystem `json:"subsystems"`
}
func (c *Client) PostStartup(ctx context.Context, req PostStartupRequest) error {
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/startup", req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return codersdk.ReadBodyAsError(res)
}
return nil
}
type Log struct {
CreatedAt time.Time `json:"created_at"`
Output string `json:"output"`

View File

@ -266,3 +266,15 @@ func ProtoFromServiceBanner(sb codersdk.ServiceBannerConfig) *proto.ServiceBanne
BackgroundColor: sb.BackgroundColor,
}
}
func ProtoFromSubsystems(ss []codersdk.AgentSubsystem) ([]proto.Startup_Subsystem, error) {
ret := make([]proto.Startup_Subsystem, len(ss))
for i, s := range ss {
pi, ok := proto.Startup_Subsystem_value[strings.ToUpper(string(s))]
if !ok {
return nil, xerrors.Errorf("unknown subsystem: %s", s)
}
ret[i] = proto.Startup_Subsystem(pi)
}
return ret, nil
}

View File

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
@ -144,3 +145,19 @@ func TestManifest(t *testing.T) {
require.Equal(t, manifest.Metadata, back.Metadata)
require.Equal(t, manifest.Scripts, back.Scripts)
}
func TestSubsystems(t *testing.T) {
t.Parallel()
ss := []codersdk.AgentSubsystem{
codersdk.AgentSubsystemEnvbox,
codersdk.AgentSubsystemEnvbuilder,
codersdk.AgentSubsystemExectrace,
}
ps, err := agentsdk.ProtoFromSubsystems(ss)
require.NoError(t, err)
require.Equal(t, ps, []proto.Startup_Subsystem{
proto.Startup_ENVBOX,
proto.Startup_ENVBUILDER,
proto.Startup_EXECTRACE,
})
}