diff --git a/cli/ssh.go b/cli/ssh.go index 990b972309..21437ee6ae 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -876,6 +876,7 @@ type closerStack struct { closed bool logger slog.Logger err error + wg sync.WaitGroup } func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack { @@ -893,10 +894,13 @@ func (c *closerStack) close(err error) { c.Lock() if c.closed { c.Unlock() + c.wg.Wait() return } c.closed = true c.err = err + c.wg.Add(1) + defer c.wg.Done() c.Unlock() for i := len(c.closers) - 1; i >= 0; i-- { diff --git a/cli/ssh_internal_test.go b/cli/ssh_internal_test.go index 0630deb36d..b612dd5ef9 100644 --- a/cli/ssh_internal_test.go +++ b/cli/ssh_internal_test.go @@ -4,6 +4,7 @@ import ( "context" "net/url" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -127,6 +128,43 @@ func TestCloserStack_PushAfterClose(t *testing.T) { require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1") } +func TestCloserStack_CloseAfterContext(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + defer cancel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + uut := newCloserStack(ctx, logger) + ac := &asyncCloser{ + t: t, + ctx: testCtx, + complete: make(chan struct{}), + started: make(chan struct{}), + } + err := uut.push("async", ac) + require.NoError(t, err) + cancel() + testutil.RequireRecvCtx(testCtx, t, ac.started) + + closed := make(chan struct{}) + go func() { + defer close(closed) + uut.close(nil) + }() + + // since the asyncCloser is still waiting, we shouldn't complete uut.close() + select { + case <-time.After(testutil.IntervalFast): + // OK! + case <-closed: + t.Fatal("closed before stack was finished") + } + + // complete the asyncCloser + close(ac.complete) + testutil.RequireRecvCtx(testCtx, t, closed) +} + type fakeCloser struct { closes *[]*fakeCloser err error @@ -136,3 +174,21 @@ func (c *fakeCloser) Close() error { *c.closes = append(*c.closes, c) return c.err } + +type asyncCloser struct { + t *testing.T + ctx context.Context + started chan struct{} + complete chan struct{} +} + +func (c *asyncCloser) Close() error { + close(c.started) + select { + case <-c.ctx.Done(): + c.t.Error("timed out") + return c.ctx.Err() + case <-c.complete: + return nil + } +}