mirror of https://github.com/coder/coder.git
fix(coderd): fix memory leak in `watchWorkspaceAgentMetadata` (#10685)
Fixes #10550
This commit is contained in:
parent
c130f8d6d0
commit
198b56c137
|
@ -1785,10 +1785,10 @@ func (api *API) workspaceAgentUpdateMetadata(ctx context.Context, workspaceAgent
|
|||
|
||||
datum := database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
Key: []string{},
|
||||
Value: []string{},
|
||||
Error: []string{},
|
||||
CollectedAt: []time.Time{},
|
||||
Key: make([]string, 0, len(req.Metadata)),
|
||||
Value: make([]string, 0, len(req.Metadata)),
|
||||
Error: make([]string, 0, len(req.Metadata)),
|
||||
CollectedAt: make([]time.Time, 0, len(req.Metadata)),
|
||||
}
|
||||
|
||||
for _, md := range req.Metadata {
|
||||
|
@ -1853,22 +1853,28 @@ func (api *API) workspaceAgentUpdateMetadata(ctx context.Context, workspaceAgent
|
|||
// @Router /workspaceagents/{workspaceagent}/watch-metadata [get]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
workspaceAgent = httpmw.WorkspaceAgentParam(r)
|
||||
log = api.Logger.Named("workspace_metadata_watcher").With(
|
||||
slog.F("workspace_agent_id", workspaceAgent.ID),
|
||||
)
|
||||
// Allow us to interrupt watch via cancel.
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx) // Rewire context for SSE cancellation.
|
||||
|
||||
workspaceAgent := httpmw.WorkspaceAgentParam(r)
|
||||
log := api.Logger.Named("workspace_metadata_watcher").With(
|
||||
slog.F("workspace_agent_id", workspaceAgent.ID),
|
||||
)
|
||||
|
||||
// Send metadata on updates, we must ensure subscription before sending
|
||||
// initial metadata to guarantee that events in-between are not missed.
|
||||
update := make(chan workspaceAgentMetadataChannelPayload, 1)
|
||||
cancelSub, err := api.Pubsub.Subscribe(watchWorkspaceAgentMetadataChannel(workspaceAgent.ID), func(_ context.Context, byt []byte) {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var payload workspaceAgentMetadataChannelPayload
|
||||
err := json.Unmarshal(byt, &payload)
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err))
|
||||
log.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1876,18 +1882,7 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
|
|||
|
||||
select {
|
||||
case prev := <-update:
|
||||
// This update wasn't consumed yet, merge the keys.
|
||||
newKeysSet := make(map[string]struct{})
|
||||
for _, key := range payload.Keys {
|
||||
newKeysSet[key] = struct{}{}
|
||||
}
|
||||
keys := prev.Keys
|
||||
for _, key := range prev.Keys {
|
||||
if _, ok := newKeysSet[key]; !ok {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
payload.Keys = keys
|
||||
payload.Keys = appendUnique(prev.Keys, payload.Keys)
|
||||
default:
|
||||
}
|
||||
// This can never block since we pop and merge beforehand.
|
||||
|
@ -1899,22 +1894,9 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
|
|||
}
|
||||
defer cancelSub()
|
||||
|
||||
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error setting up server-sent events.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Prevent handler from returning until the sender is closed.
|
||||
defer func() {
|
||||
<-sseSenderClosed
|
||||
}()
|
||||
|
||||
// We always use the original Request context because it contains
|
||||
// the RBAC actor.
|
||||
md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
|
||||
initialMD, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
Keys: nil,
|
||||
})
|
||||
|
@ -1926,15 +1908,45 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
|
|||
return
|
||||
}
|
||||
|
||||
metadataMap := make(map[string]database.WorkspaceAgentMetadatum)
|
||||
for _, datum := range md {
|
||||
log.Debug(ctx, "got initial metadata", "num", len(initialMD))
|
||||
|
||||
metadataMap := make(map[string]database.WorkspaceAgentMetadatum, len(initialMD))
|
||||
for _, datum := range initialMD {
|
||||
metadataMap[datum.Key] = datum
|
||||
}
|
||||
//nolint:ineffassign // Release memory.
|
||||
initialMD = nil
|
||||
|
||||
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error setting up server-sent events.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Prevent handler from returning until the sender is closed.
|
||||
defer func() {
|
||||
cancel()
|
||||
<-sseSenderClosed
|
||||
}()
|
||||
// Synchronize cancellation from SSE -> context, this lets us simplify the
|
||||
// cancellation logic.
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-sseSenderClosed:
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
var lastSend time.Time
|
||||
sendMetadata := func() {
|
||||
lastSend = time.Now()
|
||||
values := maps.Values(metadataMap)
|
||||
|
||||
log.Debug(ctx, "sending metadata", "num", len(values))
|
||||
|
||||
_ = sseSendEvent(ctx, codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: convertWorkspaceAgentMetadata(values),
|
||||
|
@ -1953,10 +1965,11 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
|
|||
fetchedMetadata := make(chan []database.WorkspaceAgentMetadatum)
|
||||
go func() {
|
||||
defer close(fetchedMetadata)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sseSenderClosed:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case payload := <-update:
|
||||
md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{
|
||||
|
@ -1966,24 +1979,35 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
|
|||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
log.Error(ctx, "failed to get metadata", slog.Error(err))
|
||||
_ = sseSendEvent(ctx, codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Failed to get metadata.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-sseSenderClosed:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
// We want to block here to avoid constantly pinging the
|
||||
// database when the metadata isn't being processed.
|
||||
case fetchedMetadata <- md:
|
||||
log.Debug(ctx, "fetched metadata update for keys", "keys", payload.Keys, "num", len(md))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
<-fetchedMetadata
|
||||
}()
|
||||
|
||||
pendingChanges := true
|
||||
for {
|
||||
select {
|
||||
case <-sseSenderClosed:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case md, ok := <-fetchedMetadata:
|
||||
if !ok {
|
||||
|
@ -2007,9 +2031,24 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
|
|||
}
|
||||
}
|
||||
|
||||
// appendUnique is like append and adds elements from src to dst,
|
||||
// skipping any elements that already exist in dst.
|
||||
func appendUnique[T comparable](dst, src []T) []T {
|
||||
exists := make(map[T]struct{}, len(dst))
|
||||
for _, key := range dst {
|
||||
exists[key] = struct{}{}
|
||||
}
|
||||
for _, key := range src {
|
||||
if _, ok := exists[key]; !ok {
|
||||
dst = append(dst, key)
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []codersdk.WorkspaceAgentMetadata {
|
||||
// An empty array is easier for clients to handle than a null.
|
||||
result := []codersdk.WorkspaceAgentMetadata{}
|
||||
result := make([]codersdk.WorkspaceAgentMetadata, 0, len(db))
|
||||
for _, datum := range db {
|
||||
result = append(result, codersdk.WorkspaceAgentMetadata{
|
||||
Result: codersdk.WorkspaceAgentMetadataResult{
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
@ -25,7 +26,9 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmem"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
|
@ -1107,6 +1110,162 @@ func TestWorkspaceAgent_Metadata(t *testing.T) {
|
|||
post("unknown", unknownKeyMetadata)
|
||||
}
|
||||
|
||||
type testWAMErrorStore struct {
|
||||
database.Store
|
||||
err atomic.Pointer[error]
|
||||
}
|
||||
|
||||
func (s *testWAMErrorStore) GetWorkspaceAgentMetadata(ctx context.Context, arg database.GetWorkspaceAgentMetadataParams) ([]database.WorkspaceAgentMetadatum, error) {
|
||||
err := s.err.Load()
|
||||
if err != nil {
|
||||
return nil, *err
|
||||
}
|
||||
return s.Store.GetWorkspaceAgentMetadata(ctx, arg)
|
||||
}
|
||||
|
||||
func TestWorkspaceAgent_Metadata_CatchMemoryLeak(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := &testWAMErrorStore{Store: dbmem.New()}
|
||||
psub := pubsub.NewInMemory()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("coderd").Leveled(slog.LevelDebug)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: psub,
|
||||
IncludeProvisionerDaemon: true,
|
||||
Logger: &logger,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
ws := dbfake.Workspace(t, db, database.Workspace{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
})
|
||||
dbfake.WorkspaceBuild(t, db, ws, database.WorkspaceBuild{}, &proto.Resource{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
Agents: []*proto.Agent{{
|
||||
Metadata: []*proto.Agent_Metadata{
|
||||
{
|
||||
DisplayName: "First Meta",
|
||||
Key: "foo1",
|
||||
Script: "echo hi",
|
||||
Interval: 10,
|
||||
Timeout: 3,
|
||||
},
|
||||
{
|
||||
DisplayName: "Second Meta",
|
||||
Key: "foo2",
|
||||
Script: "echo bye",
|
||||
Interval: 10,
|
||||
Timeout: 3,
|
||||
},
|
||||
},
|
||||
Id: uuid.NewString(),
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
})
|
||||
workspace, err := client.Workspace(context.Background(), ws.ID)
|
||||
require.NoError(t, err)
|
||||
for _, res := range workspace.LatestBuild.Resources {
|
||||
for _, a := range res.Agents {
|
||||
require.Equal(t, codersdk.WorkspaceAgentLifecycleCreated, a.LifecycleState)
|
||||
}
|
||||
}
|
||||
|
||||
agentClient := agentsdk.New(client.URL)
|
||||
agentClient.SetSessionToken(authToken)
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitSuperLong))
|
||||
|
||||
manifest, err := agentClient.Manifest(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
post := func(ctx context.Context, key, value string) error {
|
||||
return agentClient.PostMetadata(ctx, agentsdk.PostMetadataRequest{
|
||||
Metadata: []agentsdk.Metadata{
|
||||
{
|
||||
Key: key,
|
||||
WorkspaceAgentMetadataResult: codersdk.WorkspaceAgentMetadataResult{
|
||||
CollectedAt: time.Now(),
|
||||
Value: value,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err, "get workspace")
|
||||
|
||||
// Start the SSE connection.
|
||||
metadata, errors := client.WatchWorkspaceAgentMetadata(ctx, manifest.AgentID)
|
||||
|
||||
// Discard the output, pretending to be a client consuming it.
|
||||
wantErr := xerrors.New("test error")
|
||||
metadataDone := testutil.Go(t, func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case _, ok := <-metadata:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case err := <-errors:
|
||||
if err != nil && !strings.Contains(err.Error(), wantErr.Error()) {
|
||||
assert.NoError(t, err, "watch metadata")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
postDone := testutil.Go(t, func() {
|
||||
for {
|
||||
// We need to send two separate metadata updates to trigger the
|
||||
// memory leak. foo2 will cause the number of foo1 to be doubled, etc.
|
||||
err = post(ctx, "foo1", "hi")
|
||||
if err != nil {
|
||||
if !xerrors.Is(err, context.Canceled) {
|
||||
assert.NoError(t, err, "post metadata foo1")
|
||||
}
|
||||
return
|
||||
}
|
||||
err = post(ctx, "foo2", "bye")
|
||||
if err != nil {
|
||||
if !xerrors.Is(err, context.Canceled) {
|
||||
assert.NoError(t, err, "post metadata foo1")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// In a previously faulty implementation, this database error will trigger
|
||||
// a close of the goroutine that consumes metadata updates for refreshing
|
||||
// the metadata sent over SSE. As it was, the exit of the consumer was not
|
||||
// detected as a trigger to close down the connection.
|
||||
//
|
||||
// Further, there was a memory leak in the pubsub subscription that cause
|
||||
// ballooning of memory (almost double in size every received metadata).
|
||||
//
|
||||
// This db error should trigger a close of the SSE connection in the fixed
|
||||
// implementation. The memory leak should not happen in either case, but
|
||||
// testing it is not straightforward.
|
||||
db.err.Store(&wantErr)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for SSE to close")
|
||||
case <-metadataDone:
|
||||
}
|
||||
cancel()
|
||||
<-postDone
|
||||
}
|
||||
|
||||
func TestWorkspaceAgent_Startup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -492,6 +492,11 @@ func (c *Client) WatchWorkspaceAgentMetadata(ctx context.Context, id uuid.UUID)
|
|||
firstEvent = false
|
||||
}
|
||||
|
||||
// Ignore pings.
|
||||
if sse.Type == ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
|
||||
b, ok := sse.Data.([]byte)
|
||||
if !ok {
|
||||
return xerrors.Errorf("unexpected data type: %T", sse.Data)
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package testutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Go runs fn in a goroutine and waits until fn has completed before
|
||||
// test completion. Done is returned for optionally waiting for fn to
|
||||
// exit.
|
||||
func Go(t *testing.T, fn func()) (done <-chan struct{}) {
|
||||
t.Helper()
|
||||
|
||||
doneC := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
<-doneC
|
||||
})
|
||||
go func() {
|
||||
fn()
|
||||
close(doneC)
|
||||
}()
|
||||
|
||||
return doneC
|
||||
}
|
Loading…
Reference in New Issue