mirror of https://github.com/coder/coder.git
feat: Improve experience with local SSH keys (#3835)
* feat: Improve experience with local SSH keys This change means that users can place SSH keys in the default locations for OpenSSH, like `~/.ssh/id_rsa` and it will be automatically picked up (as per a default OpenSSH experience). Fixes #3126 * fix: Ensure gitssh cleans up temporary file on interrupt Co-authored-by: Dean Sheather <dean@deansheather.com>
This commit is contained in:
parent
66ad86a755
commit
d0b02e581d
125
cli/gitssh.go
125
cli/gitssh.go
|
@ -1,9 +1,15 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -13,16 +19,30 @@ import (
|
|||
)
|
||||
|
||||
func gitssh() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
cmd := &cobra.Command{
|
||||
Use: "gitssh",
|
||||
Hidden: true,
|
||||
Short: `Wraps the "ssh" command and uses the coder gitssh key for authentication`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
env := os.Environ()
|
||||
|
||||
// Catch interrupt signals to ensure the temporary private
|
||||
// key file is cleaned up on most cases.
|
||||
ctx, stop := signal.NotifyContext(ctx, interruptSignals...)
|
||||
defer stop()
|
||||
|
||||
// Early check so errors are reported immediately.
|
||||
identityFiles, err := parseIdentityFilesForHost(ctx, args, env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := createAgentClient(cmd)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
key, err := client.AgentGitSSHKey(cmd.Context())
|
||||
key, err := client.AgentGitSSHKey(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get agent git ssh token: %w", err)
|
||||
}
|
||||
|
@ -44,8 +64,23 @@ func gitssh() *cobra.Command {
|
|||
return xerrors.Errorf("close temp gitsshkey file: %w", err)
|
||||
}
|
||||
|
||||
args = append([]string{"-i", privateKeyFile.Name()}, args...)
|
||||
c := exec.CommandContext(cmd.Context(), "ssh", args...)
|
||||
// Append our key, giving precedence to user keys. Note that
|
||||
// OpenSSH server are typically configured with MaxAuthTries
|
||||
// set to the default value of 6. This means that only the 6
|
||||
// first keys can be tried. However, we will assume that if
|
||||
// a user has configured 6+ keys for a host, they know what
|
||||
// they're doing. This behavior is critical if a server has
|
||||
// been configured with MaxAuthTries set to 1.
|
||||
identityFiles = append(identityFiles, privateKeyFile.Name())
|
||||
|
||||
var identityArgs []string
|
||||
for _, id := range identityFiles {
|
||||
identityArgs = append(identityArgs, "-i", id)
|
||||
}
|
||||
|
||||
args = append(identityArgs, args...)
|
||||
c := exec.CommandContext(ctx, "ssh", args...)
|
||||
c.Env = append(c.Env, env...)
|
||||
c.Stderr = cmd.ErrOrStderr()
|
||||
c.Stdout = cmd.OutOrStdout()
|
||||
c.Stdin = cmd.InOrStdin()
|
||||
|
@ -69,4 +104,86 @@ func gitssh() *cobra.Command {
|
|||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// fallbackIdentityFiles is the list of identity files SSH tries when
|
||||
// none have been defined for a host.
|
||||
var fallbackIdentityFiles = strings.Join([]string{
|
||||
"identityfile ~/.ssh/id_rsa",
|
||||
"identityfile ~/.ssh/id_dsa",
|
||||
"identityfile ~/.ssh/id_ecdsa",
|
||||
"identityfile ~/.ssh/id_ecdsa_sk",
|
||||
"identityfile ~/.ssh/id_ed25519",
|
||||
"identityfile ~/.ssh/id_ed25519_sk",
|
||||
"identityfile ~/.ssh/id_xmss",
|
||||
}, "\n")
|
||||
|
||||
// parseIdentityFilesForHost uses ssh -G to discern what SSH keys have
|
||||
// been enabled for the host (via the users SSH config) and returns a
|
||||
// list of existing identity files.
|
||||
//
|
||||
// We do this because when no keys are defined for a host, SSH uses
|
||||
// fallback keys (see above). However, by passing `-i` to attach our
|
||||
// private key, we're effectively disabling the fallback keys.
|
||||
//
|
||||
// Example invocation:
|
||||
//
|
||||
// ssh -G -o SendEnv=GIT_PROTOCOL git@github.com git-upload-pack 'coder/coder'
|
||||
//
|
||||
// The extra arguments work without issue and lets us run the command
|
||||
// as-is without stripping out the excess (git-upload-pack 'coder/coder').
|
||||
func parseIdentityFilesForHost(ctx context.Context, args, env []string) (identityFiles []string, error error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get user home dir failed: %w", err)
|
||||
}
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
var r io.Reader = &outBuf
|
||||
|
||||
args = append([]string{"-G"}, args...)
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
cmd.Env = append(cmd.Env, env...)
|
||||
cmd.Stdout = &outBuf
|
||||
cmd.Stderr = io.Discard
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
// If ssh -G failed, the SSH version is likely too old, fallback
|
||||
// to using the default identity files.
|
||||
r = strings.NewReader(fallbackIdentityFiles)
|
||||
}
|
||||
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
if strings.HasPrefix(line, "identityfile ") {
|
||||
id := strings.TrimPrefix(line, "identityfile ")
|
||||
if strings.HasPrefix(id, "~/") {
|
||||
id = home + id[1:]
|
||||
}
|
||||
// OpenSSH on Windows is weird, it supports using (and does
|
||||
// use) mixed \ and / in paths.
|
||||
//
|
||||
// Example: C:\Users\ZeroCool/.ssh/known_hosts
|
||||
//
|
||||
// To check the file existence in Go, though, we want to use
|
||||
// proper Windows paths.
|
||||
// OpenSSH is amazing, this will work on Windows too:
|
||||
// C:\Users\ZeroCool/.ssh/id_rsa
|
||||
id = filepath.FromSlash(id)
|
||||
|
||||
// Only include the identity file if it exists.
|
||||
if _, err := os.Stat(id); err == nil {
|
||||
identityFiles = append(identityFiles, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := s.Err(); err != nil {
|
||||
// This should never happen, the check is for completeness.
|
||||
return nil, xerrors.Errorf("scan ssh output: %w", err)
|
||||
}
|
||||
|
||||
return identityFiles, nil
|
||||
}
|
||||
|
|
|
@ -2,8 +2,16 @@ package cli_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
|
@ -17,99 +25,245 @@ import (
|
|||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, string, gossh.PublicKey) {
|
||||
t.Helper()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer t.Cleanup(cancel) // Defer so that cancel is the first cleanup.
|
||||
|
||||
// get user public key
|
||||
keypair, err := client.GitSSHKey(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
//nolint:dogsled
|
||||
pubkey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(keypair.PublicKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
// setup template
|
||||
agentToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "somename",
|
||||
Type: "someinstance",
|
||||
Agents: []*proto.Agent{{
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: agentToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// start workspace agent
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false")
|
||||
agentClient := client
|
||||
clitest.SetupConfig(t, agentClient, root)
|
||||
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
t.Cleanup(func() { require.NoError(t, <-errC) })
|
||||
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
return agentClient, agentToken, pubkey
|
||||
}
|
||||
|
||||
func serveSSHForGitSSH(t *testing.T, handler func(ssh.Session), pubkeys ...gossh.PublicKey) *net.TCPAddr {
|
||||
t.Helper()
|
||||
|
||||
// start ssh server
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = l.Close() })
|
||||
|
||||
serveOpts := []ssh.Option{
|
||||
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
for _, pubkey := range pubkeys {
|
||||
if ssh.KeysEqual(pubkey, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}),
|
||||
}
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
// as long as we get a successful session we don't care if the server errors
|
||||
errC <- ssh.Serve(l, handler, serveOpts...)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
_ = l.Close() // Ensure server shutdown.
|
||||
<-errC
|
||||
})
|
||||
|
||||
// start ssh session
|
||||
addr, ok := l.Addr().(*net.TCPAddr)
|
||||
require.True(t, ok)
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
func writePrivateKeyToFile(t *testing.T, name string, key *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
|
||||
b, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
b = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: b,
|
||||
})
|
||||
|
||||
err = os.WriteFile(name, b, 0o600)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGitSSH(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Dial", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// get user public key
|
||||
keypair, err := client.GitSSHKey(context.Background(), codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
publicKey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(keypair.PublicKey))
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// setup template
|
||||
agentToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "somename",
|
||||
Type: "someinstance",
|
||||
Agents: []*proto.Agent{{
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: agentToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// start workspace agent
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false")
|
||||
agentClient := client
|
||||
clitest.SetupConfig(t, agentClient, root)
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
agentErrC := make(chan error)
|
||||
go func() {
|
||||
agentErrC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
dialer, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
// start ssh server
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
return ssh.KeysEqual(publicKey, key)
|
||||
})
|
||||
client, token, pubkey := prepareTestGitSSH(ctx, t)
|
||||
var inc int64
|
||||
sshErrC := make(chan error)
|
||||
go func() {
|
||||
// as long as we get a successful session we don't care if the server errors
|
||||
_ = ssh.Serve(l, func(s ssh.Session) {
|
||||
atomic.AddInt64(&inc, 1)
|
||||
t.Log("got authenticated session")
|
||||
sshErrC <- s.Exit(0)
|
||||
}, publicKeyOption)
|
||||
}()
|
||||
errC := make(chan error, 1)
|
||||
addr := serveSSHForGitSSH(t, func(s ssh.Session) {
|
||||
atomic.AddInt64(&inc, 1)
|
||||
t.Log("got authenticated session")
|
||||
select {
|
||||
case errC <- s.Exit(0):
|
||||
default:
|
||||
t.Error("error channel is full")
|
||||
}
|
||||
}, pubkey)
|
||||
|
||||
// start ssh session
|
||||
addr, ok := l.Addr().(*net.TCPAddr)
|
||||
require.True(t, ok)
|
||||
// set to agent config dir
|
||||
gitsshCmd, _ := clitest.New(t, "gitssh", "--agent-url", agentClient.URL.String(), "--agent-token", agentToken, "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "-o", "IdentitiesOnly=yes", "127.0.0.1")
|
||||
err = gitsshCmd.ExecuteContext(context.Background())
|
||||
cmd, _ := clitest.New(t,
|
||||
"gitssh",
|
||||
"--agent-url", client.URL.String(),
|
||||
"--agent-token", token,
|
||||
"--",
|
||||
fmt.Sprintf("-p%d", addr.Port),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "IdentitiesOnly=yes",
|
||||
"127.0.0.1",
|
||||
)
|
||||
err := cmd.ExecuteContext(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, inc)
|
||||
|
||||
err = <-sshErrC
|
||||
require.NoError(t, err, "error in ssh session exit")
|
||||
|
||||
cancelFunc()
|
||||
err = <-agentErrC
|
||||
err = <-errC
|
||||
require.NoError(t, err, "error in agent execute")
|
||||
})
|
||||
|
||||
t.Run("Local SSH Keys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
home := t.TempDir()
|
||||
sshdir := filepath.Join(home, ".ssh")
|
||||
err := os.MkdirAll(sshdir, 0o700)
|
||||
require.NoError(t, err)
|
||||
|
||||
idFile := filepath.Join(sshdir, "id_ed25519")
|
||||
privkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
localPubkey, err := gossh.NewPublicKey(&privkey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
writePrivateKeyToFile(t, idFile, privkey)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
client, token, coderPubkey := prepareTestGitSSH(ctx, t)
|
||||
|
||||
authkey := make(chan gossh.PublicKey, 1)
|
||||
addr := serveSSHForGitSSH(t, func(s ssh.Session) {
|
||||
t.Logf("authenticated with: %s", gossh.MarshalAuthorizedKey(s.PublicKey()))
|
||||
select {
|
||||
case authkey <- s.PublicKey():
|
||||
default:
|
||||
t.Error("authkey channel is full")
|
||||
}
|
||||
}, localPubkey, coderPubkey)
|
||||
|
||||
// Create a new config which sets an identity file.
|
||||
config := filepath.Join(sshdir, "config")
|
||||
knownHosts := filepath.Join(sshdir, "known_hosts")
|
||||
err = os.WriteFile(config, []byte(strings.Join([]string{
|
||||
"Host mytest",
|
||||
" HostName 127.0.0.1",
|
||||
fmt.Sprintf(" Port %d", addr.Port),
|
||||
" StrictHostKeyChecking no",
|
||||
" UserKnownHostsFile=" + knownHosts,
|
||||
" IdentitiesOnly yes",
|
||||
" IdentityFile=" + idFile,
|
||||
}, "\n")), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
cmdArgs := []string{
|
||||
"gitssh",
|
||||
"--agent-url", client.URL.String(),
|
||||
"--agent-token", token,
|
||||
"--",
|
||||
"-F", config,
|
||||
"mytest",
|
||||
}
|
||||
// Test authentication via local private key.
|
||||
cmd, _ := clitest.New(t, cmdArgs...)
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
err = cmd.ExecuteContext(ctx)
|
||||
require.NoError(t, err)
|
||||
select {
|
||||
case key := <-authkey:
|
||||
require.Equal(t, localPubkey, key)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for auth")
|
||||
}
|
||||
|
||||
// Delete the local private key.
|
||||
err = os.Remove(idFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// With the local file deleted, the coder key should be used.
|
||||
cmd, _ = clitest.New(t, cmdArgs...)
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
err = cmd.ExecuteContext(ctx)
|
||||
require.NoError(t, err)
|
||||
select {
|
||||
case key := <-authkey:
|
||||
require.Equal(t, coderPubkey, key)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for auth")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue