fix(coderd): fix memory leak in `watchWorkspaceAgentMetadata` (#10685)

Fixes #10550
This commit is contained in:
Mathias Fredriksson 2023-11-16 17:03:53 +02:00 committed by GitHub
parent c130f8d6d0
commit 198b56c137
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 269 additions and 43 deletions

View File

@ -1785,10 +1785,10 @@ func (api *API) workspaceAgentUpdateMetadata(ctx context.Context, workspaceAgent
datum := database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: workspaceAgent.ID,
Key: []string{},
Value: []string{},
Error: []string{},
CollectedAt: []time.Time{},
Key: make([]string, 0, len(req.Metadata)),
Value: make([]string, 0, len(req.Metadata)),
Error: make([]string, 0, len(req.Metadata)),
CollectedAt: make([]time.Time, 0, len(req.Metadata)),
}
for _, md := range req.Metadata {
@ -1853,22 +1853,28 @@ func (api *API) workspaceAgentUpdateMetadata(ctx context.Context, workspaceAgent
// @Router /workspaceagents/{workspaceagent}/watch-metadata [get]
// @x-apidocgen {"skip": true}
func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
workspaceAgent = httpmw.WorkspaceAgentParam(r)
log = api.Logger.Named("workspace_metadata_watcher").With(
slog.F("workspace_agent_id", workspaceAgent.ID),
)
// Allow us to interrupt watch via cancel.
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
r = r.WithContext(ctx) // Rewire context for SSE cancellation.
workspaceAgent := httpmw.WorkspaceAgentParam(r)
log := api.Logger.Named("workspace_metadata_watcher").With(
slog.F("workspace_agent_id", workspaceAgent.ID),
)
// Send metadata on updates, we must ensure subscription before sending
// initial metadata to guarantee that events in-between are not missed.
update := make(chan workspaceAgentMetadataChannelPayload, 1)
cancelSub, err := api.Pubsub.Subscribe(watchWorkspaceAgentMetadataChannel(workspaceAgent.ID), func(_ context.Context, byt []byte) {
if ctx.Err() != nil {
return
}
var payload workspaceAgentMetadataChannelPayload
err := json.Unmarshal(byt, &payload)
if err != nil {
api.Logger.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err))
log.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err))
return
}
@ -1876,18 +1882,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
select {
case prev := <-update:
// This update wasn't consumed yet, merge the keys.
newKeysSet := make(map[string]struct{})
for _, key := range payload.Keys {
newKeysSet[key] = struct{}{}
}
keys := prev.Keys
for _, key := range prev.Keys {
if _, ok := newKeysSet[key]; !ok {
keys = append(keys, key)
}
}
payload.Keys = keys
payload.Keys = appendUnique(prev.Keys, payload.Keys)
default:
}
// This can never block since we pop and merge beforehand.
@ -1899,22 +1894,9 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
}
defer cancelSub()
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
Detail: err.Error(),
})
return
}
// Prevent handler from returning until the sender is closed.
defer func() {
<-sseSenderClosed
}()
// We always use the original Request context because it contains
// the RBAC actor.
md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
initialMD, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
WorkspaceAgentID: workspaceAgent.ID,
Keys: nil,
})
@ -1926,15 +1908,45 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
return
}
metadataMap := make(map[string]database.WorkspaceAgentMetadatum)
for _, datum := range md {
log.Debug(ctx, "got initial metadata", "num", len(initialMD))
metadataMap := make(map[string]database.WorkspaceAgentMetadatum, len(initialMD))
for _, datum := range initialMD {
metadataMap[datum.Key] = datum
}
//nolint:ineffassign // Release memory.
initialMD = nil
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
Detail: err.Error(),
})
return
}
// Prevent handler from returning until the sender is closed.
defer func() {
cancel()
<-sseSenderClosed
}()
// Synchronize cancellation from SSE -> context, this lets us simplify the
// cancellation logic.
go func() {
select {
case <-ctx.Done():
case <-sseSenderClosed:
cancel()
}
}()
var lastSend time.Time
sendMetadata := func() {
lastSend = time.Now()
values := maps.Values(metadataMap)
log.Debug(ctx, "sending metadata", "num", len(values))
_ = sseSendEvent(ctx, codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: convertWorkspaceAgentMetadata(values),
@ -1953,10 +1965,11 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
fetchedMetadata := make(chan []database.WorkspaceAgentMetadatum)
go func() {
defer close(fetchedMetadata)
defer cancel()
for {
select {
case <-sseSenderClosed:
case <-ctx.Done():
return
case payload := <-update:
md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
@ -1966,24 +1979,35 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Error(ctx, "failed to get metadata", slog.Error(err))
_ = sseSendEvent(ctx, codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Failed to get metadata.",
Detail: err.Error(),
},
})
}
return
}
select {
case <-sseSenderClosed:
case <-ctx.Done():
return
// We want to block here to avoid constantly pinging the
// database when the metadata isn't being processed.
case fetchedMetadata <- md:
log.Debug(ctx, "fetched metadata update for keys", "keys", payload.Keys, "num", len(md))
}
}
}
}()
defer func() {
<-fetchedMetadata
}()
pendingChanges := true
for {
select {
case <-sseSenderClosed:
case <-ctx.Done():
return
case md, ok := <-fetchedMetadata:
if !ok {
@ -2007,9 +2031,24 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
}
}
// appendUnique is like append and adds elements from src to dst,
// skipping any elements that already exist in dst.
func appendUnique[T comparable](dst, src []T) []T {
exists := make(map[T]struct{}, len(dst))
for _, key := range dst {
exists[key] = struct{}{}
}
for _, key := range src {
if _, ok := exists[key]; !ok {
dst = append(dst, key)
}
}
return dst
}
func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []codersdk.WorkspaceAgentMetadata {
// An empty array is easier for clients to handle than a null.
result := []codersdk.WorkspaceAgentMetadata{}
result := make([]codersdk.WorkspaceAgentMetadata, 0, len(db))
for _, datum := range db {
result = append(result, codersdk.WorkspaceAgentMetadata{
Result: codersdk.WorkspaceAgentMetadataResult{

View File

@ -16,6 +16,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"cdr.dev/slog"
@ -25,7 +26,9 @@ import (
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/provisioner/echo"
@ -1107,6 +1110,162 @@ func TestWorkspaceAgent_Metadata(t *testing.T) {
post("unknown", unknownKeyMetadata)
}
type testWAMErrorStore struct {
database.Store
err atomic.Pointer[error]
}
func (s *testWAMErrorStore) GetWorkspaceAgentMetadata(ctx context.Context, arg database.GetWorkspaceAgentMetadataParams) ([]database.WorkspaceAgentMetadatum, error) {
err := s.err.Load()
if err != nil {
return nil, *err
}
return s.Store.GetWorkspaceAgentMetadata(ctx, arg)
}
func TestWorkspaceAgent_Metadata_CatchMemoryLeak(t *testing.T) {
t.Parallel()
db := &testWAMErrorStore{Store: dbmem.New()}
psub := pubsub.NewInMemory()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("coderd").Leveled(slog.LevelDebug)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: psub,
IncludeProvisionerDaemon: true,
Logger: &logger,
})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
ws := dbfake.Workspace(t, db, database.Workspace{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
})
dbfake.WorkspaceBuild(t, db, ws, database.WorkspaceBuild{}, &proto.Resource{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Metadata: []*proto.Agent_Metadata{
{
DisplayName: "First Meta",
Key: "foo1",
Script: "echo hi",
Interval: 10,
Timeout: 3,
},
{
DisplayName: "Second Meta",
Key: "foo2",
Script: "echo bye",
Interval: 10,
Timeout: 3,
},
},
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
})
workspace, err := client.Workspace(context.Background(), ws.ID)
require.NoError(t, err)
for _, res := range workspace.LatestBuild.Resources {
for _, a := range res.Agents {
require.Equal(t, codersdk.WorkspaceAgentLifecycleCreated, a.LifecycleState)
}
}
agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(authToken)
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitSuperLong))
manifest, err := agentClient.Manifest(ctx)
require.NoError(t, err)
post := func(ctx context.Context, key, value string) error {
return agentClient.PostMetadata(ctx, agentsdk.PostMetadataRequest{
Metadata: []agentsdk.Metadata{
{
Key: key,
WorkspaceAgentMetadataResult: codersdk.WorkspaceAgentMetadataResult{
CollectedAt: time.Now(),
Value: value,
},
},
},
})
}
workspace, err = client.Workspace(ctx, workspace.ID)
require.NoError(t, err, "get workspace")
// Start the SSE connection.
metadata, errors := client.WatchWorkspaceAgentMetadata(ctx, manifest.AgentID)
// Discard the output, pretending to be a client consuming it.
wantErr := xerrors.New("test error")
metadataDone := testutil.Go(t, func() {
for {
select {
case <-ctx.Done():
return
case _, ok := <-metadata:
if !ok {
return
}
case err := <-errors:
if err != nil && !strings.Contains(err.Error(), wantErr.Error()) {
assert.NoError(t, err, "watch metadata")
}
return
}
}
})
postDone := testutil.Go(t, func() {
for {
// We need to send two separate metadata updates to trigger the
// memory leak. foo2 will cause the number of foo1 to be doubled, etc.
err = post(ctx, "foo1", "hi")
if err != nil {
if !xerrors.Is(err, context.Canceled) {
assert.NoError(t, err, "post metadata foo1")
}
return
}
err = post(ctx, "foo2", "bye")
if err != nil {
if !xerrors.Is(err, context.Canceled) {
assert.NoError(t, err, "post metadata foo1")
}
return
}
}
})
// In a previously faulty implementation, this database error will trigger
// a close of the goroutine that consumes metadata updates for refreshing
// the metadata sent over SSE. As it was, the exit of the consumer was not
// detected as a trigger to close down the connection.
//
// Further, there was a memory leak in the pubsub subscription that cause
// ballooning of memory (almost double in size every received metadata).
//
// This db error should trigger a close of the SSE connection in the fixed
// implementation. The memory leak should not happen in either case, but
// testing it is not straightforward.
db.err.Store(&wantErr)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for SSE to close")
case <-metadataDone:
}
cancel()
<-postDone
}
func TestWorkspaceAgent_Startup(t *testing.T) {
t.Parallel()

View File

@ -492,6 +492,11 @@ func (c *Client) WatchWorkspaceAgentMetadata(ctx context.Context, id uuid.UUID)
firstEvent = false
}
// Ignore pings.
if sse.Type == ServerSentEventTypePing {
continue
}
b, ok := sse.Data.([]byte)
if !ok {
return xerrors.Errorf("unexpected data type: %T", sse.Data)

23
testutil/go.go Normal file
View File

@ -0,0 +1,23 @@
package testutil
import (
"testing"
)
// Go runs fn in a goroutine and waits until fn has completed before
// test completion. Done is returned for optionally waiting for fn to
// exit.
func Go(t *testing.T, fn func()) (done <-chan struct{}) {
t.Helper()
doneC := make(chan struct{})
t.Cleanup(func() {
<-doneC
})
go func() {
fn()
close(doneC)
}()
return doneC
}