feat: add GPG forwarding to coder ssh (#5482)

This commit is contained in:
Dean Sheather 2023-01-06 01:52:19 -06:00 committed by GitHub
parent 59e919ab4a
commit f1fe2b5c06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1051 additions and 22 deletions

View File

@ -480,12 +480,16 @@ func (a *agent) init(ctx context.Context) {
if err != nil {
panic(err)
}
sshLogger := a.logger.Named("ssh-server")
forwardHandler := &ssh.ForwardedTCPHandler{}
unixForwardHandler := &forwardedUnixHandler{log: a.logger}
a.sshServer = &ssh.Server{
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": ssh.DirectTCPIPHandler,
"session": ssh.DefaultSessionHandler,
"direct-tcpip": ssh.DirectTCPIPHandler,
"direct-streamlocal@openssh.com": directStreamLocalHandler,
"session": ssh.DefaultSessionHandler,
},
ConnectionFailedCallback: func(conn net.Conn, err error) {
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
@ -525,8 +529,10 @@ func (a *agent) init(ctx context.Context) {
return true
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
},
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return &gossh.ServerConfig{

View File

@ -273,7 +273,7 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
}
//nolint:paralleltest // This test reserves a port.
func TestAgent_LocalForwarding(t *testing.T) {
func TestAgent_TCPLocalForwarding(t *testing.T) {
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
_ = random.Close()
@ -286,7 +286,7 @@ func TestAgent_LocalForwarding(t *testing.T) {
defer local.Close()
tcpAddr, valid = local.Addr().(*net.TCPAddr)
require.True(t, valid)
localPort := tcpAddr.Port
remotePort := tcpAddr.Port
done := make(chan struct{})
go func() {
defer close(done)
@ -294,16 +294,231 @@ func TestAgent_LocalForwarding(t *testing.T) {
if !assert.NoError(t, err) {
return
}
_ = conn.Close()
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()
err = setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, localPort)}, []string{"echo", "test"}).Start()
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "10"})
err = cmd.Start()
require.NoError(t, err)
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(localPort))
require.NoError(t, err)
conn.Close()
require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort))
if err != nil {
return false
}
defer conn.Close()
_, err = conn.Write([]byte("test"))
if !assert.NoError(t, err) {
return false
}
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return false
}
if !assert.Equal(t, "test", string(b)) {
return false
}
return true
}, testutil.WaitLong, testutil.IntervalSlow)
<-done
_ = cmd.Process.Kill()
}
//nolint:paralleltest // This test reserves a port.
func TestAgent_TCPRemoteForwarding(t *testing.T) {
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
_ = random.Close()
tcpAddr, valid := random.Addr().(*net.TCPAddr)
require.True(t, valid)
randomPort := tcpAddr.Port
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer l.Close()
tcpAddr, valid = l.Addr().(*net.TCPAddr)
require.True(t, valid)
localPort := tcpAddr.Port
done := make(chan struct{})
go func() {
defer close(done)
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "10"})
err = cmd.Start()
require.NoError(t, err)
require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
if err != nil {
return false
}
defer conn.Close()
_, err = conn.Write([]byte("test"))
if !assert.NoError(t, err) {
return false
}
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return false
}
if !assert.Equal(t, "test", string(b)) {
return false
}
return true
}, testutil.WaitLong, testutil.IntervalSlow)
<-done
_ = cmd.Process.Kill()
}
func TestAgent_UnixLocalForwarding(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
tmpdir := tempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
localSocketPath := filepath.Join(tmpdir, "local-socket")
l, err := net.Listen("unix", remoteSocketPath)
require.NoError(t, err)
defer l.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "10"})
err = cmd.Start()
require.NoError(t, err)
require.Eventually(t, func() bool {
_, err := os.Stat(localSocketPath)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
conn, err := net.Dial("unix", localSocketPath)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("test"))
require.NoError(t, err)
b := make([]byte, 4)
_, err = conn.Read(b)
require.NoError(t, err)
require.Equal(t, "test", string(b))
_ = conn.Close()
<-done
_ = cmd.Process.Kill()
}
func TestAgent_UnixRemoteForwarding(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
tmpdir := tempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
localSocketPath := filepath.Join(tmpdir, "local-socket")
l, err := net.Listen("unix", localSocketPath)
require.NoError(t, err)
defer l.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()
cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "10"})
err = cmd.Start()
require.NoError(t, err)
require.Eventually(t, func() bool {
_, err := os.Stat(remoteSocketPath)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
conn, err := net.Dial("unix", remoteSocketPath)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("test"))
require.NoError(t, err)
b := make([]byte, 4)
_, err = conn.Read(b)
require.NoError(t, err)
require.Equal(t, "test", string(b))
_ = conn.Close()
<-done
_ = cmd.Process.Kill()
}
func TestAgent_SFTP(t *testing.T) {
@ -733,7 +948,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
args := append(beforeArgs,
"-o", "HostName "+tcpAddr.IP.String(),
"-o", "Port "+strconv.Itoa(tcpAddr.Port),
"-o", "StrictHostKeyChecking=no", "host")
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"host",
)
args = append(args, afterArgs...)
return exec.Command("ssh", args...)
}
@ -919,3 +1137,26 @@ func (*client) PostWorkspaceAgentAppHealth(_ context.Context, _ codersdk.PostWor
func (*client) PostWorkspaceAgentVersion(_ context.Context, _ string) error {
return nil
}
// tempDirUnixSocket returns a temporary directory that can safely hold unix
// sockets (probably).
//
// During tests on darwin we hit the max path length limit for unix sockets
// pretty easily in the default location, so this function uses /tmp instead to
// get shorter paths.
func tempDirUnixSocket(t *testing.T) string {
t.Helper()
if runtime.GOOS == "darwin" {
testName := strings.ReplaceAll(t.Name(), "/", "_")
dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("coder-test-%s-", testName))
require.NoError(t, err, "create temp dir for gpg test")
t.Cleanup(func() {
err := os.RemoveAll(dir)
assert.NoError(t, err, "remove temp dir", dir)
})
return dir
}
return t.TempDir()
}

203
agent/ssh.go Normal file
View File

@ -0,0 +1,203 @@
package agent
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"sync"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
// streamLocalForwardPayload describes the extra data sent in a
// streamlocal-forward@openssh.com containing the socket path to bind to.
type streamLocalForwardPayload struct {
SocketPath string
}
// forwardedStreamLocalPayload describes the data sent as the payload in the new
// channel request when a Unix connection is accepted by the listener.
type forwardedStreamLocalPayload struct {
SocketPath string
Reserved uint32
}
// forwardedUnixHandler is a clone of ssh.ForwardedTCPHandler that does
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
type forwardedUnixHandler struct {
sync.Mutex
log slog.Logger
forwards map[string]net.Listener
}
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
h.Lock()
if h.forwards == nil {
h.forwards = make(map[string]net.Listener)
}
h.Unlock()
conn, ok := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
if !ok {
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
return false, nil
}
switch req.Type {
case "streamlocal-forward@openssh.com":
var reqPayload streamLocalForwardPayload
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request payload from client", slog.Error(err))
return false, nil
}
addr := reqPayload.SocketPath
h.Lock()
_, ok := h.forwards[addr]
h.Unlock()
if ok {
h.log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
slog.F("socket_path", addr),
)
return false, nil
}
// Create socket parent dir if not exists.
parentDir := filepath.Dir(addr)
err = os.MkdirAll(parentDir, 0700)
if err != nil {
h.log.Warn(ctx, "create parent dir for SSH unix forward request",
slog.F("parent_dir", parentDir),
slog.F("socket_path", addr),
slog.Error(err),
)
return false, nil
}
ln, err := net.Listen("unix", addr)
if err != nil {
h.log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
slog.F("socket_path", addr),
slog.Error(err),
)
return false, nil
}
// The listener needs to successfully start before it can be added to
// the map, so we don't have to worry about checking for an existing
// listener.
//
// This is also what the upstream TCP version of this code does.
h.Lock()
h.forwards[addr] = ln
h.Unlock()
ctx, cancel := context.WithCancel(ctx)
go func() {
<-ctx.Done()
_ = ln.Close()
}()
go func() {
defer cancel()
for {
c, err := ln.Accept()
if err != nil {
if !xerrors.Is(err, net.ErrClosed) {
h.log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
slog.F("socket_path", addr),
slog.Error(err),
)
}
// closed below
break
}
payload := gossh.Marshal(&forwardedStreamLocalPayload{
SocketPath: addr,
})
go func() {
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
if err != nil {
h.log.Warn(ctx, "open SSH channel to forward Unix connection to client",
slog.F("socket_path", addr),
slog.Error(err),
)
_ = c.Close()
return
}
go gossh.DiscardRequests(reqs)
Bicopy(ctx, ch, c)
}()
}
h.Lock()
ln2, ok := h.forwards[addr]
if ok && ln2 == ln {
delete(h.forwards, addr)
}
h.Unlock()
_ = ln.Close()
}()
return true, nil
case "cancel-streamlocal-forward@openssh.com":
var reqPayload streamLocalForwardPayload
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err))
return false, nil
}
h.Lock()
ln, ok := h.forwards[reqPayload.SocketPath]
h.Unlock()
if ok {
_ = ln.Close()
}
return true, nil
default:
return false, nil
}
}
// directStreamLocalPayload describes the extra data sent in a
// direct-streamlocal@openssh.com channel request containing the socket path.
type directStreamLocalPayload struct {
SocketPath string
Reserved1 string
Reserved2 uint32
}
func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
var reqPayload directStreamLocalPayload
err := gossh.Unmarshal(newChan.ExtraData(), &reqPayload)
if err != nil {
_ = newChan.Reject(gossh.ConnectionFailed, "could not parse direct-streamlocal@openssh.com channel payload")
return
}
var dialer net.Dialer
dconn, err := dialer.DialContext(ctx, "unix", reqPayload.SocketPath)
if err != nil {
_ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", reqPayload.SocketPath, err.Error()))
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go gossh.DiscardRequests(reqs)
Bicopy(ctx, ch, dconn)
}

View File

@ -575,14 +575,17 @@ func TestServer(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
httpListenAddr := ""
if c.httpListener {
httpListenAddr = ":0"
}
certPath, keyPath := generateTLSCertificate(t)
flags := []string{
"server",
"--in-memory",
"--cache-dir", t.TempDir(),
}
if c.httpListener {
flags = append(flags, "--http-address", ":0")
"--http-address", httpListenAddr,
}
if c.tlsListener {
flags = append(flags,

View File

@ -1,12 +1,15 @@
package cli
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
@ -21,6 +24,7 @@ import (
"golang.org/x/term"
"golang.org/x/xerrors"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/autobuild/notify"
@ -39,6 +43,7 @@ func ssh() *cobra.Command {
stdio bool
shuffle bool
forwardAgent bool
forwardGPG bool
identityAgent string
wsPollInterval time.Duration
)
@ -138,7 +143,7 @@ func ssh() *cobra.Command {
if forwardAgent && identityAgent != "" {
err = gosshagent.ForwardToRemote(sshClient, identityAgent)
if err != nil {
return xerrors.Errorf("forward agent failed: %w", err)
return xerrors.Errorf("forward agent: %w", err)
}
err = gosshagent.RequestAgentForwarding(sshSession)
if err != nil {
@ -146,6 +151,22 @@ func ssh() *cobra.Command {
}
}
if forwardGPG {
if workspaceAgent.OperatingSystem == "windows" {
return xerrors.New("GPG forwarding is not supported for Windows workspaces")
}
err = uploadGPGKeys(ctx, sshClient)
if err != nil {
return xerrors.Errorf("upload GPG public keys and ownertrust to workspace: %w", err)
}
closer, err := forwardGPGAgent(ctx, cmd.ErrOrStderr(), sshClient)
if err != nil {
return xerrors.Errorf("forward GPG socket: %w", err)
}
defer closer.Close()
}
stdoutFile, validOut := cmd.OutOrStdout().(*os.File)
stdinFile, validIn := cmd.InOrStdin().(*os.File)
if validOut && validIn && isatty.IsTerminal(stdoutFile.Fd()) {
@ -199,10 +220,12 @@ func ssh() *cobra.Command {
_ = sshSession.WindowChange(height, width)
}
}
err = sshSession.Wait()
if err != nil {
// If the connection drops unexpectedly, we get an ExitMissingError but no other
// error details, so try to at least give the user a better message
// If the connection drops unexpectedly, we get an
// ExitMissingError but no other error details, so try to at
// least give the user a better message
if errors.Is(err, &gossh.ExitMissingError{}) {
return xerrors.New("SSH connection ended unexpectedly")
}
@ -216,6 +239,7 @@ func ssh() *cobra.Command {
cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace")
_ = cmd.Flags().MarkHidden("shuffle")
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
cliflag.BoolVarP(cmd.Flags(), &forwardGPG, "forward-gpg", "G", "CODER_SSH_FORWARD_GPG", false, "Specifies whether to forward the GPG agent. Unsupported on Windows workspaces, but supports all clients. Requires gnupg (gpg, gpgconf) on both the client and workspace. The GPG agent must already be running locally and will not be started for you. If a GPG agent is already running in the workspace, it will be attempted to be killed.")
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
return cmd
@ -364,3 +388,184 @@ func verifyWorkspaceOutdated(client *codersdk.Client, workspace codersdk.Workspa
func buildWorkspaceLink(serverURL *url.URL, workspace codersdk.Workspace) *url.URL {
return serverURL.ResolveReference(&url.URL{Path: fmt.Sprintf("@%s/%s", workspace.OwnerName, workspace.Name)})
}
// runLocal runs a command on the local machine.
func runLocal(ctx context.Context, stdin io.Reader, name string, args ...string) ([]byte, error) {
cmd := exec.CommandContext(ctx, name, args...)
cmd.Stdin = stdin
out, err := cmd.Output()
if err != nil {
var stderr []byte
if exitErr := new(exec.ExitError); errors.As(err, &exitErr) {
stderr = exitErr.Stderr
}
return out, xerrors.Errorf(
"`%s %s` failed: stderr: %s\n\nstdout: %s\n\n%w",
name,
strings.Join(args, " "),
bytes.TrimSpace(stderr),
bytes.TrimSpace(out),
err,
)
}
return out, nil
}
// runRemoteSSH runs a command on a remote machine/workspace via SSH.
func runRemoteSSH(sshClient *gossh.Client, stdin io.Reader, cmd string) ([]byte, error) {
sess, err := sshClient.NewSession()
if err != nil {
return nil, xerrors.Errorf("create SSH session")
}
defer sess.Close()
stderr := bytes.NewBuffer(nil)
sess.Stdin = stdin
sess.Stderr = stderr
out, err := sess.Output(cmd)
if err != nil {
return out, xerrors.Errorf(
"`%s` failed: stderr: %s\n\nstdout: %s:\n\n%w",
cmd,
bytes.TrimSpace(stderr.Bytes()),
bytes.TrimSpace(out),
err,
)
}
return out, nil
}
func uploadGPGKeys(ctx context.Context, sshClient *gossh.Client) error {
// Check if the agent is running in the workspace already.
//
// Note: we don't support windows in the workspace for GPG forwarding so
// using shell commands is fine.
//
// Note: we sleep after killing the agent because it doesn't always die
// immediately.
agentSocketBytes, err := runRemoteSSH(sshClient, nil, `
set -eux
agent_socket=$(gpgconf --list-dir agent-socket)
echo "$agent_socket"
if [ -S "$agent_socket" ]; then
echo "agent socket exists, attempting to kill it" >&2
gpgconf --kill gpg-agent
rm -f "$agent_socket"
sleep 1
fi
test ! -S "$agent_socket"
`)
agentSocket := strings.TrimSpace(string(agentSocketBytes))
if err != nil {
return xerrors.Errorf("check if agent socket is running (check if %q exists): %w", agentSocket, err)
}
if agentSocket == "" {
return xerrors.Errorf("agent socket path is empty, check the output of `gpgconf --list-dir agent-socket`")
}
// Read the user's public keys and ownertrust from GPG.
pubKeyExport, err := runLocal(ctx, nil, "gpg", "--armor", "--export")
if err != nil {
return xerrors.Errorf("export local public keys from GPG: %w", err)
}
ownerTrustExport, err := runLocal(ctx, nil, "gpg", "--export-ownertrust")
if err != nil {
return xerrors.Errorf("export local ownertrust from GPG: %w", err)
}
// Import the public keys and ownertrust into the workspace.
_, err = runRemoteSSH(sshClient, bytes.NewReader(pubKeyExport), "gpg --import")
if err != nil {
return xerrors.Errorf("import public keys into workspace: %w", err)
}
_, err = runRemoteSSH(sshClient, bytes.NewReader(ownerTrustExport), "gpg --import-ownertrust")
if err != nil {
return xerrors.Errorf("import ownertrust into workspace: %w", err)
}
// Kill the agent in the workspace if it was started by one of the above
// commands.
_, err = runRemoteSSH(sshClient, nil, fmt.Sprintf("gpgconf --kill gpg-agent && rm -f %q", agentSocket))
if err != nil {
return xerrors.Errorf("kill existing agent in workspace: %w", err)
}
return nil
}
func localGPGExtraSocket(ctx context.Context) (string, error) {
localSocket, err := runLocal(ctx, nil, "gpgconf", "--list-dir", "agent-extra-socket")
if err != nil {
return "", xerrors.Errorf("get local GPG agent socket: %w", err)
}
return string(bytes.TrimSpace(localSocket)), nil
}
func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
remoteSocket, err := runRemoteSSH(sshClient, nil, "gpgconf --list-dir agent-socket")
if err != nil {
return "", xerrors.Errorf("get remote GPG agent socket: %w", err)
}
return string(bytes.TrimSpace(remoteSocket)), nil
}
// cookieAddr is a special net.Addr accepted by sshForward() which includes a
// cookie which is written to the connection before forwarding.
type cookieAddr struct {
net.Addr
cookie []byte
}
// sshForwardRemote starts forwarding connections from a remote listener to a
// local address via SSH in a goroutine.
//
// Accepts a `cookieAddr` as the local address.
func sshForwardRemote(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
if err != nil {
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
}
go func() {
for {
remoteConn, err := listener.Accept()
if err != nil {
if ctx.Err() == nil {
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
}
return
}
go func() {
defer remoteConn.Close()
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
if err != nil {
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
return
}
defer localConn.Close()
if c, ok := localAddr.(cookieAddr); ok {
_, err = localConn.Write(c.cookie)
if err != nil {
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
return
}
}
agent.Bicopy(ctx, localConn, remoteConn)
}()
}
}()
return listener, nil
}

View File

@ -5,9 +5,12 @@ package cli
import (
"context"
"io"
"net"
"os"
"os/signal"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/sys/unix"
)
@ -20,3 +23,26 @@ func listenWindowSize(ctx context.Context) <-chan os.Signal {
}()
return windowSize
}
func forwardGPGAgent(ctx context.Context, stderr io.Writer, sshClient *gossh.Client) (io.Closer, error) {
localSocket, err := localGPGExtraSocket(ctx)
if err != nil {
return nil, err
}
remoteSocket, err := remoteGPGAgentSocket(sshClient)
if err != nil {
return nil, err
}
localAddr := &net.UnixAddr{
Name: localSocket,
Net: "unix",
}
remoteAddr := &net.UnixAddr{
Name: remoteSocket,
Net: "unix",
}
return sshForwardRemote(ctx, stderr, sshClient, localAddr, remoteAddr)
}

View File

@ -1,15 +1,20 @@
package cli_test
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
@ -27,6 +32,7 @@ import (
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/pty"
"github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/testutil"
)
@ -226,7 +232,7 @@ func TestSSH(t *testing.T) {
})
// Start up ssh agent listening on unix socket.
tmpdir := t.TempDir()
tmpdir := tempDirUnixSocket(t)
agentSock := filepath.Join(tmpdir, "agent.sock")
l, err := net.Listen("unix", agentSock)
require.NoError(t, err)
@ -283,6 +289,224 @@ func TestSSH(t *testing.T) {
pty.WriteLine("exit")
<-cmdDone
})
//nolint:paralleltest // This test uses t.Setenv.
t.Run("ForwardGPG", func(t *testing.T) {
if runtime.GOOS == "windows" {
// While GPG forwarding from a Windows client works, we currently do
// not support forwarding to a Windows workspace. Our tests use the
// same platform for the "client" and "workspace" as they run in the
// same process.
t.Skip("Test not supported on windows")
}
// This key is for dean@coder.com.
const randPublicKeyFingerprint = "7BDFBA0CC7F5A96537C806C427BC6335EB5117F1"
const randPublicKey = `-----BEGIN PGP PUBLIC KEY BLOCK-----
mQINBF6SWkEBEADB8sAhBaT36VQ6HEhAmtKexLldu1HUdXNw16rdF+1wiBzSFfJN
aPeX4Y9iFIZgC2wU0wOjJ04BpioyOLtJngbThI5WpeoQ/1yQZOpnDaCMPPLp+uJ+
Gy4tMZYWQq21PukrFm3XDRGKjVN58QN6uCPb1S/YzteP8Epmq590GYIYLiAHnMt6
5iyxIFhXj/fq5Fddp2+efI7QWvNl2wTNnCaTziOSKYcbNmQpn9gy0WvKktWYtB8E
JJtWES0DzgCnDpm/hYx79Wkb+F7qY54y2uauDx+z97QXrON47lsIyGm8/T59ZfSd
/yrBqDLHYrHlt9RkFpAnBzO402y2eHsKTB6/EAHv9H2apxahyJlcxGbE5QE+fOJk
LdPlako0cSljz0g9Icesr2nZL0MhWwLnwk7DHkg/PUUijkbuR/TD9dti2/yOTFrf
Y7DdZpoZ0ZkcGu9lMh2vOTWc96RNCyIZfE5WNDKKo+u5Txzndsc/qIgKohwDSxTC
3hAulG5Wt05UeyHBEAAvGV2szG88VsGwd1juqXAbEzk+kLQzNyoQX188/4V4X+MV
pY9Wz7JudmQpB/3+YTcA/ziK/+wu3c2wNlr7gMZYMOwDWTLfW64nux7zHWDytrP0
HfgJIgqP7F7SnChpTFdb1hr1WDox99ZG+/eDkwxnuXYWm9xx5/crqQ0POQARAQAB
tClEZWFuIFNoZWF0aGVyICh3b3JrIGtleSkgPGRlYW5AY29kZXIuY29tPokCVAQT
AQgAPhYhBHvfugzH9allN8gGxCe8YzXrURfxBQJeklpBAhsDBQkJZgGABQsJCAcC
BhUKCQgLAgQWAgMBAh4BAheAAAoJECe8YzXrURfxIVkP/3UJMzvIjTNF63WiK4xk
TXlBbPKodnzUmAJ+8DVXmJMJpNsSI2czw6eFUXMcrT3JMlviOXhRWMLHr2FsQhyS
AJOQo0x9z7nntPIkvj96ihCdgRn7VN1WzaMwOOesGPr57StWLE84bg9/R0aSsxtX
LgfBCyNkv6FFlruhnw8+JdZJEjvIXQ9swvwD6L68ZLWIWcdnj/CjQmnmgFA+O4UO
SFXMUjklbrq8mJ0sAPUUATJK0SOTyqkZPkhqjlTZa8p0XoJF25trhwLhzDi4GPR6
SK/9SkqB/go9ZwkNZOjs2tP7eMExy4zQ21MFH09JMKQB7H5CG8GwdMwz4+VKc9aP
y9Ncova/p7Y8kJ7oQPWhACJT1jMP6620oC2N/7wwS0Vtc6E9LoPrfXC2TtvOA9qx
aOf6riWSjo8BEcXDuMtlW4g6IQFNd0+wcgcKrAd+vPLZnG4rtYL0Etdd1ymBT4pi
5E5uT8oUT9rLHX+2tD/E8SE5PzsaKEOJKzcOB8ESb3YBGic7+VvX/AuJuSFsuWnZ
FqAUENqfdz6+0dEJe1pfWyje+Q+o7B7u+ffMT4dOQOC8NfHFnz1kU+DA3VDE6xsu
3YN1L8KlYON92s9VWDA8VuvmU2d9pq5ysUeg133ftDSwj3X+5GYcBv4VFcSRCBW5
w0hDpMDun1t8xcXdo1LQ4R4NuQINBF6SWkEBEADF4Nrhlqc5M3Sz9sNHDJZR68zb
4CjkoOpYwsKj/ZCukzRCGKpT5Agn0zOycUjbAyCZVjREeIRRURyAhfpOmZY5yF6b
PD93+04OzWk1AaDRmMfvi1Crn/WUEVHIbDaisxDzNuAJgLrt93I/lOz06GczhCb6
sPBeKuaXCLl/5LSwTahGWsweeSCmfyrYsOc11T+SjdyWXWXEpzFNNIhvqiEoJCw3
IcdktTBJYuHsN4jh5kVemi/ttqRN3z7rBMKR1sPG3ux1MfCfSTSCeZLTN9eVvqm9
ne8brk8ZC6sdwlZ9IofPbmSaAh+F5Kfcnd3KjmyQ63t+8plpJ2YH3Fx6IwTwVEQ8
Ii3WQInTpBSPqf0EwnzRBvhYeKusRpcmX3JSmosLbd5uhvJdgotzuwZYzgay/6DL
OlwElZ//ecXNhU8iYmx1BwNuquvGcGVpkP5eaaT6O9qDznB7TT0xztfAK0LaAuRJ
HOFCc8iiHtQ4o0OkRhg/0KkUGBU5Iw5SIDimkgwJMtD3ZiYOqLaXS6kmmVw2u6YD
LB8rTpegz/tcX+4uyfnIZ28JCOYFTeaDT4FixFW2hrfo/VJzMI5IIv9XAAmtAiEU
f+CY2BT6kg9NkQuke0p4/W8yTaScapYZa5I2bzFpJJyzh1TKE6x3qcbBs9vVX+6E
vK4FflNwu9WSWojO2wARAQABiQI8BBgBCAAmFiEEe9+6DMf1qWU3yAbEJ7xjNetR
F/EFAl6SWkECGwwFCQlmAYAACgkQJ7xjNetRF/FpnQ//SIYePQzhvWj9drnT2krG
dUGSxCN0pA2UQZNkreAaKmyxn2/6xEdxYSz0iUEk+I0HKay+NLCxJ5PDoDBypFtM
f0yOnbWRObhim8HmED4JRw678G4hRU7KEN0L/9SUYlsBNbgr1xYM/CUX/Ih9NT+P
eApxs2VgjKii6m81nfBCFpWSxAs+TOnbshp8dlDZk9kxjFH9+h1ffgZjntqeyiWe
F1UE1Wh32MbJdtc2Y3mrA6i+7+3OXmqMHoiG1obhISgdpaCJ/ub3ywnAmeXSiAKE
IuS6CriR71Wqv8LMQ8kPM8On9Q26d1dsKKBnlFop9oexxf1AFsbbf9gkcgb+uNno
1Qr/R6l2H1TcV1gmiyQLzVnkgLRORosLvSlFrisrsLv9uTYYgcGvwKiU/o3PTdQg
fv0D7LB+a3C9KsCBFjihW3bTOcHKX2sAWEQXZMtKGf5aNTBmWQ+eKWUGpudXIvLE
od5lgfk9p8T1R50KDieG/+2X95zxFSYBoPRAfp7JNT7h+TZ55qUmQXZGI1VqhWiq
b6y/yqfI17JCm4oWpXYbgeruLuye2c/ptDc3S3d26hbWYiWKVT4bLtUGR0wuE6lS
DK0u4LK+mnrYfIvRDYJGx18/nbLpR+ivWLIssJT2Jyyj8w9+hk10XkODySNjHCxj
p7KeSZdlk47pMBGOfnvEmoQ=
=OxHv
-----END PGP PUBLIC KEY BLOCK-----`
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
gpgPath, err := exec.LookPath("gpg")
if err != nil {
t.Skip("gpg not found")
}
gpgConfPath, err := exec.LookPath("gpgconf")
if err != nil {
t.Skip("gpgconf not found")
}
gpgAgentPath, err := exec.LookPath("gpg-agent")
if err != nil {
t.Skip("gpg-agent not found")
}
// Setup GPG home directory on the "client".
gnupgHomeClient := tempDirUnixSocket(t)
t.Setenv("GNUPGHOME", gnupgHomeClient)
// Get the agent extra socket path.
var (
stdout = bytes.NewBuffer(nil)
stderr = bytes.NewBuffer(nil)
)
c := exec.CommandContext(ctx, gpgConfPath, "--list-dir", "agent-extra-socket")
c.Stdout = stdout
c.Stderr = stderr
err = c.Run()
require.NoError(t, err, "get extra socket path failed: %s", stderr.String())
extraSocketPath := strings.TrimSpace(stdout.String())
// Generate private key non-interactively.
genKeyScript := `
Key-Type: 1
Key-Length: 2048
Subkey-Type: 1
Subkey-Length: 2048
Name-Real: Coder Test
Name-Email: test@coder.com
Expire-Date: 0
%no-protection
`
c = exec.CommandContext(ctx, gpgPath, "--batch", "--gen-key")
c.Stdin = strings.NewReader(genKeyScript)
out, err := c.CombinedOutput()
require.NoError(t, err, "generate key failed: %s", out)
// Import a random public key.
stdin := strings.NewReader(randPublicKey + "\n")
c = exec.CommandContext(ctx, gpgPath, "--import", "-")
c.Stdin = stdin
out, err = c.CombinedOutput()
require.NoError(t, err, "import key failed: %s", out)
// Set ultimate trust on imported key.
stdin = strings.NewReader(randPublicKeyFingerprint + ":6:\n")
c = exec.CommandContext(ctx, gpgPath, "--import-ownertrust")
c.Stdin = stdin
out, err = c.CombinedOutput()
require.NoError(t, err, "import ownertrust failed: %s", out)
// Start the GPG agent.
agentCmd := exec.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath)
agentCmd.Env = append(agentCmd.Env, "GNUPGHOME="+gnupgHomeClient)
agentPTY, agentProc, err := pty.Start(agentCmd, pty.WithPTYOption(pty.WithGPGTTY()))
require.NoError(t, err, "launch agent failed")
defer func() {
_ = agentProc.Kill()
_ = agentPTY.Close()
}()
// Get the agent socket path in the "workspace".
gnupgHomeWorkspace := tempDirUnixSocket(t)
stdout = bytes.NewBuffer(nil)
stderr = bytes.NewBuffer(nil)
c = exec.CommandContext(ctx, gpgConfPath, "--list-dir", "agent-socket")
c.Env = append(c.Env, "GNUPGHOME="+gnupgHomeWorkspace)
c.Stdout = stdout
c.Stderr = stderr
err = c.Run()
require.NoError(t, err, "get agent socket path in workspace failed: %s", stderr.String())
workspaceAgentSocketPath := strings.TrimSpace(stdout.String())
require.NotEqual(t, extraSocketPath, workspaceAgentSocketPath, "socket path should be different")
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
agentClient := codersdk.New(client.URL)
agentClient.SetSessionToken(agentToken)
agentCloser := agent.New(agent.Options{
Client: agentClient,
EnvironmentVariables: map[string]string{
"GNUPGHOME": gnupgHomeWorkspace,
},
Logger: slogtest.Make(t, nil).Named("agent"),
})
defer agentCloser.Close()
cmd, root := clitest.New(t,
"ssh",
workspace.Name,
"--forward-gpg",
)
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
cmd.SetErr(pty.Output())
cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx)
assert.NoError(t, err, "ssh command failed")
})
// Prevent the test from hanging if the asserts below kill the test
// early. This will cause the command to exit with an error, which will
// let the t.Cleanup'd `<-done` inside of `tGo` exit and not hang.
// Without this, the test will hang forever on failure, preventing the
// real error from being printed.
t.Cleanup(cancel)
pty.WriteLine("echo hello 'world'")
pty.ExpectMatch("hello world")
// Check the GNUPGHOME was correctly inherited via shell.
pty.WriteLine("env && echo env-''-command-done")
match := pty.ExpectMatch("env--command-done")
require.Contains(t, match, "GNUPGHOME="+gnupgHomeWorkspace, match)
// Get the agent extra socket path in the "workspace" via shell.
pty.WriteLine("gpgconf --list-dir agent-socket && echo gpgconf-''-agentsocket-command-done")
pty.ExpectMatch(workspaceAgentSocketPath)
pty.ExpectMatch("gpgconf--agentsocket-command-done")
// List the keys in the "workspace".
pty.WriteLine("gpg --list-keys && echo gpg-''-listkeys-command-done")
listKeysOutput := pty.ExpectMatch("gpg--listkeys-command-done")
require.Contains(t, listKeysOutput, "[ultimate] Coder Test <test@coder.com>")
require.Contains(t, listKeysOutput, "[ultimate] Dean Sheather (work key) <dean@coder.com>")
// Try to sign something. This demonstrates that the forwarding is
// working as expected, since the workspace doesn't have access to the
// private key directly and must use the forwarded agent.
pty.WriteLine("echo 'hello world' | gpg --clearsign && echo gpg-''-sign-command-done")
pty.ExpectMatch("BEGIN PGP SIGNED MESSAGE")
pty.ExpectMatch("Hash:")
pty.ExpectMatch("hello world")
pty.ExpectMatch("gpg--sign-command-done")
// And we're done.
pty.WriteLine("exit")
<-cmdDone
})
}
// tGoContext runs fn in a goroutine passing a context that will be
@ -356,3 +580,26 @@ func (*stdioConn) SetReadDeadline(_ time.Time) error {
func (*stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
// tempDirUnixSocket returns a temporary directory that can safely hold unix
// sockets (probably).
//
// During tests on darwin we hit the max path length limit for unix sockets
// pretty easily in the default location, so this function uses /tmp instead to
// get shorter paths.
func tempDirUnixSocket(t *testing.T) string {
t.Helper()
if runtime.GOOS == "darwin" {
testName := strings.ReplaceAll(t.Name(), "/", "_")
dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("coder-test-%s-", testName))
require.NoError(t, err, "create temp dir for gpg test")
t.Cleanup(func() {
err := os.RemoveAll(dir)
assert.NoError(t, err, "remove temp dir", dir)
})
return dir
}
return t.TempDir()
}

View File

@ -4,9 +4,16 @@
package cli
import (
"bufio"
"context"
"io"
"net"
"os"
"strconv"
"time"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
)
func listenWindowSize(ctx context.Context) <-chan os.Signal {
@ -25,3 +32,74 @@ func listenWindowSize(ctx context.Context) <-chan os.Signal {
}()
return windowSize
}
func forwardGPGAgent(ctx context.Context, stderr io.Writer, sshClient *gossh.Client) (io.Closer, error) {
// Read TCP port and cookie from extra socket file. A gpg-agent socket
// file looks like the following:
//
// 49955
// abcdefghijklmnop
//
// The first line is the TCP port that gpg-agent is listening on, and
// the second line is a 16 byte cookie that MUST be sent as the first
// bytes of any connection to this port (otherwise the connection is
// closed by gpg-agent).
localSocket, err := localGPGExtraSocket(ctx)
if err != nil {
return nil, err
}
f, err := os.Open(localSocket)
if err != nil {
return nil, xerrors.Errorf("open gpg-agent-extra socket file %q: %w", localSocket, err)
}
// Scan lines from file to get port and cookie.
var (
port uint16
cookie []byte
scanner = bufio.NewScanner(f)
)
for i := 0; scanner.Scan(); i++ {
switch i {
case 0:
port64, err := strconv.ParseUint(scanner.Text(), 10, 16)
if err != nil {
return nil, xerrors.Errorf("parse gpg-agent-extra socket file %q: line 1: convert string to integer: %w", localSocket, err)
}
port = uint16(port64)
case 1:
cookie = scanner.Bytes()
if len(cookie) != 16 {
return nil, xerrors.Errorf("parse gpg-agent-extra socket file %q: line 2: expected 16 bytes, got %v bytes", localSocket, len(cookie))
}
default:
return nil, xerrors.Errorf("parse gpg-agent-extra socket file %q: file contains more than 2 lines", localSocket)
}
}
err = scanner.Err()
if err != nil {
return nil, xerrors.Errorf("parse gpg-agent-extra socket file: %q: %w", localSocket, err)
}
remoteSocket, err := remoteGPGAgentSocket(sshClient)
if err != nil {
return nil, err
}
localAddr := cookieAddr{
Addr: &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: int(port),
},
cookie: cookie,
}
remoteAddr := &net.UnixAddr{
Name: remoteSocket,
Net: "unix",
}
return sshForwardRemote(ctx, stderr, sshClient, localAddr, remoteAddr)
}

View File

@ -7,6 +7,14 @@ Flags:
-A, --forward-agent Specifies whether to forward the SSH agent
specified in $SSH_AUTH_SOCK.
Consumes $CODER_SSH_FORWARD_AGENT
-G, --forward-gpg Specifies whether to forward the GPG agent.
Unsupported on Windows workspaces, but supports all
clients. Requires gnupg (gpg, gpgconf) on both the
client and workspace. The GPG agent must already be
running locally and will not be started for you. If
a GPG agent is already running in the workspace, it
will be attempted to be killed.
Consumes $CODER_SSH_FORWARD_GPG
-h, --help help for ssh
--identity-agent string Specifies which identity agent to use (overrides
$SSH_AUTH_SOCK), forward agent must also be

View File

@ -61,8 +61,9 @@ type WithFlags interface {
type Option func(*ptyOptions)
type ptyOptions struct {
logger *log.Logger
sshReq *ssh.Pty
logger *log.Logger
sshReq *ssh.Pty
setGPGTTY bool
}
// WithSSHRequest applies the ssh.Pty request to the PTY.
@ -81,6 +82,14 @@ func WithLogger(logger *log.Logger) Option {
}
}
// WithGPGTTY sets the GPG_TTY environment variable to the PTY name. This only
// applies to non-Windows platforms.
func WithGPGTTY() Option {
return func(opts *ptyOptions) {
opts.setGPGTTY = true
}
}
// New constructs a new Pty.
func New(opts ...Option) (PTY, error) {
return newPty(opts...)

View File

@ -27,6 +27,9 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process
if opty.opts.sshReq != nil {
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name()))
}
if opty.opts.setGPGTTY {
cmd.Env = append(cmd.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name()))
}
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,

View File

@ -121,7 +121,7 @@ fatal() {
trap 'fatal "Script encountered an error"' ERR
cdroot
start_cmd API "" "${CODER_DEV_SHIM}" server --http-address 0.0.0.0:3000 --swagger-enable
start_cmd API "" "${CODER_DEV_SHIM}" server --http-address 0.0.0.0:3000 --swagger-enable --access-url "http://127.0.0.1:3000"
echo '== Waiting for Coder to become ready'
# Start the timeout in the background so interrupting this script