fix: ensure ssh cleanup happens on cmd error

I noticed in my logs that sometimes `coder ssh` doesn't gracefully disconnect from the coordinator.

The cause is the `closerStack` construct we use in that function.  It has two paths to start closing things down:

1. explicit `close()` which we do in `defer`
2. context cancellation, which happens if the cli function returns an error

sometimes the ssh remote command returns an error, and this triggers context cancellation of the `closerStack`.  That is fine in and of itself, but we still want the explicit `close()` to wait until everything is closed before returning, since that's where we do cleanup, including the graceful disconnect.  Prior to this fix the `close()` just immediately exits if another goroutine is closing the stack.  Here we add a wait until everything is done.
This commit is contained in:
Spike Curtis 2024-03-07 17:26:49 +04:00 committed by GitHub
parent c8aa99a5b8
commit b96f6b48a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 0 deletions

View File

@ -876,6 +876,7 @@ type closerStack struct {
closed bool
logger slog.Logger
err error
wg sync.WaitGroup
}
func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack {
@ -893,10 +894,13 @@ func (c *closerStack) close(err error) {
c.Lock()
if c.closed {
c.Unlock()
c.wg.Wait()
return
}
c.closed = true
c.err = err
c.wg.Add(1)
defer c.wg.Done()
c.Unlock()
for i := len(c.closers) - 1; i >= 0; i-- {

View File

@ -4,6 +4,7 @@ import (
"context"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -127,6 +128,43 @@ func TestCloserStack_PushAfterClose(t *testing.T) {
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1")
}
func TestCloserStack_CloseAfterContext(t *testing.T) {
t.Parallel()
testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger)
ac := &asyncCloser{
t: t,
ctx: testCtx,
complete: make(chan struct{}),
started: make(chan struct{}),
}
err := uut.push("async", ac)
require.NoError(t, err)
cancel()
testutil.RequireRecvCtx(testCtx, t, ac.started)
closed := make(chan struct{})
go func() {
defer close(closed)
uut.close(nil)
}()
// since the asyncCloser is still waiting, we shouldn't complete uut.close()
select {
case <-time.After(testutil.IntervalFast):
// OK!
case <-closed:
t.Fatal("closed before stack was finished")
}
// complete the asyncCloser
close(ac.complete)
testutil.RequireRecvCtx(testCtx, t, closed)
}
type fakeCloser struct {
closes *[]*fakeCloser
err error
@ -136,3 +174,21 @@ func (c *fakeCloser) Close() error {
*c.closes = append(*c.closes, c)
return c.err
}
type asyncCloser struct {
t *testing.T
ctx context.Context
started chan struct{}
complete chan struct{}
}
func (c *asyncCloser) Close() error {
close(c.started)
select {
case <-c.ctx.Done():
c.t.Error("timed out")
return c.ctx.Err()
case <-c.complete:
return nil
}
}