coder/cli/remoteforward.go

148 lines
3.8 KiB
Go

package cli
import (
"context"
"fmt"
"io"
"net"
"regexp"
"strconv"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent/agentssh"
)
// cookieAddr is a special net.Addr accepted by sshRemoteForward() which includes a
// cookie which is written to the connection before forwarding.
type cookieAddr struct {
net.Addr
cookie []byte
}
// Format:
// remote_port:local_address:local_port
var remoteForwardRegexTCP = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
// remote_socket_path:local_socket_path (both absolute paths)
var remoteForwardRegexUnixSocket = regexp.MustCompile(`^(\/.+):(\/.+)$`)
func isRemoteForwardTCP(flag string) bool {
return remoteForwardRegexTCP.MatchString(flag)
}
func isRemoteForwardUnixSocket(flag string) bool {
return remoteForwardRegexUnixSocket.MatchString(flag)
}
func validateRemoteForward(flag string) bool {
return isRemoteForwardTCP(flag) || isRemoteForwardUnixSocket(flag)
}
func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
remotePort, err := strconv.Atoi(matches[1])
if err != nil {
return nil, nil, xerrors.Errorf("remote port is invalid: %w", err)
}
localAddress, err := net.ResolveIPAddr("ip", matches[2])
if err != nil {
return nil, nil, xerrors.Errorf("local address is invalid: %w", err)
}
localPort, err := strconv.Atoi(matches[3])
if err != nil {
return nil, nil, xerrors.Errorf("local port is invalid: %w", err)
}
localAddr := &net.TCPAddr{
IP: localAddress.IP,
Port: localPort,
}
remoteAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: remotePort,
}
return localAddr, remoteAddr, nil
}
// parseRemoteForwardUnixSocket parses a remote forward flag. Note that
// we don't verify that the local socket path exists because the user
// may create it later. This behavior matches OpenSSH.
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
remoteSocket := matches[1]
localSocket := matches[2]
remoteAddr := &net.UnixAddr{
Name: remoteSocket,
Net: "unix",
}
localAddr := &net.UnixAddr{
Name: localSocket,
Net: "unix",
}
return localAddr, remoteAddr, nil
}
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
tcpMatches := remoteForwardRegexTCP.FindStringSubmatch(flag)
if len(tcpMatches) > 0 {
return parseRemoteForwardTCP(tcpMatches)
}
unixSocketMatches := remoteForwardRegexUnixSocket.FindStringSubmatch(flag)
if len(unixSocketMatches) > 0 {
return parseRemoteForwardUnixSocket(unixSocketMatches)
}
return nil, nil, xerrors.New("Could not match forward arguments")
}
// sshRemoteForward starts forwarding connections from a remote listener to a
// local address via SSH in a goroutine.
//
// Accepts a `cookieAddr` as the local address.
func sshRemoteForward(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
if err != nil {
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
}
go func() {
for {
remoteConn, err := listener.Accept()
if err != nil {
if ctx.Err() == nil {
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
}
return
}
go func() {
defer remoteConn.Close()
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
if err != nil {
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
return
}
defer localConn.Close()
if c, ok := localAddr.(cookieAddr); ok {
_, err = localConn.Write(c.cookie)
if err != nil {
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
return
}
}
agentssh.Bicopy(ctx, localConn, remoteConn)
}()
}
}()
return listener, nil
}