coder/cli/remoteforward.go

105 lines
2.6 KiB
Go

package cli
import (
"context"
"fmt"
"io"
"net"
"regexp"
"strconv"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent/agentssh"
)
// cookieAddr is a special net.Addr accepted by sshRemoteForward() which includes a
// cookie which is written to the connection before forwarding.
type cookieAddr struct {
net.Addr
cookie []byte
}
// Format:
// remote_port:local_address:local_port
var remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
func validateRemoteForward(flag string) bool {
return remoteForwardRegex.MatchString(flag)
}
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
matches := remoteForwardRegex.FindStringSubmatch(flag)
remotePort, err := strconv.Atoi(matches[1])
if err != nil {
return nil, nil, xerrors.Errorf("remote port is invalid: %w", err)
}
localAddress, err := net.ResolveIPAddr("ip", matches[2])
if err != nil {
return nil, nil, xerrors.Errorf("local address is invalid: %w", err)
}
localPort, err := strconv.Atoi(matches[3])
if err != nil {
return nil, nil, xerrors.Errorf("local port is invalid: %w", err)
}
localAddr := &net.TCPAddr{
IP: localAddress.IP,
Port: localPort,
}
remoteAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: remotePort,
}
return localAddr, remoteAddr, nil
}
// sshRemoteForward starts forwarding connections from a remote listener to a
// local address via SSH in a goroutine.
//
// Accepts a `cookieAddr` as the local address.
func sshRemoteForward(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
if err != nil {
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
}
go func() {
for {
remoteConn, err := listener.Accept()
if err != nil {
if ctx.Err() == nil {
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
}
return
}
go func() {
defer remoteConn.Close()
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
if err != nil {
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
return
}
defer localConn.Close()
if c, ok := localAddr.(cookieAddr); ok {
_, err = localConn.Write(c.cookie)
if err != nil {
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
return
}
}
agentssh.Bicopy(ctx, localConn, remoteConn)
}()
}
}()
return listener, nil
}