From 200a87e7d494953b6aaafd8f944e60160caa6ffe Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 19 Jan 2024 15:21:10 +0200 Subject: [PATCH] feat(cli/ssh): allow multiple remote forwards and allow missing local file (#11648) --- cli/remoteforward.go | 13 +--- cli/ssh.go | 44 +++++++------ cli/ssh_test.go | 98 ++++++++++++++++++++++++++++ cli/testdata/coder_ssh_--help.golden | 2 +- docs/cli/ssh.md | 2 +- 5 files changed, 127 insertions(+), 32 deletions(-) diff --git a/cli/remoteforward.go b/cli/remoteforward.go index 2c4207583b..bffc50694c 100644 --- a/cli/remoteforward.go +++ b/cli/remoteforward.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net" - "os" "regexp" "strconv" @@ -67,19 +66,13 @@ func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) { return localAddr, remoteAddr, nil } +// parseRemoteForwardUnixSocket parses a remote forward flag. Note that +// we don't verify that the local socket path exists because the user +// may create it later. This behavior matches OpenSSH. func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) { remoteSocket := matches[1] localSocket := matches[2] - fileInfo, err := os.Stat(localSocket) - if err != nil { - return nil, nil, err - } - - if fileInfo.Mode()&os.ModeSocket == 0 { - return nil, nil, xerrors.New("File is not a Unix domain socket file") - } - remoteAddr := &net.UnixAddr{ Name: remoteSocket, Net: "unix", diff --git a/cli/ssh.go b/cli/ssh.go index b3fc79d51d..b11f48b9b1 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -53,7 +53,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { waitEnum string noWait bool logDirPath string - remoteForward string + remoteForwards []string disableAutostart bool ) client := new(codersdk.Client) @@ -135,13 +135,15 @@ func (r *RootCmd) ssh() *clibase.Cmd { stack := newCloserStack(ctx, logger) defer stack.close(nil) - if remoteForward != "" { - isValid := validateRemoteForward(remoteForward) - if !isValid { - return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`) - } - if isValid && stdio { - return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`) + if len(remoteForwards) > 0 { + for _, remoteForward := range remoteForwards { + isValid := validateRemoteForward(remoteForward) + if !isValid { + return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`) + } + if isValid && stdio { + return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`) + } } } @@ -311,18 +313,20 @@ func (r *RootCmd) ssh() *clibase.Cmd { } } - if remoteForward != "" { - localAddr, remoteAddr, err := parseRemoteForward(remoteForward) - if err != nil { - return err - } + if len(remoteForwards) > 0 { + for _, remoteForward := range remoteForwards { + localAddr, remoteAddr, err := parseRemoteForward(remoteForward) + if err != nil { + return err + } - closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr) - if err != nil { - return xerrors.Errorf("ssh remote forward: %w", err) - } - if err = stack.push("sshRemoteForward", closer); err != nil { - return err + closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr) + if err != nil { + return xerrors.Errorf("ssh remote forward: %w", err) + } + if err = stack.push("sshRemoteForward", closer); err != nil { + return err + } } } @@ -460,7 +464,7 @@ func (r *RootCmd) ssh() *clibase.Cmd { Description: "Enable remote port forwarding (remote_port:local_address:local_port).", Env: "CODER_SSH_REMOTE_FORWARD", FlagShorthand: "R", - Value: clibase.StringOf(&remoteForward), + Value: clibase.StringArrayOf(&remoteForwards), }, sshDisableAutostartOption(clibase.BoolOf(&disableAutostart)), } diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 684e8700c1..fdde064ce9 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -883,6 +883,104 @@ func TestSSH(t *testing.T) { require.NoError(t, err) }) + // Test that we can remote forward multiple sockets, whether or not the + // local sockets exists at the time of establishing xthe SSH connection. + t.Run("RemoteForwardMultipleUnixSockets", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Test not supported on windows") + } + + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + // Wait super long so this doesn't flake on -race test. + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + + tmpdir := tempDirUnixSocket(t) + + type testSocket struct { + local string + remote string + } + + args := []string{"ssh", workspace.Name} + var sockets []testSocket + for i := 0; i < 2; i++ { + localSock := filepath.Join(tmpdir, fmt.Sprintf("local-%d.sock", i)) + remoteSock := filepath.Join(tmpdir, fmt.Sprintf("remote-%d.sock", i)) + sockets = append(sockets, testSocket{ + local: localSock, + remote: remoteSock, + }) + args = append(args, "--remote-forward", fmt.Sprintf("%s:%s", remoteSock, localSock)) + } + + inv, root := clitest.New(t, args...) + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t).Attach(inv) + inv.Stderr = pty.Output() + + w := clitest.StartWithWaiter(t, inv.WithContext(ctx)) + defer w.Wait() // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly). + + // Since something was output, it should be safe to write input. + // This could show a prompt or "running startup scripts", so it's + // not indicative of the SSH connection being ready. + _ = pty.Peek(ctx, 1) + + // Ensure the SSH connection is ready by testing the shell + // input/output. + pty.WriteLine("echo ping' 'pong") + pty.ExpectMatchContext(ctx, "ping pong") + + for i, sock := range sockets { + i := i + // Start the listener on the "local machine". + l, err := net.Listen("unix", sock.local) + require.NoError(t, err) + defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice. + testutil.Go(t, func() { + for { + fd, err := l.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + assert.NoError(t, err, "listener accept failed", i) + } + return + } + + testutil.Go(t, func() { + defer fd.Close() + agentssh.Bicopy(ctx, fd, fd) + }) + } + }) + + // Dial the forwarded socket on the "remote machine". + d := &net.Dialer{} + fd, err := d.DialContext(ctx, "unix", sock.remote) + require.NoError(t, err, i) + defer fd.Close() //nolint:revive // Defer is fine in this loop, we only run it twice. + + // Ping / pong to ensure the socket is working. + _, err = fd.Write([]byte("hello world")) + require.NoError(t, err, i) + + buf := make([]byte, 11) + _, err = fd.Read(buf) + require.NoError(t, err, i) + require.Equal(t, "hello world", string(buf), i) + } + + // And we're done. + pty.WriteLine("exit") + }) + t.Run("FileLogging", func(t *testing.T) { t.Parallel() diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index b76e56a8ab..ce53948c70 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -33,7 +33,7 @@ OPTIONS: behavior as non-blocking. DEPRECATED: Use --wait instead. - -R, --remote-forward string, $CODER_SSH_REMOTE_FORWARD + -R, --remote-forward string-array, $CODER_SSH_REMOTE_FORWARD Enable remote port forwarding (remote_port:local_address:local_port). --stdio bool, $CODER_SSH_STDIO diff --git a/docs/cli/ssh.md b/docs/cli/ssh.md index b3416f3307..34762d5b2b 100644 --- a/docs/cli/ssh.md +++ b/docs/cli/ssh.md @@ -71,7 +71,7 @@ Enter workspace immediately after the agent has connected. This is the default i | | | | ----------- | -------------------------------------- | -| Type | string | +| Type | string-array | | Environment | $CODER_SSH_REMOTE_FORWARD | Enable remote port forwarding (remote_port:local_address:local_port).