fix: handle SIGHUP from OpenSSH (#10638)

Fixes an issue where remote forwards are not correctly torn down when using OpenSSH with `coder ssh --stdio`.  OpenSSH sends a disconnect signal, but then also sends SIGHUP to `coder`.  Previously, we just exited when we got SIGHUP, and this raced against properly disconnecting.

Fixes https://github.com/coder/customers/issues/327
This commit is contained in:
Spike Curtis 2023-11-13 15:14:42 +04:00 committed by GitHub
parent be0436afbe
commit f400d8a0c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 248 additions and 29 deletions

View File

@ -37,6 +37,7 @@ type forwardedUnixHandler struct {
}
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
h.log.Debug(ctx, "handling SSH unix forward")
h.Lock()
if h.forwards == nil {
h.forwards = make(map[string]net.Listener)
@ -47,22 +48,25 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
return false, nil
}
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
switch req.Type {
case "streamlocal-forward@openssh.com":
var reqPayload streamLocalForwardPayload
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request payload from client", slog.Error(err))
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request (SSH unix forward) payload from client", slog.Error(err))
return false, nil
}
addr := reqPayload.SocketPath
log = log.With(slog.F("socket_path", addr))
log.Debug(ctx, "request begin SSH unix forward")
h.Lock()
_, ok := h.forwards[addr]
h.Unlock()
if ok {
h.log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
slog.F("socket_path", addr),
)
return false, nil
@ -72,9 +76,8 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
parentDir := filepath.Dir(addr)
err = os.MkdirAll(parentDir, 0o700)
if err != nil {
h.log.Warn(ctx, "create parent dir for SSH unix forward request",
log.Warn(ctx, "create parent dir for SSH unix forward request",
slog.F("parent_dir", parentDir),
slog.F("socket_path", addr),
slog.Error(err),
)
return false, nil
@ -82,12 +85,13 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
ln, err := net.Listen("unix", addr)
if err != nil {
h.log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
slog.F("socket_path", addr),
slog.Error(err),
)
return false, nil
}
log.Debug(ctx, "SSH unix forward listening on socket")
// The listener needs to successfully start before it can be added to
// the map, so we don't have to worry about checking for an existing
@ -97,6 +101,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
h.Lock()
h.forwards[addr] = ln
h.Unlock()
log.Debug(ctx, "SSH unix forward added to cache")
ctx, cancel := context.WithCancel(ctx)
go func() {
@ -110,14 +115,15 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
c, err := ln.Accept()
if err != nil {
if !xerrors.Is(err, net.ErrClosed) {
h.log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
slog.F("socket_path", addr),
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
slog.Error(err),
)
}
// closed below
log.Debug(ctx, "SSH unix forward listener closed")
break
}
log.Debug(ctx, "accepted SSH unix forward connection")
payload := gossh.Marshal(&forwardedStreamLocalPayload{
SocketPath: addr,
})
@ -125,7 +131,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
go func() {
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
if err != nil {
h.log.Warn(ctx, "open SSH channel to forward Unix connection to client",
h.log.Warn(ctx, "open SSH unix forward channel to client",
slog.F("socket_path", addr),
slog.Error(err),
)
@ -143,6 +149,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
delete(h.forwards, addr)
}
h.Unlock()
log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr))
_ = ln.Close()
}()
@ -152,9 +159,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
var reqPayload streamLocalForwardPayload
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err))
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err))
return false, nil
}
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath))
h.Lock()
ln, ok := h.forwards[reqPayload.SocketPath]
h.Unlock()

View File

@ -8,7 +8,6 @@ import (
"net/http/pprof"
"net/url"
"os"
"os/signal"
"path/filepath"
"runtime"
"strconv"
@ -144,7 +143,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
// Note that we don't want to handle these signals in the
// process that runs as PID 1, that's why we do this after
// the reaper forked.
ctx, stopNotify := signal.NotifyContext(ctx, InterruptSignals...)
ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stopNotify()
// DumpHandler does signal handling, so we call it after the

View File

@ -7,7 +7,9 @@ import (
"fmt"
"io"
"os"
"os/signal"
"strings"
"testing"
"unicode"
"github.com/spf13/pflag"
@ -183,6 +185,9 @@ type Invocation struct {
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
// testing
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
}
// WithOS returns the invocation as a main package, filling in the invocation's unset
@ -197,6 +202,26 @@ func (inv *Invocation) WithOS() *Invocation {
})
}
// WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext.
// This should only be used in testing.
func (inv *Invocation) WithTestSignalNotifyContext(
_ testing.TB, // ensure we only call this from tests
f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc),
) *Invocation {
return inv.with(func(i *Invocation) {
i.signalNotifyContext = f
})
}
// SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in
// tests.
func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
if inv.signalNotifyContext == nil {
return signal.NotifyContext(parent, signals...)
}
return inv.signalNotifyContext(parent, signals...)
}
func (inv *Invocation) Context() context.Context {
if inv.ctx == nil {
return context.Background()

59
cli/clitest/signal.go Normal file
View File

@ -0,0 +1,59 @@
package clitest
import (
"context"
"os"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
type FakeSignalNotifier struct {
sync.Mutex
t *testing.T
ctx context.Context
cancel context.CancelFunc
signals []os.Signal
stopped bool
}
func NewFakeSignalNotifier(t *testing.T) *FakeSignalNotifier {
fsn := &FakeSignalNotifier{t: t}
return fsn
}
func (f *FakeSignalNotifier) Stop() {
f.Lock()
defer f.Unlock()
f.stopped = true
if f.cancel == nil {
f.t.Error("stopped before started")
return
}
f.cancel()
}
func (f *FakeSignalNotifier) NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
f.Lock()
defer f.Unlock()
f.signals = signals
f.ctx, f.cancel = context.WithCancel(parent)
return f.ctx, f.Stop
}
func (f *FakeSignalNotifier) Notify() {
f.Lock()
defer f.Unlock()
if f.cancel == nil {
f.t.Error("notified before started")
return
}
f.cancel()
}
func (f *FakeSignalNotifier) AssertStopped() {
f.Lock()
defer f.Unlock()
assert.True(f.t, f.stopped)
}

View File

@ -2,7 +2,6 @@ package cli
import (
"encoding/json"
"os/signal"
"golang.org/x/xerrors"
@ -63,7 +62,7 @@ fi
Handler: func(inv *clibase.Invocation) error {
ctx := inv.Context()
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stop()
client, err := r.createAgentClient()

View File

@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"os/signal"
"time"
"golang.org/x/xerrors"
@ -26,7 +25,7 @@ func (r *RootCmd) gitAskpass() *clibase.Cmd {
Handler: func(inv *clibase.Invocation) error {
ctx := inv.Context()
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stop()
user, host, err := gitauth.ParseAskpass(inv.Args[0])

View File

@ -8,7 +8,6 @@ import (
"io"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
@ -30,7 +29,7 @@ func (r *RootCmd) gitssh() *clibase.Cmd {
// Catch interrupt signals to ensure the temporary private
// key file is cleaned up on most cases.
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
ctx, stop := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stop()
// Early check so errors are reported immediately.

View File

@ -22,7 +22,6 @@ import (
"net/http/pprof"
"net/url"
"os"
"os/signal"
"os/user"
"path/filepath"
"regexp"
@ -333,7 +332,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
//
// To get out of a graceful shutdown, the user can send
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer notifyStop()
cacheDir := vals.CacheDir.String()
@ -1098,7 +1097,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
logger = logger.Leveled(slog.LevelDebug)
}
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer cancel()
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)

View File

@ -4,7 +4,6 @@ package cli
import (
"fmt"
"os/signal"
"sort"
"github.com/google/uuid"
@ -48,7 +47,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd {
logger = logger.Leveled(slog.LevelDebug)
}
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer cancel()
if newUserDBURL == "" {

View File

@ -62,7 +62,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) (retErr error) {
ctx, cancel := context.WithCancel(inv.Context())
// Before dialing the SSH server over TCP, capture Interrupt signals
// so that if we are interrupted, we have a chance to tear down the
// TCP session cleanly before exiting. If we don't, then the TCP
// session can persist for up to 72 hours, since we set a long
// timeout on the Agent side of the connection. In particular,
// OpenSSH sends SIGHUP to terminate a proxy command.
ctx, stop := inv.SignalNotifyContext(inv.Context(), InterruptSignals...)
defer stop()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger := slog.Make() // empty logger
@ -227,8 +235,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
go func() {
defer wg.Done()
// Ensure stdout copy closes incase stdin is closed
// unexpectedly. Typically we wouldn't worry about
// this since OpenSSH should kill the proxy command.
// unexpectedly.
defer rawSSH.Close()
_, err := io.Copy(rawSSH, inv.Stdin)

View File

@ -14,12 +14,15 @@ import (
"net/http/httptest"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"golang.org/x/xerrors"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -245,6 +248,119 @@ func TestSSH(t *testing.T) {
<-cmdDone
})
// Test that we handle OS signals properly while remote forwarding, and don't just leave the TCP
// socket hanging.
t.Run("RemoteForward_Unix_Signal", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("No unix sockets on windows")
}
t.Parallel()
ctx := testutil.Context(t, testutil.WaitSuperLong)
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
_, _ = tGoContext(t, func(ctx context.Context) {
// Run this async so the SSH command has to wait for
// the build and agent to connect!
_ = agenttest.New(t, client.URL, agentToken)
<-ctx.Done()
})
tmpdir := tempDirUnixSocket(t)
localSock := filepath.Join(tmpdir, "local.sock")
l, err := net.Listen("unix", localSock)
require.NoError(t, err)
defer l.Close()
remoteSock := path.Join(tmpdir, "remote.sock")
for i := 0; i < 2; i++ {
t.Logf("connect %d of 2", i+1)
inv, root := clitest.New(t,
"ssh",
workspace.Name,
"--remote-forward",
remoteSock+":"+localSock,
)
fsn := clitest.NewFakeSignalNotifier(t)
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
inv.Stdout = io.Discard
inv.Stderr = io.Discard
clitest.SetupConfig(t, client, root)
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.Error(t, err)
})
// accept a single connection
msgs := make(chan string, 1)
go func() {
conn, err := l.Accept()
if !assert.NoError(t, err) {
return
}
msg, err := io.ReadAll(conn)
if !assert.NoError(t, err) {
return
}
msgs <- string(msg)
}()
// Unfortunately, there is a race in crypto/ssh where it sends the request to forward
// unix sockets before it is prepared to receive the response, meaning that even after
// the socket exists on the file system, the client might not be ready to accept the
// channel.
//
// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33
//
// To work around this, we attempt to send messages in a loop until one succeeds
success := make(chan struct{})
go func() {
var (
conn net.Conn
err error
)
for {
time.Sleep(testutil.IntervalMedium)
select {
case <-ctx.Done():
t.Error("timeout")
return
case <-success:
return
default:
// Ok
}
conn, err = net.Dial("unix", remoteSock)
if err != nil {
t.Logf("dial error: %s", err)
continue
}
_, err = conn.Write([]byte("test"))
if err != nil {
t.Logf("write error: %s", err)
}
err = conn.Close()
if err != nil {
t.Logf("close error: %s", err)
}
}
}()
msg := testutil.RequireRecvCtx(ctx, t, msgs)
require.Equal(t, "test", msg)
close(success)
fsn.Notify()
<-cmdDone
fsn.AssertStopped()
// wait for the remote socket to get cleaned up before retrying,
// because cleaning up the socket happens asynchronously, and we
// might connect to an old listener on the agent side.
require.Eventually(t, func() bool {
_, err = os.Stat(remoteSock)
return xerrors.Is(err, os.ErrNotExist)
}, testutil.WaitShort, testutil.IntervalFast)
}
})
t.Run("StdioExitOnStop", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {

View File

@ -6,7 +6,6 @@ import (
"context"
"fmt"
"os"
"os/signal"
"time"
"github.com/google/uuid"
@ -61,7 +60,7 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel()
notifyCtx, notifyStop := signal.NotifyContext(ctx, agpl.InterruptSignals...)
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, agpl.InterruptSignals...)
defer notifyStop()
tags, err := agpl.ParseProvisionerTags(rawTags)

View File

@ -10,7 +10,6 @@ import (
"net"
"net/http"
"net/http/pprof"
"os/signal"
"regexp"
rpprof "runtime/pprof"
"time"
@ -142,7 +141,7 @@ func (*RootCmd) proxyServer() *clibase.Cmd {
//
// To get out of a graceful shutdown, the user can send
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
notifyCtx, notifyStop := signal.NotifyContext(ctx, cli.InterruptSignals...)
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, cli.InterruptSignals...)
defer notifyStop()
// Clean up idle connections at the end, e.g.

View File

@ -788,6 +788,7 @@ func (c *Conn) Close() error {
}
_ = c.netStack.Close()
c.logger.Debug(context.Background(), "closed netstack")
c.dialCancel()
_ = c.wireguardMonitor.Close()
_ = c.dialer.Close()

View File

@ -11,3 +11,14 @@ func Context(t *testing.T, dur time.Duration) context.Context {
t.Cleanup(cancel)
return ctx
}
func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A) {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout")
return a
case a = <-c:
return a
}
}