mirror of https://github.com/coder/coder.git
feat(cli/ssh): allow multiple remote forwards and allow missing local file (#11648)
This commit is contained in:
parent
73e6bbff7e
commit
200a87e7d4
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
@ -67,19 +66,13 @@ func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
|
||||||
return localAddr, remoteAddr, nil
|
return localAddr, remoteAddr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseRemoteForwardUnixSocket parses a remote forward flag. Note that
|
||||||
|
// we don't verify that the local socket path exists because the user
|
||||||
|
// may create it later. This behavior matches OpenSSH.
|
||||||
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
|
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
|
||||||
remoteSocket := matches[1]
|
remoteSocket := matches[1]
|
||||||
localSocket := matches[2]
|
localSocket := matches[2]
|
||||||
|
|
||||||
fileInfo, err := os.Stat(localSocket)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if fileInfo.Mode()&os.ModeSocket == 0 {
|
|
||||||
return nil, nil, xerrors.New("File is not a Unix domain socket file")
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAddr := &net.UnixAddr{
|
remoteAddr := &net.UnixAddr{
|
||||||
Name: remoteSocket,
|
Name: remoteSocket,
|
||||||
Net: "unix",
|
Net: "unix",
|
||||||
|
|
44
cli/ssh.go
44
cli/ssh.go
|
@ -53,7 +53,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
||||||
waitEnum string
|
waitEnum string
|
||||||
noWait bool
|
noWait bool
|
||||||
logDirPath string
|
logDirPath string
|
||||||
remoteForward string
|
remoteForwards []string
|
||||||
disableAutostart bool
|
disableAutostart bool
|
||||||
)
|
)
|
||||||
client := new(codersdk.Client)
|
client := new(codersdk.Client)
|
||||||
|
@ -135,13 +135,15 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
||||||
stack := newCloserStack(ctx, logger)
|
stack := newCloserStack(ctx, logger)
|
||||||
defer stack.close(nil)
|
defer stack.close(nil)
|
||||||
|
|
||||||
if remoteForward != "" {
|
if len(remoteForwards) > 0 {
|
||||||
isValid := validateRemoteForward(remoteForward)
|
for _, remoteForward := range remoteForwards {
|
||||||
if !isValid {
|
isValid := validateRemoteForward(remoteForward)
|
||||||
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
|
if !isValid {
|
||||||
}
|
return xerrors.Errorf(`invalid format of remote-forward, expected: remote_port:local_address:local_port`)
|
||||||
if isValid && stdio {
|
}
|
||||||
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
|
if isValid && stdio {
|
||||||
|
return xerrors.Errorf(`remote-forward can't be enabled in the stdio mode`)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,18 +313,20 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if remoteForward != "" {
|
if len(remoteForwards) > 0 {
|
||||||
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
|
for _, remoteForward := range remoteForwards {
|
||||||
if err != nil {
|
localAddr, remoteAddr, err := parseRemoteForward(remoteForward)
|
||||||
return err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
|
closer, err := sshRemoteForward(ctx, inv.Stderr, sshClient, localAddr, remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("ssh remote forward: %w", err)
|
return xerrors.Errorf("ssh remote forward: %w", err)
|
||||||
}
|
}
|
||||||
if err = stack.push("sshRemoteForward", closer); err != nil {
|
if err = stack.push("sshRemoteForward", closer); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -460,7 +464,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
||||||
Description: "Enable remote port forwarding (remote_port:local_address:local_port).",
|
Description: "Enable remote port forwarding (remote_port:local_address:local_port).",
|
||||||
Env: "CODER_SSH_REMOTE_FORWARD",
|
Env: "CODER_SSH_REMOTE_FORWARD",
|
||||||
FlagShorthand: "R",
|
FlagShorthand: "R",
|
||||||
Value: clibase.StringOf(&remoteForward),
|
Value: clibase.StringArrayOf(&remoteForwards),
|
||||||
},
|
},
|
||||||
sshDisableAutostartOption(clibase.BoolOf(&disableAutostart)),
|
sshDisableAutostartOption(clibase.BoolOf(&disableAutostart)),
|
||||||
}
|
}
|
||||||
|
|
|
@ -883,6 +883,104 @@ func TestSSH(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Test that we can remote forward multiple sockets, whether or not the
|
||||||
|
// local sockets exists at the time of establishing xthe SSH connection.
|
||||||
|
t.Run("RemoteForwardMultipleUnixSockets", func(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("Test not supported on windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client, workspace, agentToken := setupWorkspaceForAgent(t)
|
||||||
|
|
||||||
|
_ = agenttest.New(t, client.URL, agentToken)
|
||||||
|
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||||
|
|
||||||
|
// Wait super long so this doesn't flake on -race test.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tmpdir := tempDirUnixSocket(t)
|
||||||
|
|
||||||
|
type testSocket struct {
|
||||||
|
local string
|
||||||
|
remote string
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{"ssh", workspace.Name}
|
||||||
|
var sockets []testSocket
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
localSock := filepath.Join(tmpdir, fmt.Sprintf("local-%d.sock", i))
|
||||||
|
remoteSock := filepath.Join(tmpdir, fmt.Sprintf("remote-%d.sock", i))
|
||||||
|
sockets = append(sockets, testSocket{
|
||||||
|
local: localSock,
|
||||||
|
remote: remoteSock,
|
||||||
|
})
|
||||||
|
args = append(args, "--remote-forward", fmt.Sprintf("%s:%s", remoteSock, localSock))
|
||||||
|
}
|
||||||
|
|
||||||
|
inv, root := clitest.New(t, args...)
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
pty := ptytest.New(t).Attach(inv)
|
||||||
|
inv.Stderr = pty.Output()
|
||||||
|
|
||||||
|
w := clitest.StartWithWaiter(t, inv.WithContext(ctx))
|
||||||
|
defer w.Wait() // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
|
||||||
|
|
||||||
|
// Since something was output, it should be safe to write input.
|
||||||
|
// This could show a prompt or "running startup scripts", so it's
|
||||||
|
// not indicative of the SSH connection being ready.
|
||||||
|
_ = pty.Peek(ctx, 1)
|
||||||
|
|
||||||
|
// Ensure the SSH connection is ready by testing the shell
|
||||||
|
// input/output.
|
||||||
|
pty.WriteLine("echo ping' 'pong")
|
||||||
|
pty.ExpectMatchContext(ctx, "ping pong")
|
||||||
|
|
||||||
|
for i, sock := range sockets {
|
||||||
|
i := i
|
||||||
|
// Start the listener on the "local machine".
|
||||||
|
l, err := net.Listen("unix", sock.local)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
|
||||||
|
testutil.Go(t, func() {
|
||||||
|
for {
|
||||||
|
fd, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
assert.NoError(t, err, "listener accept failed", i)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
testutil.Go(t, func() {
|
||||||
|
defer fd.Close()
|
||||||
|
agentssh.Bicopy(ctx, fd, fd)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Dial the forwarded socket on the "remote machine".
|
||||||
|
d := &net.Dialer{}
|
||||||
|
fd, err := d.DialContext(ctx, "unix", sock.remote)
|
||||||
|
require.NoError(t, err, i)
|
||||||
|
defer fd.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
|
||||||
|
|
||||||
|
// Ping / pong to ensure the socket is working.
|
||||||
|
_, err = fd.Write([]byte("hello world"))
|
||||||
|
require.NoError(t, err, i)
|
||||||
|
|
||||||
|
buf := make([]byte, 11)
|
||||||
|
_, err = fd.Read(buf)
|
||||||
|
require.NoError(t, err, i)
|
||||||
|
require.Equal(t, "hello world", string(buf), i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// And we're done.
|
||||||
|
pty.WriteLine("exit")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("FileLogging", func(t *testing.T) {
|
t.Run("FileLogging", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ OPTIONS:
|
||||||
behavior as non-blocking.
|
behavior as non-blocking.
|
||||||
DEPRECATED: Use --wait instead.
|
DEPRECATED: Use --wait instead.
|
||||||
|
|
||||||
-R, --remote-forward string, $CODER_SSH_REMOTE_FORWARD
|
-R, --remote-forward string-array, $CODER_SSH_REMOTE_FORWARD
|
||||||
Enable remote port forwarding (remote_port:local_address:local_port).
|
Enable remote port forwarding (remote_port:local_address:local_port).
|
||||||
|
|
||||||
--stdio bool, $CODER_SSH_STDIO
|
--stdio bool, $CODER_SSH_STDIO
|
||||||
|
|
|
@ -71,7 +71,7 @@ Enter workspace immediately after the agent has connected. This is the default i
|
||||||
|
|
||||||
| | |
|
| | |
|
||||||
| ----------- | -------------------------------------- |
|
| ----------- | -------------------------------------- |
|
||||||
| Type | <code>string</code> |
|
| Type | <code>string-array</code> |
|
||||||
| Environment | <code>$CODER_SSH_REMOTE_FORWARD</code> |
|
| Environment | <code>$CODER_SSH_REMOTE_FORWARD</code> |
|
||||||
|
|
||||||
Enable remote port forwarding (remote_port:local_address:local_port).
|
Enable remote port forwarding (remote_port:local_address:local_port).
|
||||||
|
|
Loading…
Reference in New Issue