feat(cli/ssh): allow multiple remote forwards and allow missing local file (#11648)

This commit is contained in:
Mathias Fredriksson 2024-01-19 15:21:10 +02:00 committed by GitHub
parent 73e6bbff7e
commit 200a87e7d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 127 additions and 32 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"os"
"regexp" "regexp"
"strconv" "strconv"
@ -67,19 +66,13 @@ func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
return localAddr, remoteAddr, nil return localAddr, remoteAddr, nil
} }
// parseRemoteForwardUnixSocket parses a remote forward flag. Note that
// we don't verify that the local socket path exists because the user
// may create it later. This behavior matches OpenSSH.
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) { func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
remoteSocket := matches[1] remoteSocket := matches[1]
localSocket := matches[2] localSocket := matches[2]
fileInfo, err := os.Stat(localSocket)
if err != nil {
return nil, nil, err
}
if fileInfo.Mode()&os.ModeSocket == 0 {
return nil, nil, xerrors.New("File is not a Unix domain socket file")
}
remoteAddr := &net.UnixAddr{ remoteAddr := &net.UnixAddr{
Name: remoteSocket, Name: remoteSocket,
Net: "unix", Net: "unix",

View File

@ -53,7 +53,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
waitEnum string waitEnum string
noWait bool noWait bool
logDirPath string logDirPath string
remoteForward string remoteForwards []string
disableAutostart bool disableAutostart bool
) )
client := new(codersdk.Client) client := new(codersdk.Client)
@ -135,13 +135,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
stack := newCloserStack(ctx, logger) stack := newCloserStack(ctx, logger)
defer stack.close(nil) defer stack.close(nil)
if remoteForward != "" { if len(remoteForwards) > 0 {
isValid := validateRemoteForward(remoteForward) for _, remoteForward := range remoteForwards {
if !isValid { isValid := validateRemoteForward(remoteForward)
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`) if !isValid {
} return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
if isValid && stdio { }
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`) if isValid && stdio {
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
}
} }
} }
@ -311,18 +313,20 @@ func (r *RootCmd) ssh() *clibase.Cmd {
} }
} }
if remoteForward != "" { if len(remoteForwards) > 0 {
localAddr, remoteAddr, err := parseRemoteForward(remoteForward) for _, remoteForward := range remoteForwards {
if err != nil { localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
return err if err != nil {
} return err
}
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr) closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
if err != nil { if err != nil {
return xerrors.Errorf("ssh remote forward: %w", err) return xerrors.Errorf("ssh remote forward: %w", err)
} }
if err = stack.push("sshRemoteForward", closer); err != nil { if err = stack.push("sshRemoteForward", closer); err != nil {
return err return err
}
} }
} }
@ -460,7 +464,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
Description: "Enable remote port forwarding (remote_port:local_address:local_port).", Description: "Enable remote port forwarding (remote_port:local_address:local_port).",
Env: "CODER_SSH_REMOTE_FORWARD", Env: "CODER_SSH_REMOTE_FORWARD",
FlagShorthand: "R", FlagShorthand: "R",
Value: clibase.StringOf(&remoteForward), Value: clibase.StringArrayOf(&remoteForwards),
}, },
sshDisableAutostartOption(clibase.BoolOf(&disableAutostart)), sshDisableAutostartOption(clibase.BoolOf(&disableAutostart)),
} }

View File

@ -883,6 +883,104 @@ func TestSSH(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}) })
// Test that we can remote forward multiple sockets, whether or not the
// local sockets exists at the time of establishing xthe SSH connection.
t.Run("RemoteForwardMultipleUnixSockets", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Test not supported on windows")
}
t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t)
_ = agenttest.New(t, client.URL, agentToken)
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
// Wait super long so this doesn't flake on -race test.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
tmpdir := tempDirUnixSocket(t)
type testSocket struct {
local string
remote string
}
args := []string{"ssh", workspace.Name}
var sockets []testSocket
for i := 0; i < 2; i++ {
localSock := filepath.Join(tmpdir, fmt.Sprintf("local-%d.sock", i))
remoteSock := filepath.Join(tmpdir, fmt.Sprintf("remote-%d.sock", i))
sockets = append(sockets, testSocket{
local: localSock,
remote: remoteSock,
})
args = append(args, "--remote-forward", fmt.Sprintf("%s:%s", remoteSock, localSock))
}
inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t).Attach(inv)
inv.Stderr = pty.Output()
w := clitest.StartWithWaiter(t, inv.WithContext(ctx))
defer w.Wait() // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
// Since something was output, it should be safe to write input.
// This could show a prompt or "running startup scripts", so it's
// not indicative of the SSH connection being ready.
_ = pty.Peek(ctx, 1)
// Ensure the SSH connection is ready by testing the shell
// input/output.
pty.WriteLine("echo ping' 'pong")
pty.ExpectMatchContext(ctx, "ping pong")
for i, sock := range sockets {
i := i
// Start the listener on the "local machine".
l, err := net.Listen("unix", sock.local)
require.NoError(t, err)
defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
testutil.Go(t, func() {
for {
fd, err := l.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
assert.NoError(t, err, "listener accept failed", i)
}
return
}
testutil.Go(t, func() {
defer fd.Close()
agentssh.Bicopy(ctx, fd, fd)
})
}
})
// Dial the forwarded socket on the "remote machine".
d := &net.Dialer{}
fd, err := d.DialContext(ctx, "unix", sock.remote)
require.NoError(t, err, i)
defer fd.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
// Ping / pong to ensure the socket is working.
_, err = fd.Write([]byte("hello world"))
require.NoError(t, err, i)
buf := make([]byte, 11)
_, err = fd.Read(buf)
require.NoError(t, err, i)
require.Equal(t, "hello world", string(buf), i)
}
// And we're done.
pty.WriteLine("exit")
})
t.Run("FileLogging", func(t *testing.T) { t.Run("FileLogging", func(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -33,7 +33,7 @@ OPTIONS:
behavior as non-blocking. behavior as non-blocking.
DEPRECATED: Use --wait instead. DEPRECATED: Use --wait instead.
-R, --remote-forward string, $CODER_SSH_REMOTE_FORWARD -R, --remote-forward string-array, $CODER_SSH_REMOTE_FORWARD
Enable remote port forwarding (remote_port:local_address:local_port). Enable remote port forwarding (remote_port:local_address:local_port).
--stdio bool, $CODER_SSH_STDIO --stdio bool, $CODER_SSH_STDIO

2
docs/cli/ssh.md generated
View File

@ -71,7 +71,7 @@ Enter workspace immediately after the agent has connected. This is the default i
| | | | | |
| ----------- | -------------------------------------- | | ----------- | -------------------------------------- |
| Type | <code>string</code> | | Type | <code>string-array</code> |
| Environment | <code>$CODER_SSH_REMOTE_FORWARD</code> | | Environment | <code>$CODER_SSH_REMOTE_FORWARD</code> |
Enable remote port forwarding (remote_port:local_address:local_port). Enable remote port forwarding (remote_port:local_address:local_port).