From b96f6b48a46a979496d48e515e66ea587fd1f0a1 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 7 Mar 2024 17:26:49 +0400 Subject: [PATCH] fix: ensure ssh cleanup happens on cmd error I noticed in my logs that sometimes `coder ssh` doesn't gracefully disconnect from the coordinator. The cause is the `closerStack` construct we use in that function. It has two paths to start closing things down: 1. explicit `close()` which we do in `defer` 2. context cancellation, which happens if the cli function returns an error sometimes the ssh remote command returns an error, and this triggers context cancellation of the `closerStack`. That is fine in and of itself, but we still want the explicit `close()` to wait until everything is closed before returning, since that's where we do cleanup, including the graceful disconnect. Prior to this fix the `close()` just immediately exits if another goroutine is closing the stack. Here we add a wait until everything is done. --- cli/ssh.go | 4 +++ cli/ssh_internal_test.go | 56 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/cli/ssh.go b/cli/ssh.go index 990b972309..21437ee6ae 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -876,6 +876,7 @@ type closerStack struct { closed bool logger slog.Logger err error + wg sync.WaitGroup } func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack { @@ -893,10 +894,13 @@ func (c *closerStack) close(err error) { c.Lock() if c.closed { c.Unlock() + c.wg.Wait() return } c.closed = true c.err = err + c.wg.Add(1) + defer c.wg.Done() c.Unlock() for i := len(c.closers) - 1; i >= 0; i-- { diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index 0630deb36d..b612dd5ef9 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -4,6 +4,7 @@ import ( "context" "net/url" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -127,6 +128,43 @@ func TestCloserStack_PushAfterClose(t *testing.T) { require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1") } +func TestCloserStack_CloseAfterContext(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + defer cancel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + uut := newCloserStack(ctx, logger) + ac := &asyncCloser{ + t: t, + ctx: testCtx, + complete: make(chan struct{}), + started: make(chan struct{}), + } + err := uut.push("async", ac) + require.NoError(t, err) + cancel() + testutil.RequireRecvCtx(testCtx, t, ac.started) + + closed := make(chan struct{}) + go func() { + defer close(closed) + uut.close(nil) + }() + + // since the asyncCloser is still waiting, we shouldn't complete uut.close() + select { + case <-time.After(testutil.IntervalFast): + // OK! + case <-closed: + t.Fatal("closed before stack was finished") + } + + // complete the asyncCloser + close(ac.complete) + testutil.RequireRecvCtx(testCtx, t, closed) +} + type fakeCloser struct { closes *[]*fakeCloser err error @@ -136,3 +174,21 @@ func (c *fakeCloser) Close() error { *c.closes = append(*c.closes, c) return c.err } + +type asyncCloser struct { + t *testing.T + ctx context.Context + started chan struct{} + complete chan struct{} +} + +func (c *asyncCloser) Close() error { + close(c.started) + select { + case <-c.ctx.Done(): + c.t.Error("timed out") + return c.ctx.Err() + case <-c.complete: + return nil + } +}