mirror of https://github.com/coder/coder.git
fix: handle SIGHUP from OpenSSH (#10638)
Fixes an issue where remote forwards are not correctly torn down when using OpenSSH with `coder ssh --stdio`. OpenSSH sends a disconnect signal, but then also sends SIGHUP to `coder`. Previously, we just exited when we got SIGHUP, and this raced against properly disconnecting. Fixes https://github.com/coder/customers/issues/327
This commit is contained in:
parent
be0436afbe
commit
f400d8a0c5
|
@ -37,6 +37,7 @@ type forwardedUnixHandler struct {
|
|||
}
|
||||
|
||||
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
|
||||
h.log.Debug(ctx, "handling SSH unix forward")
|
||||
h.Lock()
|
||||
if h.forwards == nil {
|
||||
h.forwards = make(map[string]net.Listener)
|
||||
|
@ -47,22 +48,25 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
|
||||
return false, nil
|
||||
}
|
||||
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
|
||||
|
||||
switch req.Type {
|
||||
case "streamlocal-forward@openssh.com":
|
||||
var reqPayload streamLocalForwardPayload
|
||||
err := gossh.Unmarshal(req.Payload, &reqPayload)
|
||||
if err != nil {
|
||||
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request payload from client", slog.Error(err))
|
||||
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request (SSH unix forward) payload from client", slog.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
addr := reqPayload.SocketPath
|
||||
log = log.With(slog.F("socket_path", addr))
|
||||
log.Debug(ctx, "request begin SSH unix forward")
|
||||
h.Lock()
|
||||
_, ok := h.forwards[addr]
|
||||
h.Unlock()
|
||||
if ok {
|
||||
h.log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
|
||||
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
|
||||
slog.F("socket_path", addr),
|
||||
)
|
||||
return false, nil
|
||||
|
@ -72,9 +76,8 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
parentDir := filepath.Dir(addr)
|
||||
err = os.MkdirAll(parentDir, 0o700)
|
||||
if err != nil {
|
||||
h.log.Warn(ctx, "create parent dir for SSH unix forward request",
|
||||
log.Warn(ctx, "create parent dir for SSH unix forward request",
|
||||
slog.F("parent_dir", parentDir),
|
||||
slog.F("socket_path", addr),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false, nil
|
||||
|
@ -82,12 +85,13 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
|
||||
ln, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
h.log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
|
||||
log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
|
||||
slog.F("socket_path", addr),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false, nil
|
||||
}
|
||||
log.Debug(ctx, "SSH unix forward listening on socket")
|
||||
|
||||
// The listener needs to successfully start before it can be added to
|
||||
// the map, so we don't have to worry about checking for an existing
|
||||
|
@ -97,6 +101,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
h.Lock()
|
||||
h.forwards[addr] = ln
|
||||
h.Unlock()
|
||||
log.Debug(ctx, "SSH unix forward added to cache")
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
|
@ -110,14 +115,15 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
if !xerrors.Is(err, net.ErrClosed) {
|
||||
h.log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
|
||||
slog.F("socket_path", addr),
|
||||
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
// closed below
|
||||
log.Debug(ctx, "SSH unix forward listener closed")
|
||||
break
|
||||
}
|
||||
log.Debug(ctx, "accepted SSH unix forward connection")
|
||||
payload := gossh.Marshal(&forwardedStreamLocalPayload{
|
||||
SocketPath: addr,
|
||||
})
|
||||
|
@ -125,7 +131,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
go func() {
|
||||
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
|
||||
if err != nil {
|
||||
h.log.Warn(ctx, "open SSH channel to forward Unix connection to client",
|
||||
h.log.Warn(ctx, "open SSH unix forward channel to client",
|
||||
slog.F("socket_path", addr),
|
||||
slog.Error(err),
|
||||
)
|
||||
|
@ -143,6 +149,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
delete(h.forwards, addr)
|
||||
}
|
||||
h.Unlock()
|
||||
log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr))
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
|
@ -152,9 +159,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
var reqPayload streamLocalForwardPayload
|
||||
err := gossh.Unmarshal(req.Payload, &reqPayload)
|
||||
if err != nil {
|
||||
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err))
|
||||
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath))
|
||||
h.Lock()
|
||||
ln, ok := h.forwards[reqPayload.SocketPath]
|
||||
h.Unlock()
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"net/http/pprof"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
@ -144,7 +143,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
|
|||
// Note that we don't want to handle these signals in the
|
||||
// process that runs as PID 1, that's why we do this after
|
||||
// the reaper forked.
|
||||
ctx, stopNotify := signal.NotifyContext(ctx, InterruptSignals...)
|
||||
ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
||||
defer stopNotify()
|
||||
|
||||
// DumpHandler does signal handling, so we call it after the
|
||||
|
|
|
@ -7,7 +7,9 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
|
@ -183,6 +185,9 @@ type Invocation struct {
|
|||
Stdout io.Writer
|
||||
Stderr io.Writer
|
||||
Stdin io.Reader
|
||||
|
||||
// testing
|
||||
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
|
||||
}
|
||||
|
||||
// WithOS returns the invocation as a main package, filling in the invocation's unset
|
||||
|
@ -197,6 +202,26 @@ func (inv *Invocation) WithOS() *Invocation {
|
|||
})
|
||||
}
|
||||
|
||||
// WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext.
|
||||
// This should only be used in testing.
|
||||
func (inv *Invocation) WithTestSignalNotifyContext(
|
||||
_ testing.TB, // ensure we only call this from tests
|
||||
f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc),
|
||||
) *Invocation {
|
||||
return inv.with(func(i *Invocation) {
|
||||
i.signalNotifyContext = f
|
||||
})
|
||||
}
|
||||
|
||||
// SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in
|
||||
// tests.
|
||||
func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
|
||||
if inv.signalNotifyContext == nil {
|
||||
return signal.NotifyContext(parent, signals...)
|
||||
}
|
||||
return inv.signalNotifyContext(parent, signals...)
|
||||
}
|
||||
|
||||
func (inv *Invocation) Context() context.Context {
|
||||
if inv.ctx == nil {
|
||||
return context.Background()
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
package clitest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type FakeSignalNotifier struct {
|
||||
sync.Mutex
|
||||
t *testing.T
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
signals []os.Signal
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func NewFakeSignalNotifier(t *testing.T) *FakeSignalNotifier {
|
||||
fsn := &FakeSignalNotifier{t: t}
|
||||
return fsn
|
||||
}
|
||||
|
||||
func (f *FakeSignalNotifier) Stop() {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.stopped = true
|
||||
if f.cancel == nil {
|
||||
f.t.Error("stopped before started")
|
||||
return
|
||||
}
|
||||
f.cancel()
|
||||
}
|
||||
|
||||
func (f *FakeSignalNotifier) NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.signals = signals
|
||||
f.ctx, f.cancel = context.WithCancel(parent)
|
||||
return f.ctx, f.Stop
|
||||
}
|
||||
|
||||
func (f *FakeSignalNotifier) Notify() {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
if f.cancel == nil {
|
||||
f.t.Error("notified before started")
|
||||
return
|
||||
}
|
||||
f.cancel()
|
||||
}
|
||||
|
||||
func (f *FakeSignalNotifier) AssertStopped() {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
assert.True(f.t, f.stopped)
|
||||
}
|
|
@ -2,7 +2,6 @@ package cli
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os/signal"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
|
@ -63,7 +62,7 @@ fi
|
|||
Handler: func(inv *clibase.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
|
||||
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
||||
defer stop()
|
||||
|
||||
client, err := r.createAgentClient()
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -26,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd {
|
|||
Handler: func(inv *clibase.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
|
||||
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
||||
defer stop()
|
||||
|
||||
user, host, err := gitauth.ParseAskpass(inv.Args[0])
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
|
@ -30,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd {
|
|||
|
||||
// Catch interrupt signals to ensure the temporary private
|
||||
// key file is cleaned up on most cases.
|
||||
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
||||
defer stop()
|
||||
|
||||
// Early check so errors are reported immediately.
|
||||
|
|
|
@ -22,7 +22,6 @@ import (
|
|||
"net/http/pprof"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
@ -333,7 +332,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
|||
//
|
||||
// To get out of a graceful shutdown, the user can send
|
||||
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
|
||||
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
|
||||
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
||||
defer notifyStop()
|
||||
|
||||
cacheDir := vals.CacheDir.String()
|
||||
|
@ -1098,7 +1097,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
|||
logger = logger.Leveled(slog.LevelDebug)
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
|
||||
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
||||
defer cancel()
|
||||
|
||||
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)
|
||||
|
|
|
@ -4,7 +4,6 @@ package cli
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os/signal"
|
||||
"sort"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -48,7 +47,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd {
|
|||
logger = logger.Leveled(slog.LevelDebug)
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
|
||||
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
||||
defer cancel()
|
||||
|
||||
if newUserDBURL == "" {
|
||||
|
|
13
cli/ssh.go
13
cli/ssh.go
|
@ -62,7 +62,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *clibase.Invocation) (retErr error) {
|
||||
ctx, cancel := context.WithCancel(inv.Context())
|
||||
// Before dialing the SSH server over TCP, capture Interrupt signals
|
||||
// so that if we are interrupted, we have a chance to tear down the
|
||||
// TCP session cleanly before exiting. If we don't, then the TCP
|
||||
// session can persist for up to 72 hours, since we set a long
|
||||
// timeout on the Agent side of the connection. In particular,
|
||||
// OpenSSH sends SIGHUP to terminate a proxy command.
|
||||
ctx, stop := inv.SignalNotifyContext(inv.Context(), InterruptSignals...)
|
||||
defer stop()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
logger := slog.Make() // empty logger
|
||||
|
@ -227,8 +235,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
// Ensure stdout copy closes incase stdin is closed
|
||||
// unexpectedly. Typically we wouldn't worry about
|
||||
// this since OpenSSH should kill the proxy command.
|
||||
// unexpectedly.
|
||||
defer rawSSH.Close()
|
||||
|
||||
_, err := io.Copy(rawSSH, inv.Stdin)
|
||||
|
|
116
cli/ssh_test.go
116
cli/ssh_test.go
|
@ -14,12 +14,15 @@ import (
|
|||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -245,6 +248,119 @@ func TestSSH(t *testing.T) {
|
|||
<-cmdDone
|
||||
})
|
||||
|
||||
// Test that we handle OS signals properly while remote forwarding, and don't just leave the TCP
|
||||
// socket hanging.
|
||||
t.Run("RemoteForward_Unix_Signal", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("No unix sockets on windows")
|
||||
}
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
|
||||
_, _ = tGoContext(t, func(ctx context.Context) {
|
||||
// Run this async so the SSH command has to wait for
|
||||
// the build and agent to connect!
|
||||
_ = agenttest.New(t, client.URL, agentToken)
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
tmpdir := tempDirUnixSocket(t)
|
||||
localSock := filepath.Join(tmpdir, "local.sock")
|
||||
l, err := net.Listen("unix", localSock)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
remoteSock := path.Join(tmpdir, "remote.sock")
|
||||
for i := 0; i < 2; i++ {
|
||||
t.Logf("connect %d of 2", i+1)
|
||||
inv, root := clitest.New(t,
|
||||
"ssh",
|
||||
workspace.Name,
|
||||
"--remote-forward",
|
||||
remoteSock+":"+localSock,
|
||||
)
|
||||
fsn := clitest.NewFakeSignalNotifier(t)
|
||||
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
|
||||
inv.Stdout = io.Discard
|
||||
inv.Stderr = io.Discard
|
||||
|
||||
clitest.SetupConfig(t, client, root)
|
||||
cmdDone := tGo(t, func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
// accept a single connection
|
||||
msgs := make(chan string, 1)
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
msg, err := io.ReadAll(conn)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
msgs <- string(msg)
|
||||
}()
|
||||
|
||||
// Unfortunately, there is a race in crypto/ssh where it sends the request to forward
|
||||
// unix sockets before it is prepared to receive the response, meaning that even after
|
||||
// the socket exists on the file system, the client might not be ready to accept the
|
||||
// channel.
|
||||
//
|
||||
// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33
|
||||
//
|
||||
// To work around this, we attempt to send messages in a loop until one succeeds
|
||||
success := make(chan struct{})
|
||||
go func() {
|
||||
var (
|
||||
conn net.Conn
|
||||
err error
|
||||
)
|
||||
for {
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout")
|
||||
return
|
||||
case <-success:
|
||||
return
|
||||
default:
|
||||
// Ok
|
||||
}
|
||||
conn, err = net.Dial("unix", remoteSock)
|
||||
if err != nil {
|
||||
t.Logf("dial error: %s", err)
|
||||
continue
|
||||
}
|
||||
_, err = conn.Write([]byte("test"))
|
||||
if err != nil {
|
||||
t.Logf("write error: %s", err)
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Logf("close error: %s", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
msg := testutil.RequireRecvCtx(ctx, t, msgs)
|
||||
require.Equal(t, "test", msg)
|
||||
close(success)
|
||||
fsn.Notify()
|
||||
<-cmdDone
|
||||
fsn.AssertStopped()
|
||||
|
||||
// wait for the remote socket to get cleaned up before retrying,
|
||||
// because cleaning up the socket happens asynchronously, and we
|
||||
// might connect to an old listener on the agent side.
|
||||
require.Eventually(t, func() bool {
|
||||
_, err = os.Stat(remoteSock)
|
||||
return xerrors.Is(err, os.ErrNotExist)
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("StdioExitOnStop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -61,7 +60,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
|
|||
ctx, cancel := context.WithCancel(inv.Context())
|
||||
defer cancel()
|
||||
|
||||
notifyCtx, notifyStop := signal.NotifyContext(ctx, agpl.InterruptSignals...)
|
||||
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, agpl.InterruptSignals...)
|
||||
defer notifyStop()
|
||||
|
||||
tags, err := agpl.ParseProvisionerTags(rawTags)
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
rpprof "runtime/pprof"
|
||||
"time"
|
||||
|
@ -142,7 +141,7 @@ func (*RootCmd) proxyServer() *clibase.Cmd {
|
|||
//
|
||||
// To get out of a graceful shutdown, the user can send
|
||||
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
|
||||
notifyCtx, notifyStop := signal.NotifyContext(ctx, cli.InterruptSignals...)
|
||||
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, cli.InterruptSignals...)
|
||||
defer notifyStop()
|
||||
|
||||
// Clean up idle connections at the end, e.g.
|
||||
|
|
|
@ -788,6 +788,7 @@ func (c *Conn) Close() error {
|
|||
}
|
||||
|
||||
_ = c.netStack.Close()
|
||||
c.logger.Debug(context.Background(), "closed netstack")
|
||||
c.dialCancel()
|
||||
_ = c.wireguardMonitor.Close()
|
||||
_ = c.dialer.Close()
|
||||
|
|
|
@ -11,3 +11,14 @@ func Context(t *testing.T, dur time.Duration) context.Context {
|
|||
t.Cleanup(cancel)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
return a
|
||||
case a = <-c:
|
||||
return a
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue