mirror of https://github.com/coder/coder.git
feat: add ssh support over wireguard (#2642)
This commit is contained in:
parent
26e85b0bbc
commit
6aed58f486
|
@ -2,6 +2,8 @@ package agent
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"inet.af/netaddr"
|
||||
|
@ -58,6 +60,38 @@ func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) er
|
|||
}
|
||||
}()
|
||||
|
||||
a.startWireguardListeners(ctx, wg, []handlerPort{
|
||||
{port: 12212, handler: a.sshServer.HandleConn},
|
||||
})
|
||||
|
||||
a.network = wg
|
||||
return nil
|
||||
}
|
||||
|
||||
type handlerPort struct {
|
||||
handler func(conn net.Conn)
|
||||
port uint16
|
||||
}
|
||||
|
||||
func (a *agent) startWireguardListeners(ctx context.Context, network *peerwg.Network, handlers []handlerPort) {
|
||||
for _, h := range handlers {
|
||||
go func(h handlerPort) {
|
||||
a.logger.Debug(ctx, "starting wireguard listener", slog.F("port", h.port))
|
||||
|
||||
listener, err := network.Listen("tcp", net.JoinHostPort("", strconv.Itoa(int(h.port))))
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "listen wireguard", slog.F("port", h.port), slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go h.handler(conn)
|
||||
}
|
||||
}(h)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -135,6 +135,7 @@ func configSSH() *cobra.Command {
|
|||
coderConfigFile string
|
||||
dryRun bool
|
||||
skipProxyCommand bool
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
|
@ -287,7 +288,11 @@ func configSSH() *cobra.Command {
|
|||
"\tLogLevel ERROR",
|
||||
)
|
||||
if !skipProxyCommand {
|
||||
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname))
|
||||
if !wireguard {
|
||||
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --stdio %s", binaryFile, root, hostname))
|
||||
} else {
|
||||
configOptions = append(configOptions, fmt.Sprintf("\tProxyCommand %q --global-config %q ssh --wireguard --stdio %s", binaryFile, root, hostname))
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = buf.WriteString(strings.Join(configOptions, "\n"))
|
||||
|
@ -374,6 +379,8 @@ func configSSH() *cobra.Command {
|
|||
cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.")
|
||||
_ = cmd.Flags().MarkHidden("skip-proxy-command")
|
||||
cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
|
||||
_ = cmd.Flags().MarkHidden("wireguard")
|
||||
|
||||
// Deprecated: Remove after migration period.
|
||||
cmd.Flags().StringVar(&coderConfigFile, "test.ssh-coder-config-file", sshDefaultCoderConfigFileName, "Specifies the path to an Coder SSH config file. Useful for testing.")
|
||||
|
|
121
cli/ssh.go
121
cli/ssh.go
|
@ -18,13 +18,18 @@ import (
|
|||
gosshagent "golang.org/x/crypto/ssh/agent"
|
||||
"golang.org/x/term"
|
||||
"golang.org/x/xerrors"
|
||||
"inet.af/netaddr"
|
||||
tslogger "tailscale.com/types/logger"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/autobuild/notify"
|
||||
"github.com/coder/coder/coderd/util/ptr"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/peer/peerwg"
|
||||
)
|
||||
|
||||
var workspacePollInterval = time.Minute
|
||||
|
@ -37,6 +42,7 @@ func ssh() *cobra.Command {
|
|||
forwardAgent bool
|
||||
identityAgent string
|
||||
wsPollInterval time.Duration
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
|
@ -61,7 +67,7 @@ func ssh() *cobra.Command {
|
|||
}
|
||||
}
|
||||
|
||||
workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
|
||||
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -71,41 +77,104 @@ func ssh() *cobra.Command {
|
|||
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
|
||||
WorkspaceName: workspace.Name,
|
||||
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
|
||||
return client.WorkspaceAgent(ctx, agent.ID)
|
||||
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
|
||||
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
var (
|
||||
sshClient *gossh.Client
|
||||
sshSession *gossh.Session
|
||||
)
|
||||
|
||||
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
|
||||
defer stopPolling()
|
||||
|
||||
if stdio {
|
||||
rawSSH, err := conn.SSH()
|
||||
if !wireguard {
|
||||
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
|
||||
}()
|
||||
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
|
||||
return nil
|
||||
}
|
||||
sshClient, err := conn.SSHClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
sshSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace)
|
||||
defer stopPolling()
|
||||
|
||||
if stdio {
|
||||
rawSSH, err := conn.SSH()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
|
||||
}()
|
||||
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
|
||||
return nil
|
||||
}
|
||||
|
||||
sshClient, err = conn.SSHClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sshSession, err = sshClient.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// TODO: more granual control of Tailscale logging.
|
||||
peerwg.Logf = tslogger.Discard
|
||||
|
||||
ipv6 := peerwg.UUIDToNetaddr(uuid.New())
|
||||
wgn, err := peerwg.New(
|
||||
slog.Make(sloghuman.Sink(os.Stderr)),
|
||||
[]netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)},
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create wireguard network: %w", err)
|
||||
}
|
||||
|
||||
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
|
||||
Recipient: workspaceAgent.ID,
|
||||
NodePublicKey: wgn.NodePrivateKey.Public(),
|
||||
DiscoPublicKey: wgn.DiscoPublicKey,
|
||||
IPv6: ipv6,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("post wireguard peer: %w", err)
|
||||
}
|
||||
|
||||
err = wgn.AddPeer(peerwg.Handshake{
|
||||
Recipient: workspaceAgent.ID,
|
||||
DiscoPublicKey: workspaceAgent.DiscoPublicKey,
|
||||
NodePublicKey: workspaceAgent.WireguardPublicKey,
|
||||
IPv6: workspaceAgent.IPv6.IP(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("add workspace agent as peer: %w", err)
|
||||
}
|
||||
|
||||
if stdio {
|
||||
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
|
||||
}()
|
||||
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
|
||||
return nil
|
||||
}
|
||||
|
||||
sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sshSession, err = sshClient.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if identityAgent == "" {
|
||||
|
@ -174,6 +243,8 @@ func ssh() *cobra.Command {
|
|||
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
|
||||
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
|
||||
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
|
||||
_ = cmd.Flags().MarkHidden("wireguard")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
package peerwg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
func (n *Network) SSH(ctx context.Context, ip netaddr.IP) (net.Conn, error) {
|
||||
netConn, err := n.Netstack.DialContextTCP(ctx, netaddr.IPPortFrom(ip, 12212))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial agent ssh: %w", err)
|
||||
}
|
||||
|
||||
return netConn, nil
|
||||
}
|
||||
|
||||
func (n *Network) SSHClient(ctx context.Context, ip netaddr.IP) (*ssh.Client, error) {
|
||||
netConn, err := n.SSH(ctx, ip)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh: %w", err)
|
||||
}
|
||||
|
||||
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
|
||||
// SSH host validation isn't helpful, because obtaining a peer
|
||||
// connection already signifies user-intent to dial a workspace.
|
||||
// #nosec
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("new ssh client conn: %w", err)
|
||||
}
|
||||
|
||||
return ssh.NewClient(sshConn, channels, requests), nil
|
||||
}
|
|
@ -35,7 +35,7 @@ import (
|
|||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
var logf tslogger.Logf = log.Printf
|
||||
var Logf tslogger.Logf = log.Printf
|
||||
|
||||
func init() {
|
||||
// Globally disable network namespacing.
|
||||
|
@ -139,15 +139,15 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
|
|||
DERP: DefaultDerpHome,
|
||||
}
|
||||
|
||||
wgMonitor, err := monitor.New(logf)
|
||||
wgMonitor, err := monitor.New(Logf)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create link monitor: %w", err)
|
||||
}
|
||||
|
||||
dialer := new(tsdial.Dialer)
|
||||
dialer.Logf = logf
|
||||
dialer.Logf = Logf
|
||||
// Create a wireguard engine in userspace.
|
||||
engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
|
||||
engine, err := wgengine.NewUserspaceEngine(Logf, wgengine.Config{
|
||||
LinkMonitor: wgMonitor,
|
||||
Dialer: dialer,
|
||||
})
|
||||
|
@ -172,7 +172,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
|
|||
|
||||
// Create the networking stack.
|
||||
// This is called to route connections.
|
||||
netStack, err := netstack.Create(logf, tunDev, engine, magicConn, dialer, dnsManager)
|
||||
netStack, err := netstack.Create(Logf, tunDev, engine, magicConn, dialer, dnsManager)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create netstack: %w", err)
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
|
|||
engine = wgengine.NewWatchdog(engine)
|
||||
|
||||
// Update the wireguard configuration to allow traffic to flow.
|
||||
cfg, err := nmcfg.WGCfg(netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
|
||||
cfg, err := nmcfg.WGCfg(netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create wgcfg: %w", err)
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
|
|||
|
||||
iplb := netaddr.IPSetBuilder{}
|
||||
ipl, _ := iplb.IPSet()
|
||||
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf))
|
||||
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, Logf))
|
||||
|
||||
wn := &Network{
|
||||
logger: logger,
|
||||
|
@ -319,7 +319,7 @@ func (n *Network) AddPeer(handshake Handshake) error {
|
|||
|
||||
n.netMap.Peers = peers
|
||||
|
||||
cfg, err := nmcfg.WGCfg(n.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
|
||||
cfg, err := nmcfg.WGCfg(n.netMap, Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create wgcfg: %w", err)
|
||||
}
|
||||
|
@ -375,6 +375,12 @@ func (n *Network) Listen(network, addr string) (net.Listener, error) {
|
|||
}
|
||||
|
||||
func (n *Network) Close() error {
|
||||
// Close all listeners.
|
||||
for _, l := range n.listeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
|
||||
// Close the Wireguard netstack and engine.
|
||||
_ = n.Netstack.Close()
|
||||
n.wgEngine.Close()
|
||||
|
||||
|
|
Loading…
Reference in New Issue