feat: add ssh support over wireguard (#2642)

This commit is contained in:
Colin Adler 2022-06-24 16:21:46 -05:00 committed by GitHub
parent 26e85b0bbc
commit 6aed58f486
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 190 additions and 34 deletions

View File

@ -2,6 +2,8 @@ package agent
import (
"context"
"net"
"strconv"
"golang.org/x/xerrors"
"inet.af/netaddr"
@ -58,6 +60,38 @@ func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) er
}
}()
a.startWireguardListeners(ctx, wg, []handlerPort{
{port: 12212, handler: a.sshServer.HandleConn},
})
a.network = wg
return nil
}
type handlerPort struct {
handler func(conn net.Conn)
port uint16
}
func (a *agent) startWireguardListeners(ctx context.Context, network *peerwg.Network, handlers []handlerPort) {
for _, h := range handlers {
go func(h handlerPort) {
a.logger.Debug(ctx, "starting wireguard listener", slog.F("port", h.port))
listener, err := network.Listen("tcp", net.JoinHostPort("", strconv.Itoa(int(h.port))))
if err != nil {
a.logger.Warn(ctx, "listen wireguard", slog.F("port", h.port), slog.Error(err))
return
}
for {
conn, err := listener.Accept()
if err != nil {
return
}
go h.handler(conn)
}
}(h)
}
}

View File

@ -135,6 +135,7 @@ func configSSH() *cobra.Command {
coderConfigFile string
dryRun bool
skipProxyCommand bool
wireguard bool
)
cmd := &cobra.Command{
Annotations: workspaceCommand,
@ -287,7 +288,11 @@ func configSSH() *cobra.Command {
"\tLogLevel ERROR",
)
if !skipProxyCommand {
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname))
if !wireguard {
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname))
} else {
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --wireguard --stdio %s", binaryFile, root, hostname))
}
}
_, _ = buf.WriteString(strings.Join(configOptions, "\n"))
@ -374,6 +379,8 @@ func configSSH() *cobra.Command {
cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.")
_ = cmd.Flags().MarkHidden("skip-proxy-command")
cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.")
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
_ = cmd.Flags().MarkHidden("wireguard")
// Deprecated: Remove after migration period.
cmd.Flags().StringVar(&coderConfigFile, "test.ssh-coder-config-file", sshDefaultCoderConfigFileName, "Specifies the path to an Coder SSH config file. Useful for testing.")

View File

@ -18,13 +18,18 @@ import (
gosshagent "golang.org/x/crypto/ssh/agent"
"golang.org/x/term"
"golang.org/x/xerrors"
"inet.af/netaddr"
tslogger "tailscale.com/types/logger"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/autobuild/notify"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/peer/peerwg"
)
var workspacePollInterval = time.Minute
@ -37,6 +42,7 @@ func ssh() *cobra.Command {
forwardAgent bool
identityAgent string
wsPollInterval time.Duration
wireguard bool
)
cmd := &cobra.Command{
Annotations: workspaceCommand,
@ -61,7 +67,7 @@ func ssh() *cobra.Command {
}
}
workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
if err != nil {
return err
}
@ -71,41 +77,104 @@ func ssh() *cobra.Command {
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, agent.ID)
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
})
if err != nil {
return xerrors.Errorf("await agent: %w", err)
}
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
if err != nil {
return err
}
defer conn.Close()
var (
sshClient *gossh.Client
sshSession *gossh.Session
)
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
defer stopPolling()
if stdio {
rawSSH, err := conn.SSH()
if !wireguard {
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
if err != nil {
return err
}
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}
sshClient, err := conn.SSHClient()
if err != nil {
return err
}
defer conn.Close()
sshSession, err := sshClient.NewSession()
if err != nil {
return err
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
defer stopPolling()
if stdio {
rawSSH, err := conn.SSH()
if err != nil {
return err
}
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}
sshClient, err = conn.SSHClient()
if err != nil {
return err
}
sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
} else {
// TODO: more granual control of Tailscale logging.
peerwg.Logf = tslogger.Discard
ipv6 := peerwg.UUIDToNetaddr(uuid.New())
wgn, err := peerwg.New(
slog.Make(sloghuman.Sink(os.Stderr)),
[]netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)},
)
if err != nil {
return xerrors.Errorf("create wireguard network: %w", err)
}
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey,
IPv6: ipv6,
})
if err != nil {
return xerrors.Errorf("post wireguard peer: %w", err)
}
err = wgn.AddPeer(peerwg.Handshake{
Recipient: workspaceAgent.ID,
DiscoPublicKey: workspaceAgent.DiscoPublicKey,
NodePublicKey: workspaceAgent.WireguardPublicKey,
IPv6: workspaceAgent.IPv6.IP(),
})
if err != nil {
return xerrors.Errorf("add workspace agent as peer: %w", err)
}
if stdio {
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP())
if err != nil {
return err
}
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}
sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP())
if err != nil {
return err
}
sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
}
if identityAgent == "" {
@ -174,6 +243,8 @@ func ssh() *cobra.Command {
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
_ = cmd.Flags().MarkHidden("wireguard")
return cmd
}

38
peer/peerwg/ssh.go Normal file
View File

@ -0,0 +1,38 @@
package peerwg
import (
"context"
"net"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"inet.af/netaddr"
)
func (n *Network) SSH(ctx context.Context, ip netaddr.IP) (net.Conn, error) {
netConn, err := n.Netstack.DialContextTCP(ctx, netaddr.IPPortFrom(ip, 12212))
if err != nil {
return nil, xerrors.Errorf("dial agent ssh: %w", err)
}
return netConn, nil
}
func (n *Network) SSHClient(ctx context.Context, ip netaddr.IP) (*ssh.Client, error) {
netConn, err := n.SSH(ctx, ip)
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
}
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, xerrors.Errorf("new ssh client conn: %w", err)
}
return ssh.NewClient(sshConn, channels, requests), nil
}

View File

@ -35,7 +35,7 @@ import (
"cdr.dev/slog"
)
var logf tslogger.Logf = log.Printf
var Logf tslogger.Logf = log.Printf
func init() {
// Globally disable network namespacing.
@ -139,15 +139,15 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
DERP: DefaultDerpHome,
}
wgMonitor, err := monitor.New(logf)
wgMonitor, err := monitor.New(Logf)
if err != nil {
return nil, xerrors.Errorf("create link monitor: %w", err)
}
dialer := new(tsdial.Dialer)
dialer.Logf = logf
dialer.Logf = Logf
// Create a wireguard engine in userspace.
engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
engine, err := wgengine.NewUserspaceEngine(Logf, wgengine.Config{
LinkMonitor: wgMonitor,
Dialer: dialer,
})
@ -172,7 +172,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
// Create the networking stack.
// This is called to route connections.
netStack, err := netstack.Create(logf, tunDev, engine, magicConn, dialer, dnsManager)
netStack, err := netstack.Create(Logf, tunDev, engine, magicConn, dialer, dnsManager)
if err != nil {
return nil, xerrors.Errorf("create netstack: %w", err)
}
@ -192,7 +192,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
engine = wgengine.NewWatchdog(engine)
// Update the wireguard configuration to allow traffic to flow.
cfg, err := nmcfg.WGCfg(netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
cfg, err := nmcfg.WGCfg(netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
if err != nil {
return nil, xerrors.Errorf("create wgcfg: %w", err)
}
@ -216,7 +216,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
iplb := netaddr.IPSetBuilder{}
ipl, _ := iplb.IPSet()
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf))
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, Logf))
wn := &Network{
logger: logger,
@ -319,7 +319,7 @@ func (n *Network) AddPeer(handshake Handshake) error {
n.netMap.Peers = peers
cfg, err := nmcfg.WGCfg(n.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
cfg, err := nmcfg.WGCfg(n.netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
if err != nil {
return xerrors.Errorf("create wgcfg: %w", err)
}
@ -375,6 +375,12 @@ func (n *Network) Listen(network, addr string) (net.Listener, error) {
}
func (n *Network) Close() error {
// Close all listeners.
for _, l := range n.listeners {
_ = l.Close()
}
// Close the Wireguard netstack and engine.
_ = n.Netstack.Close()
n.wgEngine.Close()