mirror of https://github.com/coder/coder.git
fix(agent/agentssh): allow remote forwarding a socket multiple times (#11631)
* fix(agent/agentssh): allow remote forwarding a socket multiple times Fixes #11198 Fixes https://github.com/coder/customers/issues/407
This commit is contained in:
parent
08b4eb3124
commit
385d58caf6
|
@ -99,7 +99,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
|||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := &forwardedUnixHandler{log: logger}
|
||||
unixForwardHandler := newForwardedUnixHandler(logger)
|
||||
|
||||
metrics := newSSHServerMetrics(prometheusRegistry)
|
||||
s := &Server{
|
||||
|
|
|
@ -2,11 +2,14 @@ package agentssh
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
@ -33,22 +36,29 @@ type forwardedStreamLocalPayload struct {
|
|||
type forwardedUnixHandler struct {
|
||||
sync.Mutex
|
||||
log slog.Logger
|
||||
forwards map[string]net.Listener
|
||||
forwards map[forwardKey]net.Listener
|
||||
}
|
||||
|
||||
type forwardKey struct {
|
||||
sessionID string
|
||||
addr string
|
||||
}
|
||||
|
||||
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
|
||||
return &forwardedUnixHandler{
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
|
||||
h.log.Debug(ctx, "handling SSH unix forward")
|
||||
h.Lock()
|
||||
if h.forwards == nil {
|
||||
h.forwards = make(map[string]net.Listener)
|
||||
}
|
||||
h.Unlock()
|
||||
conn, ok := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
||||
if !ok {
|
||||
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
|
||||
return false, nil
|
||||
}
|
||||
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
|
||||
log := h.log.With(slog.F("session_id", ctx.SessionID()), slog.F("remote_addr", conn.RemoteAddr()))
|
||||
|
||||
switch req.Type {
|
||||
case "streamlocal-forward@openssh.com":
|
||||
|
@ -62,14 +72,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
addr := reqPayload.SocketPath
|
||||
log = log.With(slog.F("socket_path", addr))
|
||||
log.Debug(ctx, "request begin SSH unix forward")
|
||||
|
||||
key := forwardKey{
|
||||
sessionID: ctx.SessionID(),
|
||||
addr: addr,
|
||||
}
|
||||
|
||||
h.Lock()
|
||||
_, ok := h.forwards[addr]
|
||||
_, ok := h.forwards[key]
|
||||
h.Unlock()
|
||||
if ok {
|
||||
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
|
||||
slog.F("socket_path", addr),
|
||||
)
|
||||
return false, nil
|
||||
// In cases where `ExitOnForwardFailure=yes` is set, returning false
|
||||
// here will cause the connection to be closed. To avoid this, and
|
||||
// to match OpenSSH behavior, we silently ignore the second forward
|
||||
// request.
|
||||
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded on this session, ignoring")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Create socket parent dir if not exists.
|
||||
|
@ -83,12 +101,20 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
return false, nil
|
||||
}
|
||||
|
||||
ln, err := net.Listen("unix", addr)
|
||||
// Remove existing socket if it exists. We do not use os.Remove() here
|
||||
// so that directories are kept. Note that it's possible that we will
|
||||
// overwrite a regular file here. Both of these behaviors match OpenSSH,
|
||||
// however, which is why we unlink.
|
||||
err = unlink(addr)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
log.Warn(ctx, "remove existing socket for SSH unix forward request", slog.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
lc := &net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "unix", addr)
|
||||
if err != nil {
|
||||
log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
|
||||
slog.F("socket_path", addr),
|
||||
slog.Error(err),
|
||||
)
|
||||
log.Warn(ctx, "listen on Unix socket for SSH unix forward request", slog.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
log.Debug(ctx, "SSH unix forward listening on socket")
|
||||
|
@ -99,7 +125,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
//
|
||||
// This is also what the upstream TCP version of this code does.
|
||||
h.Lock()
|
||||
h.forwards[addr] = ln
|
||||
h.forwards[key] = ln
|
||||
h.Unlock()
|
||||
log.Debug(ctx, "SSH unix forward added to cache")
|
||||
|
||||
|
@ -115,9 +141,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
if !xerrors.Is(err, net.ErrClosed) {
|
||||
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
|
||||
slog.Error(err),
|
||||
)
|
||||
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", slog.Error(err))
|
||||
}
|
||||
// closed below
|
||||
log.Debug(ctx, "SSH unix forward listener closed")
|
||||
|
@ -131,10 +155,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
go func() {
|
||||
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
|
||||
if err != nil {
|
||||
h.log.Warn(ctx, "open SSH unix forward channel to client",
|
||||
slog.F("socket_path", addr),
|
||||
slog.Error(err),
|
||||
)
|
||||
h.log.Warn(ctx, "open SSH unix forward channel to client", slog.Error(err))
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
|
@ -144,12 +165,11 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
}
|
||||
|
||||
h.Lock()
|
||||
ln2, ok := h.forwards[addr]
|
||||
if ok && ln2 == ln {
|
||||
delete(h.forwards, addr)
|
||||
if ln2, ok := h.forwards[key]; ok && ln2 == ln {
|
||||
delete(h.forwards, key)
|
||||
}
|
||||
h.Unlock()
|
||||
log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr))
|
||||
log.Debug(ctx, "SSH unix forward listener removed from cache")
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
|
@ -162,13 +182,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
|||
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath))
|
||||
h.Lock()
|
||||
ln, ok := h.forwards[reqPayload.SocketPath]
|
||||
h.Unlock()
|
||||
if ok {
|
||||
_ = ln.Close()
|
||||
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("socket_path", reqPayload.SocketPath))
|
||||
|
||||
key := forwardKey{
|
||||
sessionID: ctx.SessionID(),
|
||||
addr: reqPayload.SocketPath,
|
||||
}
|
||||
|
||||
h.Lock()
|
||||
ln, ok := h.forwards[key]
|
||||
delete(h.forwards, key)
|
||||
h.Unlock()
|
||||
if !ok {
|
||||
log.Warn(ctx, "SSH unix forward not found in cache")
|
||||
return true, nil
|
||||
}
|
||||
_ = ln.Close()
|
||||
return true, nil
|
||||
|
||||
default:
|
||||
|
@ -209,3 +238,15 @@ func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh.
|
|||
|
||||
Bicopy(ctx, ch, dconn)
|
||||
}
|
||||
|
||||
// unlink removes files and unlike os.Remove, directories are kept.
|
||||
func unlink(path string) error {
|
||||
// Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
|
||||
// for more details.
|
||||
for {
|
||||
err := syscall.Unlink(path)
|
||||
if !errors.Is(err, syscall.EINTR) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
118
cli/ssh_test.go
118
cli/ssh_test.go
|
@ -26,12 +26,14 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
gosshagent "golang.org/x/crypto/ssh/agent"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
|
@ -738,8 +740,8 @@ func TestSSH(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
tmpdir := tempDirUnixSocket(t)
|
||||
agentSock := filepath.Join(tmpdir, "agent.sock")
|
||||
l, err := net.Listen("unix", agentSock)
|
||||
localSock := filepath.Join(tmpdir, "local.sock")
|
||||
l, err := net.Listen("unix", localSock)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
remoteSock := filepath.Join(tmpdir, "remote.sock")
|
||||
|
@ -748,7 +750,7 @@ func TestSSH(t *testing.T) {
|
|||
"ssh",
|
||||
workspace.Name,
|
||||
"--remote-forward",
|
||||
fmt.Sprintf("%s:%s", remoteSock, agentSock),
|
||||
fmt.Sprintf("%s:%s", remoteSock, localSock),
|
||||
)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
|
@ -771,6 +773,116 @@ func TestSSH(t *testing.T) {
|
|||
<-cmdDone
|
||||
})
|
||||
|
||||
// Test that we can forward a local unix socket to a remote unix socket and
|
||||
// that new SSH sessions take over the socket without closing active socket
|
||||
// connections.
|
||||
t.Run("RemoteForwardUnixSocketMultipleSessionsOverwrite", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Test not supported on windows")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t)
|
||||
|
||||
_ = agenttest.New(t, client.URL, agentToken)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
|
||||
// Wait super super long so this doesn't flake on -race test.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong*2)
|
||||
defer cancel()
|
||||
|
||||
tmpdir := tempDirUnixSocket(t)
|
||||
|
||||
localSock := filepath.Join(tmpdir, "local.sock")
|
||||
l, err := net.Listen("unix", localSock)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
testutil.Go(t, func() {
|
||||
for {
|
||||
fd, err := l.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
assert.NoError(t, err, "listener accept failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
testutil.Go(t, func() {
|
||||
defer fd.Close()
|
||||
agentssh.Bicopy(ctx, fd, fd)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
remoteSock := filepath.Join(tmpdir, "remote.sock")
|
||||
|
||||
var done []func() error
|
||||
for i := 0; i < 2; i++ {
|
||||
id := fmt.Sprintf("ssh-%d", i)
|
||||
inv, root := clitest.New(t,
|
||||
"ssh",
|
||||
workspace.Name,
|
||||
"--remote-forward",
|
||||
fmt.Sprintf("%s:%s", remoteSock, localSock),
|
||||
)
|
||||
inv.Logger = inv.Logger.Named(id)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
inv.Stderr = pty.Output()
|
||||
cmdDone := tGo(t, func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
assert.NoError(t, err, "ssh command failed: %s", id)
|
||||
})
|
||||
|
||||
// Since something was output, it should be safe to write input.
|
||||
// This could show a prompt or "running startup scripts", so it's
|
||||
// not indicative of the SSH connection being ready.
|
||||
_ = pty.Peek(ctx, 1)
|
||||
|
||||
// Ensure the SSH connection is ready by testing the shell
|
||||
// input/output.
|
||||
pty.WriteLine("echo ping' 'pong")
|
||||
pty.ExpectMatchContext(ctx, "ping pong")
|
||||
|
||||
d := &net.Dialer{}
|
||||
fd, err := d.DialContext(ctx, "unix", remoteSock)
|
||||
require.NoError(t, err, id)
|
||||
|
||||
// Ping / pong to ensure the socket is working.
|
||||
_, err = fd.Write([]byte("hello world"))
|
||||
require.NoError(t, err, id)
|
||||
|
||||
buf := make([]byte, 11)
|
||||
_, err = fd.Read(buf)
|
||||
require.NoError(t, err, id)
|
||||
require.Equal(t, "hello world", string(buf), id)
|
||||
|
||||
done = append(done, func() error {
|
||||
// Redo ping / pong to ensure that the socket
|
||||
// connections still work.
|
||||
_, err := fd.Write([]byte("hello world"))
|
||||
assert.NoError(t, err, id)
|
||||
|
||||
buf := make([]byte, 11)
|
||||
_, err = fd.Read(buf)
|
||||
assert.NoError(t, err, id)
|
||||
assert.Equal(t, "hello world", string(buf), id)
|
||||
|
||||
pty.WriteLine("exit")
|
||||
<-cmdDone
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var eg errgroup.Group
|
||||
for _, d := range done {
|
||||
eg.Go(d)
|
||||
}
|
||||
err = eg.Wait()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("FileLogging", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue