From 198b56c1379f96a738092564b6ba01b7f1d0fb88 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 16 Nov 2023 17:03:53 +0200 Subject: [PATCH] fix(coderd): fix memory leak in `watchWorkspaceAgentMetadata` (#10685) Fixes #10550 --- coderd/workspaceagents.go | 125 +++++++++++++++++--------- coderd/workspaceagents_test.go | 159 +++++++++++++++++++++++++++++++++ codersdk/workspaceagents.go | 5 ++ testutil/go.go | 23 +++++ 4 files changed, 269 insertions(+), 43 deletions(-) create mode 100644 testutil/go.go diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 82a40f699d..11c9279e36 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1785,10 +1785,10 @@ func (api *API) workspaceAgentUpdateMetadata(ctx context.Context, workspaceAgent datum := database.UpdateWorkspaceAgentMetadataParams{ WorkspaceAgentID: workspaceAgent.ID, - Key: []string{}, - Value: []string{}, - Error: []string{}, - CollectedAt: []time.Time{}, + Key: make([]string, 0, len(req.Metadata)), + Value: make([]string, 0, len(req.Metadata)), + Error: make([]string, 0, len(req.Metadata)), + CollectedAt: make([]time.Time, 0, len(req.Metadata)), } for _, md := range req.Metadata { @@ -1853,22 +1853,28 @@ func (api *API) workspaceAgentUpdateMetadata(ctx context.Context, workspaceAgent // @Router /workspaceagents/{workspaceagent}/watch-metadata [get] // @x-apidocgen {"skip": true} func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - workspaceAgent = httpmw.WorkspaceAgentParam(r) - log = api.Logger.Named("workspace_metadata_watcher").With( - slog.F("workspace_agent_id", workspaceAgent.ID), - ) + // Allow us to interrupt watch via cancel. + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + r = r.WithContext(ctx) // Rewire context for SSE cancellation. + + workspaceAgent := httpmw.WorkspaceAgentParam(r) + log := api.Logger.Named("workspace_metadata_watcher").With( + slog.F("workspace_agent_id", workspaceAgent.ID), ) // Send metadata on updates, we must ensure subscription before sending // initial metadata to guarantee that events in-between are not missed. update := make(chan workspaceAgentMetadataChannelPayload, 1) cancelSub, err := api.Pubsub.Subscribe(watchWorkspaceAgentMetadataChannel(workspaceAgent.ID), func(_ context.Context, byt []byte) { + if ctx.Err() != nil { + return + } + var payload workspaceAgentMetadataChannelPayload err := json.Unmarshal(byt, &payload) if err != nil { - api.Logger.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err)) + log.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err)) return } @@ -1876,18 +1882,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ select { case prev := <-update: - // This update wasn't consumed yet, merge the keys. - newKeysSet := make(map[string]struct{}) - for _, key := range payload.Keys { - newKeysSet[key] = struct{}{} - } - keys := prev.Keys - for _, key := range prev.Keys { - if _, ok := newKeysSet[key]; !ok { - keys = append(keys, key) - } - } - payload.Keys = keys + payload.Keys = appendUnique(prev.Keys, payload.Keys) default: } // This can never block since we pop and merge beforehand. @@ -1899,22 +1894,9 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ } defer cancelSub() - sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error setting up server-sent events.", - Detail: err.Error(), - }) - return - } - // Prevent handler from returning until the sender is closed. - defer func() { - <-sseSenderClosed - }() - // We always use the original Request context because it contains // the RBAC actor. - md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{ + initialMD, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{ WorkspaceAgentID: workspaceAgent.ID, Keys: nil, }) @@ -1926,15 +1908,45 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ return } - metadataMap := make(map[string]database.WorkspaceAgentMetadatum) - for _, datum := range md { + log.Debug(ctx, "got initial metadata", "num", len(initialMD)) + + metadataMap := make(map[string]database.WorkspaceAgentMetadatum, len(initialMD)) + for _, datum := range initialMD { metadataMap[datum.Key] = datum } + //nolint:ineffassign // Release memory. + initialMD = nil + + sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error setting up server-sent events.", + Detail: err.Error(), + }) + return + } + // Prevent handler from returning until the sender is closed. + defer func() { + cancel() + <-sseSenderClosed + }() + // Synchronize cancellation from SSE -> context, this lets us simplify the + // cancellation logic. + go func() { + select { + case <-ctx.Done(): + case <-sseSenderClosed: + cancel() + } + }() var lastSend time.Time sendMetadata := func() { lastSend = time.Now() values := maps.Values(metadataMap) + + log.Debug(ctx, "sending metadata", "num", len(values)) + _ = sseSendEvent(ctx, codersdk.ServerSentEvent{ Type: codersdk.ServerSentEventTypeData, Data: convertWorkspaceAgentMetadata(values), @@ -1953,10 +1965,11 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ fetchedMetadata := make(chan []database.WorkspaceAgentMetadatum) go func() { defer close(fetchedMetadata) + defer cancel() for { select { - case <-sseSenderClosed: + case <-ctx.Done(): return case payload := <-update: md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{ @@ -1966,24 +1979,35 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ if err != nil { if !errors.Is(err, context.Canceled) { log.Error(ctx, "failed to get metadata", slog.Error(err)) + _ = sseSendEvent(ctx, codersdk.ServerSentEvent{ + Type: codersdk.ServerSentEventTypeError, + Data: codersdk.Response{ + Message: "Failed to get metadata.", + Detail: err.Error(), + }, + }) } return } select { - case <-sseSenderClosed: + case <-ctx.Done(): return // We want to block here to avoid constantly pinging the // database when the metadata isn't being processed. case fetchedMetadata <- md: + log.Debug(ctx, "fetched metadata update for keys", "keys", payload.Keys, "num", len(md)) } } } }() + defer func() { + <-fetchedMetadata + }() pendingChanges := true for { select { - case <-sseSenderClosed: + case <-ctx.Done(): return case md, ok := <-fetchedMetadata: if !ok { @@ -2007,9 +2031,24 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ } } +// appendUnique is like append and adds elements from src to dst, +// skipping any elements that already exist in dst. +func appendUnique[T comparable](dst, src []T) []T { + exists := make(map[T]struct{}, len(dst)) + for _, key := range dst { + exists[key] = struct{}{} + } + for _, key := range src { + if _, ok := exists[key]; !ok { + dst = append(dst, key) + } + } + return dst +} + func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []codersdk.WorkspaceAgentMetadata { // An empty array is easier for clients to handle than a null. - result := []codersdk.WorkspaceAgentMetadata{} + result := make([]codersdk.WorkspaceAgentMetadata, 0, len(db)) for _, datum := range db { result = append(result, codersdk.WorkspaceAgentMetadata{ Result: codersdk.WorkspaceAgentMetadataResult{ diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 1e56616b11..e0c4c7f45e 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "tailscale.com/tailcfg" "cdr.dev/slog" @@ -25,7 +26,9 @@ import ( "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/provisioner/echo" @@ -1107,6 +1110,162 @@ func TestWorkspaceAgent_Metadata(t *testing.T) { post("unknown", unknownKeyMetadata) } +type testWAMErrorStore struct { + database.Store + err atomic.Pointer[error] +} + +func (s *testWAMErrorStore) GetWorkspaceAgentMetadata(ctx context.Context, arg database.GetWorkspaceAgentMetadataParams) ([]database.WorkspaceAgentMetadatum, error) { + err := s.err.Load() + if err != nil { + return nil, *err + } + return s.Store.GetWorkspaceAgentMetadata(ctx, arg) +} + +func TestWorkspaceAgent_Metadata_CatchMemoryLeak(t *testing.T) { + t.Parallel() + + db := &testWAMErrorStore{Store: dbmem.New()} + psub := pubsub.NewInMemory() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("coderd").Leveled(slog.LevelDebug) + client := coderdtest.New(t, &coderdtest.Options{ + Database: db, + Pubsub: psub, + IncludeProvisionerDaemon: true, + Logger: &logger, + }) + user := coderdtest.CreateFirstUser(t, client) + authToken := uuid.NewString() + ws := dbfake.Workspace(t, db, database.Workspace{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }) + dbfake.WorkspaceBuild(t, db, ws, database.WorkspaceBuild{}, &proto.Resource{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Metadata: []*proto.Agent_Metadata{ + { + DisplayName: "First Meta", + Key: "foo1", + Script: "echo hi", + Interval: 10, + Timeout: 3, + }, + { + DisplayName: "Second Meta", + Key: "foo2", + Script: "echo bye", + Interval: 10, + Timeout: 3, + }, + }, + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, + }) + workspace, err := client.Workspace(context.Background(), ws.ID) + require.NoError(t, err) + for _, res := range workspace.LatestBuild.Resources { + for _, a := range res.Agents { + require.Equal(t, codersdk.WorkspaceAgentLifecycleCreated, a.LifecycleState) + } + } + + agentClient := agentsdk.New(client.URL) + agentClient.SetSessionToken(authToken) + + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitSuperLong)) + + manifest, err := agentClient.Manifest(ctx) + require.NoError(t, err) + + post := func(ctx context.Context, key, value string) error { + return agentClient.PostMetadata(ctx, agentsdk.PostMetadataRequest{ + Metadata: []agentsdk.Metadata{ + { + Key: key, + WorkspaceAgentMetadataResult: codersdk.WorkspaceAgentMetadataResult{ + CollectedAt: time.Now(), + Value: value, + }, + }, + }, + }) + } + + workspace, err = client.Workspace(ctx, workspace.ID) + require.NoError(t, err, "get workspace") + + // Start the SSE connection. + metadata, errors := client.WatchWorkspaceAgentMetadata(ctx, manifest.AgentID) + + // Discard the output, pretending to be a client consuming it. + wantErr := xerrors.New("test error") + metadataDone := testutil.Go(t, func() { + for { + select { + case <-ctx.Done(): + return + case _, ok := <-metadata: + if !ok { + return + } + case err := <-errors: + if err != nil && !strings.Contains(err.Error(), wantErr.Error()) { + assert.NoError(t, err, "watch metadata") + } + return + } + } + }) + + postDone := testutil.Go(t, func() { + for { + // We need to send two separate metadata updates to trigger the + // memory leak. foo2 will cause the number of foo1 to be doubled, etc. + err = post(ctx, "foo1", "hi") + if err != nil { + if !xerrors.Is(err, context.Canceled) { + assert.NoError(t, err, "post metadata foo1") + } + return + } + err = post(ctx, "foo2", "bye") + if err != nil { + if !xerrors.Is(err, context.Canceled) { + assert.NoError(t, err, "post metadata foo1") + } + return + } + } + }) + + // In a previously faulty implementation, this database error will trigger + // a close of the goroutine that consumes metadata updates for refreshing + // the metadata sent over SSE. As it was, the exit of the consumer was not + // detected as a trigger to close down the connection. + // + // Further, there was a memory leak in the pubsub subscription that cause + // ballooning of memory (almost double in size every received metadata). + // + // This db error should trigger a close of the SSE connection in the fixed + // implementation. The memory leak should not happen in either case, but + // testing it is not straightforward. + db.err.Store(&wantErr) + + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for SSE to close") + case <-metadataDone: + } + cancel() + <-postDone +} + func TestWorkspaceAgent_Startup(t *testing.T) { t.Parallel() diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 1f098a77f6..e020fd579a 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -492,6 +492,11 @@ func (c *Client) WatchWorkspaceAgentMetadata(ctx context.Context, id uuid.UUID) firstEvent = false } + // Ignore pings. + if sse.Type == ServerSentEventTypePing { + continue + } + b, ok := sse.Data.([]byte) if !ok { return xerrors.Errorf("unexpected data type: %T", sse.Data) diff --git a/testutil/go.go b/testutil/go.go new file mode 100644 index 0000000000..7eddac7e87 --- /dev/null +++ b/testutil/go.go @@ -0,0 +1,23 @@ +package testutil + +import ( + "testing" +) + +// Go runs fn in a goroutine and waits until fn has completed before +// test completion. Done is returned for optionally waiting for fn to +// exit. +func Go(t *testing.T, fn func()) (done <-chan struct{}) { + t.Helper() + + doneC := make(chan struct{}) + t.Cleanup(func() { + <-doneC + }) + go func() { + fn() + close(doneC) + }() + + return doneC +}