mirror of https://github.com/coder/coder.git
feat(cli/ssh): allow multiple remote forwards and allow missing local file (#11648)
This commit is contained in:
parent
73e6bbff7e
commit
200a87e7d4
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
|
@ -67,19 +66,13 @@ func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
|
|||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseRemoteForwardUnixSocket parses a remote forward flag. Note that
|
||||
// we don't verify that the local socket path exists because the user
|
||||
// may create it later. This behavior matches OpenSSH.
|
||||
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
|
||||
remoteSocket := matches[1]
|
||||
localSocket := matches[2]
|
||||
|
||||
fileInfo, err := os.Stat(localSocket)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if fileInfo.Mode()&os.ModeSocket == 0 {
|
||||
return nil, nil, xerrors.New("File is not a Unix domain socket file")
|
||||
}
|
||||
|
||||
remoteAddr := &net.UnixAddr{
|
||||
Name: remoteSocket,
|
||||
Net: "unix",
|
||||
|
|
44
cli/ssh.go
44
cli/ssh.go
|
@ -53,7 +53,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
waitEnum string
|
||||
noWait bool
|
||||
logDirPath string
|
||||
remoteForward string
|
||||
remoteForwards []string
|
||||
disableAutostart bool
|
||||
)
|
||||
client := new(codersdk.Client)
|
||||
|
@ -135,13 +135,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
stack := newCloserStack(ctx, logger)
|
||||
defer stack.close(nil)
|
||||
|
||||
if remoteForward != "" {
|
||||
isValid := validateRemoteForward(remoteForward)
|
||||
if !isValid {
|
||||
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
|
||||
}
|
||||
if isValid && stdio {
|
||||
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
|
||||
if len(remoteForwards) > 0 {
|
||||
for _, remoteForward := range remoteForwards {
|
||||
isValid := validateRemoteForward(remoteForward)
|
||||
if !isValid {
|
||||
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
|
||||
}
|
||||
if isValid && stdio {
|
||||
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -311,18 +313,20 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
}
|
||||
}
|
||||
|
||||
if remoteForward != "" {
|
||||
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(remoteForwards) > 0 {
|
||||
for _, remoteForward := range remoteForwards {
|
||||
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ssh remote forward: %w", err)
|
||||
}
|
||||
if err = stack.push("sshRemoteForward", closer); err != nil {
|
||||
return err
|
||||
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ssh remote forward: %w", err)
|
||||
}
|
||||
if err = stack.push("sshRemoteForward", closer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -460,7 +464,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
Description: "Enable remote port forwarding (remote_port:local_address:local_port).",
|
||||
Env: "CODER_SSH_REMOTE_FORWARD",
|
||||
FlagShorthand: "R",
|
||||
Value: clibase.StringOf(&remoteForward),
|
||||
Value: clibase.StringArrayOf(&remoteForwards),
|
||||
},
|
||||
sshDisableAutostartOption(clibase.BoolOf(&disableAutostart)),
|
||||
}
|
||||
|
|
|
@ -883,6 +883,104 @@ func TestSSH(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
// Test that we can remote forward multiple sockets, whether or not the
|
||||
// local sockets exists at the time of establishing xthe SSH connection.
|
||||
t.Run("RemoteForwardMultipleUnixSockets", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Test not supported on windows")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t)
|
||||
|
||||
_ = agenttest.New(t, client.URL, agentToken)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
|
||||
// Wait super long so this doesn't flake on -race test.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
|
||||
tmpdir := tempDirUnixSocket(t)
|
||||
|
||||
type testSocket struct {
|
||||
local string
|
||||
remote string
|
||||
}
|
||||
|
||||
args := []string{"ssh", workspace.Name}
|
||||
var sockets []testSocket
|
||||
for i := 0; i < 2; i++ {
|
||||
localSock := filepath.Join(tmpdir, fmt.Sprintf("local-%d.sock", i))
|
||||
remoteSock := filepath.Join(tmpdir, fmt.Sprintf("remote-%d.sock", i))
|
||||
sockets = append(sockets, testSocket{
|
||||
local: localSock,
|
||||
remote: remoteSock,
|
||||
})
|
||||
args = append(args, "--remote-forward", fmt.Sprintf("%s:%s", remoteSock, localSock))
|
||||
}
|
||||
|
||||
inv, root := clitest.New(t, args...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
inv.Stderr = pty.Output()
|
||||
|
||||
w := clitest.StartWithWaiter(t, inv.WithContext(ctx))
|
||||
defer w.Wait() // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
|
||||
|
||||
// Since something was output, it should be safe to write input.
|
||||
// This could show a prompt or "running startup scripts", so it's
|
||||
// not indicative of the SSH connection being ready.
|
||||
_ = pty.Peek(ctx, 1)
|
||||
|
||||
// Ensure the SSH connection is ready by testing the shell
|
||||
// input/output.
|
||||
pty.WriteLine("echo ping' 'pong")
|
||||
pty.ExpectMatchContext(ctx, "ping pong")
|
||||
|
||||
for i, sock := range sockets {
|
||||
i := i
|
||||
// Start the listener on the "local machine".
|
||||
l, err := net.Listen("unix", sock.local)
|
||||
require.NoError(t, err)
|
||||
defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
|
||||
testutil.Go(t, func() {
|
||||
for {
|
||||
fd, err := l.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
assert.NoError(t, err, "listener accept failed", i)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
testutil.Go(t, func() {
|
||||
defer fd.Close()
|
||||
agentssh.Bicopy(ctx, fd, fd)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Dial the forwarded socket on the "remote machine".
|
||||
d := &net.Dialer{}
|
||||
fd, err := d.DialContext(ctx, "unix", sock.remote)
|
||||
require.NoError(t, err, i)
|
||||
defer fd.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
|
||||
|
||||
// Ping / pong to ensure the socket is working.
|
||||
_, err = fd.Write([]byte("hello world"))
|
||||
require.NoError(t, err, i)
|
||||
|
||||
buf := make([]byte, 11)
|
||||
_, err = fd.Read(buf)
|
||||
require.NoError(t, err, i)
|
||||
require.Equal(t, "hello world", string(buf), i)
|
||||
}
|
||||
|
||||
// And we're done.
|
||||
pty.WriteLine("exit")
|
||||
})
|
||||
|
||||
t.Run("FileLogging", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ OPTIONS:
|
|||
behavior as non-blocking.
|
||||
DEPRECATED: Use --wait instead.
|
||||
|
||||
-R, --remote-forward string, $CODER_SSH_REMOTE_FORWARD
|
||||
-R, --remote-forward string-array, $CODER_SSH_REMOTE_FORWARD
|
||||
Enable remote port forwarding (remote_port:local_address:local_port).
|
||||
|
||||
--stdio bool, $CODER_SSH_STDIO
|
||||
|
|
|
@ -71,7 +71,7 @@ Enter workspace immediately after the agent has connected. This is the default i
|
|||
|
||||
| | |
|
||||
| ----------- | -------------------------------------- |
|
||||
| Type | <code>string</code> |
|
||||
| Type | <code>string-array</code> |
|
||||
| Environment | <code>$CODER_SSH_REMOTE_FORWARD</code> |
|
||||
|
||||
Enable remote port forwarding (remote_port:local_address:local_port).
|
||||
|
|
Loading…
Reference in New Issue