feat: Add config-ssh command (#735)

* feat: Add config-ssh command

Closes #254 and #499.

* Fix Windows support
This commit is contained in:
Kyle Carberry 2022-03-30 17:59:54 -05:00 committed by GitHub
parent 6ab1a681c4
commit 6612e3c9c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 554 additions and 115 deletions

View File

@ -1,5 +1,6 @@
{
"cSpell.words": [
"cliflag",
"cliui",
"coderd",
"coderdtest",

View File

@ -56,24 +56,24 @@ type agent struct {
sshServer *ssh.Server
}
func (s *agent) run(ctx context.Context) {
func (a *agent) run(ctx context.Context) {
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
peerListener, err = s.clientDialer(ctx, s.options)
peerListener, err = a.clientDialer(ctx, a.options)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if s.isClosed() {
if a.isClosed() {
return
}
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
a.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
s.options.Logger.Info(context.Background(), "connected")
a.options.Logger.Info(context.Background(), "connected")
break
}
select {
@ -85,40 +85,40 @@ func (s *agent) run(ctx context.Context) {
for {
conn, err := peerListener.Accept()
if err != nil {
if s.isClosed() {
if a.isClosed() {
return
}
s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
s.run(ctx)
a.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
a.run(ctx)
return
}
s.closeMutex.Lock()
s.connCloseWait.Add(1)
s.closeMutex.Unlock()
go s.handlePeerConn(ctx, conn)
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
go a.handlePeerConn(ctx, conn)
}
}
func (s *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go func() {
<-conn.Closed()
s.connCloseWait.Done()
a.connCloseWait.Done()
}()
for {
channel, err := conn.Accept(ctx)
if err != nil {
if errors.Is(err, peer.ErrClosed) || s.isClosed() {
if errors.Is(err, peer.ErrClosed) || a.isClosed() {
return
}
s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
a.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
return
}
switch channel.Protocol() {
case "ssh":
s.sshServer.HandleConn(channel.NetConn())
a.sshServer.HandleConn(channel.NetConn())
default:
s.options.Logger.Warn(ctx, "unhandled protocol from channel",
a.options.Logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
slog.F("label", channel.Label()),
)
@ -126,7 +126,7 @@ func (s *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
}
}
func (s *agent) init(ctx context.Context) {
func (a *agent) init(ctx context.Context) {
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
@ -138,17 +138,17 @@ func (s *agent) init(ctx context.Context) {
if err != nil {
panic(err)
}
sshLogger := s.options.Logger.Named("ssh-server")
sshLogger := a.options.Logger.Named("ssh-server")
forwardHandler := &ssh.ForwardedTCPHandler{}
s.sshServer = &ssh.Server{
a.sshServer = &ssh.Server{
ChannelHandlers: ssh.DefaultChannelHandlers,
ConnectionFailedCallback: func(conn net.Conn, err error) {
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
},
Handler: func(session ssh.Session) {
err := s.handleSSHSession(session)
err := a.handleSSHSession(session)
if err != nil {
s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err))
a.options.Logger.Warn(ctx, "ssh session failed", slog.Error(err))
_ = session.Exit(1)
return
}
@ -177,35 +177,26 @@ func (s *agent) init(ctx context.Context) {
},
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return &gossh.ServerConfig{
Config: gossh.Config{
// "arcfour" is the fastest SSH cipher. We prioritize throughput
// over encryption here, because the WebRTC connection is already
// encrypted. If possible, we'd disable encryption entirely here.
Ciphers: []string{"arcfour"},
},
NoClientAuth: true,
}
},
}
go s.run(ctx)
go a.run(ctx)
}
func (*agent) handleSSHSession(session ssh.Session) error {
func (a *agent) handleSSHSession(session ssh.Session) error {
var (
command string
args = []string{}
err error
)
username := session.User()
if username == "" {
currentUser, err := user.Current()
if err != nil {
return xerrors.Errorf("get current user: %w", err)
}
username = currentUser.Username
currentUser, err := user.Current()
if err != nil {
return xerrors.Errorf("get current user: %w", err)
}
username := currentUser.Username
// gliderlabs/ssh returns a command slice of zero
// when a shell is requested.
@ -249,9 +240,9 @@ func (*agent) handleSSHSession(session ssh.Session) error {
}
go func() {
for win := range windowSize {
err := ptty.Resize(uint16(win.Width), uint16(win.Height))
err = ptty.Resize(uint16(win.Width), uint16(win.Height))
if err != nil {
panic(err)
a.options.Logger.Warn(context.Background(), "failed to resize tty", slog.Error(err))
}
}
}()
@ -286,24 +277,24 @@ func (*agent) handleSSHSession(session ssh.Session) error {
}
// isClosed returns whether the API is closed or not.
func (s *agent) isClosed() bool {
func (a *agent) isClosed() bool {
select {
case <-s.closed:
case <-a.closed:
return true
default:
return false
}
}
func (s *agent) Close() error {
s.closeMutex.Lock()
defer s.closeMutex.Unlock()
if s.isClosed() {
func (a *agent) Close() error {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
return nil
}
close(s.closed)
s.closeCancel()
_ = s.sshServer.Close()
s.connCloseWait.Wait()
close(a.closed)
a.closeCancel()
_ = a.sshServer.Close()
a.connCloseWait.Wait()
return nil
}

View File

@ -39,9 +39,6 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
return nil, xerrors.Errorf("ssh: %w", err)
}
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
Config: ssh.Config{
Ciphers: []string{"arcfour"},
},
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec

View File

@ -27,5 +27,5 @@ func Get(username string) (string, error) {
}
return parts[6], nil
}
return "", xerrors.New("user not found in /etc/passwd and $SHELL not set")
return "", xerrors.Errorf("user %q not found in /etc/passwd", username)
}

View File

@ -3,11 +3,11 @@ package cliui
import (
"context"
"fmt"
"io"
"sync"
"time"
"github.com/briandowns/spinner"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/coder/coder/codersdk"
@ -21,7 +21,7 @@ type AgentOptions struct {
}
// Agent displays a spinning indicator that waits for a workspace agent to connect.
func Agent(cmd *cobra.Command, opts AgentOptions) error {
func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
if opts.FetchInterval == 0 {
opts.FetchInterval = 500 * time.Millisecond
}
@ -29,7 +29,7 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
opts.WarnInterval = 30 * time.Second
}
var resourceMutex sync.Mutex
resource, err := opts.Fetch(cmd.Context())
resource, err := opts.Fetch(ctx)
if err != nil {
return xerrors.Errorf("fetch: %w", err)
}
@ -40,7 +40,8 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
opts.WarnInterval = 0
}
spin := spinner.New(spinner.CharSets[78], 100*time.Millisecond, spinner.WithColor("fgHiGreen"))
spin.Writer = cmd.OutOrStdout()
spin.Writer = writer
spin.ForceOutput = true
spin.Suffix = " Waiting for connection from " + Styles.Field.Render(resource.Type+"."+resource.Name) + "..."
spin.Start()
defer spin.Stop()
@ -51,7 +52,7 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
defer timer.Stop()
go func() {
select {
case <-cmd.Context().Done():
case <-ctx.Done():
return
case <-timer.C:
}
@ -63,17 +64,17 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
}
// This saves the cursor position, then defers clearing from the cursor
// position to the end of the screen.
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\033[s\r\033[2K%s\n\n", Styles.Paragraph.Render(Styles.Prompt.String()+message))
defer fmt.Fprintf(cmd.OutOrStdout(), "\033[u\033[J")
_, _ = fmt.Fprintf(writer, "\033[s\r\033[2K%s\n\n", Styles.Paragraph.Render(Styles.Prompt.String()+message))
defer fmt.Fprintf(writer, "\033[u\033[J")
}()
for {
select {
case <-cmd.Context().Done():
return cmd.Context().Err()
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
resourceMutex.Lock()
resource, err = opts.Fetch(cmd.Context())
resource, err = opts.Fetch(ctx)
if err != nil {
return xerrors.Errorf("fetch: %w", err)
}

View File

@ -20,7 +20,7 @@ func TestAgent(t *testing.T) {
ptty := ptytest.New(t)
cmd := &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error {
err := cliui.Agent(cmd, cliui.AgentOptions{
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{
WorkspaceName: "example",
Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) {
resource := codersdk.WorkspaceResource{

38
cli/cliui/log.go Normal file
View File

@ -0,0 +1,38 @@
package cliui
import (
"fmt"
"io"
"strings"
"github.com/charmbracelet/lipgloss"
)
// cliMessage provides a human-readable message for CLI errors and messages.
type cliMessage struct {
Level string
Style lipgloss.Style
Header string
Lines []string
}
// String formats the CLI message for consumption by a human.
func (m cliMessage) String() string {
var str strings.Builder
_, _ = fmt.Fprintf(&str, "%s\r\n",
Styles.Bold.Render(m.Header))
for _, line := range m.Lines {
_, _ = fmt.Fprintf(&str, " %s %s\r\n", m.Style.Render("|"), line)
}
return str.String()
}
// Warn writes a log to the writer provided.
func Warn(wtr io.Writer, header string, lines ...string) {
_, _ = fmt.Fprint(wtr, cliMessage{
Level: "warning",
Style: Styles.Warn,
Header: header,
Lines: lines,
}.String())
}

View File

@ -2,6 +2,7 @@ package cliui
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
@ -62,7 +63,11 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) {
var rawMessage json.RawMessage
err := json.NewDecoder(pipeReader).Decode(&rawMessage)
if err == nil {
line = string(rawMessage)
var buf bytes.Buffer
err = json.Compact(&buf, rawMessage)
if err == nil {
line = buf.String()
}
}
}
}

View File

@ -93,9 +93,7 @@ func TestPrompt(t *testing.T) {
ptty.WriteLine(`{
"test": "wow"
}`)
require.Equal(t, `{
"test": "wow"
}`, <-doneChan)
require.Equal(t, `{"test":"wow"}`, <-doneChan)
})
}

View File

@ -3,27 +3,27 @@ package cliui
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"sync"
"time"
"github.com/google/uuid"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
)
func WorkspaceBuild(cmd *cobra.Command, client *codersdk.Client, build uuid.UUID, before time.Time) error {
return ProvisionerJob(cmd, ProvisionerJobOptions{
func WorkspaceBuild(ctx context.Context, writer io.Writer, client *codersdk.Client, build uuid.UUID, before time.Time) error {
return ProvisionerJob(ctx, writer, ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build)
build, err := client.WorkspaceBuild(ctx, build)
return build.Job, err
},
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) {
return client.WorkspaceBuildLogsAfter(cmd.Context(), build, before)
return client.WorkspaceBuildLogsAfter(ctx, build, before)
},
})
}
@ -39,25 +39,25 @@ type ProvisionerJobOptions struct {
}
// ProvisionerJob renders a provisioner job with interactive cancellation.
func ProvisionerJob(cmd *cobra.Command, opts ProvisionerJobOptions) error {
func ProvisionerJob(ctx context.Context, writer io.Writer, opts ProvisionerJobOptions) error {
if opts.FetchInterval == 0 {
opts.FetchInterval = time.Second
}
ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
var (
currentStage = "Queued"
currentStageStartedAt = time.Now().UTC()
didLogBetweenStage = false
ctx, cancelFunc = context.WithCancel(cmd.Context())
errChan = make(chan error, 1)
job codersdk.ProvisionerJob
jobMutex sync.Mutex
)
defer cancelFunc()
printStage := func() {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), Styles.Prompt.Render("⧗")+"%s\n", Styles.Field.Render(currentStage))
_, _ = fmt.Fprintf(writer, Styles.Prompt.Render("⧗")+"%s\n", Styles.Field.Render(currentStage))
}
updateStage := func(stage string, startedAt time.Time) {
@ -70,7 +70,7 @@ func ProvisionerJob(cmd *cobra.Command, opts ProvisionerJobOptions) error {
if job.CompletedAt != nil && job.Status != codersdk.ProvisionerJobSucceeded {
mark = Styles.Crossmark
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), prefix+mark.String()+Styles.Placeholder.Render(" %s [%dms]")+"\n", currentStage, startedAt.Sub(currentStageStartedAt).Milliseconds())
_, _ = fmt.Fprintf(writer, prefix+mark.String()+Styles.Placeholder.Render(" %s [%dms]")+"\n", currentStage, startedAt.Sub(currentStageStartedAt).Milliseconds())
}
if stage == "" {
return
@ -116,7 +116,7 @@ func ProvisionerJob(cmd *cobra.Command, opts ProvisionerJobOptions) error {
return
}
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\033[2K\r\n"+Styles.FocusedPrompt.String()+Styles.Bold.Render("Gracefully canceling...")+"\n\n")
_, _ = fmt.Fprintf(writer, "\033[2K\r\n"+Styles.FocusedPrompt.String()+Styles.Bold.Render("Gracefully canceling...")+"\n\n")
err := opts.Cancel()
if err != nil {
errChan <- xerrors.Errorf("cancel: %w", err)
@ -183,7 +183,7 @@ func ProvisionerJob(cmd *cobra.Command, opts ProvisionerJobOptions) error {
jobMutex.Unlock()
continue
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s %s\n", Styles.Placeholder.Render(" "), output)
_, _ = fmt.Fprintf(writer, "%s %s\n", Styles.Placeholder.Render(" "), output)
didLogBetweenStage = true
jobMutex.Unlock()
}

View File

@ -126,7 +126,7 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
logs := make(chan codersdk.ProvisionerJobLog, 1)
cmd := &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error {
return cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{
return cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
FetchInterval: time.Millisecond,
Fetch: func() (codersdk.ProvisionerJob, error) {
jobLock.Lock()

View File

@ -1,26 +1,162 @@
package cli
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"github.com/cli/safeexec"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
)
// const sshStartToken = "# ------------START-CODER-----------"
// const sshStartMessage = `# This was generated by "coder config-ssh".
// #
// # To remove this blob, run:
// #
// # coder config-ssh --remove
// #
// # You should not hand-edit this section, unless you are deleting it.`
// const sshEndToken = "# ------------END-CODER------------"
const sshStartToken = "# ------------START-CODER-----------"
const sshStartMessage = `# This was generated by "coder config-ssh".
#
# To remove this blob, run:
#
# coder config-ssh --remove
#
# You should not hand-edit this section, unless you are deleting it.`
const sshEndToken = "# ------------END-CODER------------"
func configSSH() *cobra.Command {
var (
sshConfigFile string
)
cmd := &cobra.Command{
Use: "config-ssh",
RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient(cmd)
if err != nil {
return err
}
if strings.HasPrefix(sshConfigFile, "~/") {
dirname, _ := os.UserHomeDir()
sshConfigFile = filepath.Join(dirname, sshConfigFile[2:])
}
// Doesn't matter if this fails, because we write the file anyways.
sshConfigContentRaw, _ := os.ReadFile(sshConfigFile)
sshConfigContent := string(sshConfigContentRaw)
startIndex := strings.Index(sshConfigContent, sshStartToken)
endIndex := strings.Index(sshConfigContent, sshEndToken)
if startIndex != -1 && endIndex != -1 {
sshConfigContent = sshConfigContent[:startIndex-1] + sshConfigContent[endIndex+len(sshEndToken):]
}
workspaces, err := client.WorkspacesByUser(cmd.Context(), "")
if err != nil {
return err
}
if len(workspaces) == 0 {
return xerrors.New("You don't have any workspaces!")
}
binPath, err := currentBinPath(cmd)
if err != nil {
return err
}
sshConfigContent += "\n" + sshStartToken + "\n" + sshStartMessage + "\n\n"
sshConfigContentMutex := sync.Mutex{}
var errGroup errgroup.Group
for _, workspace := range workspaces {
workspace := workspace
errGroup.Go(func() error {
resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID)
if err != nil {
return err
}
resourcesWithAgents := make([]codersdk.WorkspaceResource, 0)
for _, resource := range resources {
if resource.Agent == nil {
continue
}
resourcesWithAgents = append(resourcesWithAgents, resource)
}
sshConfigContentMutex.Lock()
defer sshConfigContentMutex.Unlock()
if len(resourcesWithAgents) == 1 {
sshConfigContent += strings.Join([]string{
"Host coder." + workspace.Name,
"\tHostName coder." + workspace.Name,
fmt.Sprintf("\tProxyCommand %q ssh --stdio %s", binPath, workspace.Name),
"\tConnectTimeout=0",
"\tStrictHostKeyChecking=no",
}, "\n") + "\n"
}
return nil
})
}
err = errGroup.Wait()
if err != nil {
return err
}
sshConfigContent += "\n" + sshEndToken
err = os.MkdirAll(filepath.Dir(sshConfigFile), os.ModePerm)
if err != nil {
return err
}
err = os.WriteFile(sshConfigFile, []byte(sshConfigContent), os.ModePerm)
if err != nil {
return err
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "An auto-generated ssh config was written to %q\n", sshConfigFile)
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "You should now be able to ssh into your workspace")
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "For example, try running\n\n\t$ ssh coder.%s\n\n", workspaces[0].Name)
return nil
},
}
cliflag.StringVarP(cmd.Flags(), &sshConfigFile, "ssh-config-file", "", "CODER_SSH_CONFIG_FILE", "~/.ssh/config", "Specifies the path to an SSH config.")
return cmd
}
// currentBinPath returns the path to the coder binary suitable for use in ssh
// ProxyCommand.
func currentBinPath(cmd *cobra.Command) (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", xerrors.Errorf("get executable path: %w", err)
}
binName := filepath.Base(exePath)
// We use safeexec instead of os/exec because os/exec returns paths in
// the current working directory, which we will run into very often when
// looking for our own path.
pathPath, err := safeexec.LookPath(binName)
// On Windows, the coder-cli executable must be in $PATH for both Msys2/Git
// Bash and OpenSSH for Windows (used by Powershell and VS Code) to function
// correctly. Check if the current executable is in $PATH, and warn the user
// if it isn't.
if err != nil && runtime.GOOS == "windows" {
cliui.Warn(cmd.OutOrStdout(),
"The current executable is not in $PATH.",
"This may lead to problems connecting to your workspace via SSH.",
fmt.Sprintf("Please move %q to a location in your $PATH (such as System32) and run `%s config-ssh` again.", binName, binName),
)
// Return the exePath so SSH at least works outside of Msys2.
return exePath, nil
}
// Warn the user if the current executable is not the same as the one in
// $PATH.
if filepath.Clean(pathPath) != filepath.Clean(exePath) {
cliui.Warn(cmd.OutOrStdout(),
"The current executable path does not match the executable path found in $PATH.",
"This may cause issues connecting to your workspace via SSH.",
fmt.Sprintf("\tCurrent executable path: %q", exePath),
fmt.Sprintf("\tExecutable path in $PATH: %q", pathPath),
)
}
return binName, nil
}

41
cli/configssh_test.go Normal file
View File

@ -0,0 +1,41 @@
package cli_test
import (
"os"
"testing"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/pty/ptytest"
"github.com/stretchr/testify/require"
)
func TestConfigSSH(t *testing.T) {
t.Parallel()
t.Run("Create", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdtest.NewProvisionerDaemon(t, client)
version := coderdtest.CreateProjectVersion(t, client, user.OrganizationID, nil)
coderdtest.AwaitProjectVersionJob(t, client, version.ID)
project := coderdtest.CreateProject(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, "", project.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
tempFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
_ = tempFile.Close()
cmd, root := clitest.New(t, "config-ssh", "--ssh-config-file", tempFile.Name())
clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := cmd.Execute()
require.NoError(t, err)
}()
<-doneChan
})
}

View File

@ -125,7 +125,7 @@ func createValidProjectVersion(cmd *cobra.Command, client *codersdk.Client, orga
return nil, nil, err
}
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) {
version, err := client.ProjectVersion(cmd.Context(), version.ID)
return version.Job, err

View File

@ -2,18 +2,28 @@ package cli
import (
"context"
"io"
"net"
"os"
"time"
"github.com/mattn/go-isatty"
"github.com/pion/webrtc/v3"
"github.com/spf13/cobra"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"golang.org/x/crypto/ssh/terminal"
)
func ssh() *cobra.Command {
var (
stdio bool
)
cmd := &cobra.Command{
Use: "ssh <workspace> [resource]",
RunE: func(cmd *cobra.Command, args []string) error {
@ -25,8 +35,11 @@ func ssh() *cobra.Command {
if err != nil {
return err
}
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to ssh")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd, client, workspace.LatestBuild.ID, workspace.CreatedAt)
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return err
}
@ -67,7 +80,9 @@ func ssh() *cobra.Command {
}
return xerrors.Errorf("no sshable agent with address %q: %+v", resourceAddress, resourceKeys)
}
err = cliui.Agent(cmd, cliui.AgentOptions{
// OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) {
return client.WorkspaceResource(ctx, resource.ID)
@ -84,6 +99,17 @@ func ssh() *cobra.Command {
return err
}
defer conn.Close()
if stdio {
rawSSH, err := conn.SSH()
if err != nil {
return err
}
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}
sshClient, err := conn.SSHClient()
if err != nil {
return err
@ -94,9 +120,17 @@ func ssh() *cobra.Command {
return err
}
err = sshSession.RequestPty("xterm-256color", 128, 128, gossh.TerminalModes{
gossh.OCRNL: 1,
})
if isatty.IsTerminal(os.Stdout.Fd()) {
state, err := terminal.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return err
}
defer func() {
_ = terminal.Restore(int(os.Stdin.Fd()), state)
}()
}
err = sshSession.RequestPty("xterm-256color", 128, 128, gossh.TerminalModes{})
if err != nil {
return err
}
@ -115,6 +149,36 @@ func ssh() *cobra.Command {
return nil
},
}
cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.")
return cmd
}
type stdioConn struct {
io.Reader
io.Writer
}
func (*stdioConn) Close() (err error) {
return nil
}
func (*stdioConn) LocalAddr() net.Addr {
return nil
}
func (*stdioConn) RemoteAddr() net.Addr {
return nil
}
func (*stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (*stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (*stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@ -1,10 +1,15 @@
package cli_test
import (
"io"
"net"
"runtime"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
@ -78,4 +83,113 @@ func TestSSH(t *testing.T) {
pty.WriteLine("exit")
<-doneChan
})
t.Run("Stdio", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdtest.NewProvisionerDaemon(t, client)
agentToken := uuid.NewString()
version := coderdtest.CreateProjectVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "dev",
Type: "google_compute_instance",
Agent: &proto.Agent{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: agentToken,
},
},
}},
},
},
}},
})
coderdtest.AwaitProjectVersionJob(t, client, version.ID)
project := coderdtest.CreateProject(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, "", project.ID)
go func() {
// Run this async so the SSH command has to wait for
// the build and agent to connect!
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = agentToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
}()
clientOutput, clientInput := io.Pipe()
serverOutput, serverInput := io.Pipe()
cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
cmd.SetIn(clientOutput)
cmd.SetOut(serverInput)
cmd.SetErr(io.Discard)
go func() {
defer close(doneChan)
err := cmd.Execute()
require.NoError(t, err)
}()
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
Reader: serverOutput,
Writer: clientInput,
}, "", &ssh.ClientConfig{
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
require.NoError(t, err)
sshClient := ssh.NewClient(conn, channels, requests)
session, err := sshClient.NewSession()
require.NoError(t, err)
command := "sh -c exit"
if runtime.GOOS == "windows" {
command = "cmd.exe /c exit"
}
err = session.Run(command)
require.NoError(t, err)
err = sshClient.Close()
require.NoError(t, err)
_ = clientOutput.Close()
<-doneChan
})
}
type stdioConn struct {
io.Reader
io.Writer
}
func (*stdioConn) Close() (err error) {
return nil
}
func (*stdioConn) LocalAddr() net.Addr {
return nil
}
func (*stdioConn) RemoteAddr() net.Addr {
return nil
}
func (*stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (*stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (*stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@ -266,7 +266,7 @@ func start() *cobra.Command {
return xerrors.Errorf("delete workspace: %w", err)
}
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build.ID)
return build.Job, err
@ -313,7 +313,7 @@ func start() *cobra.Command {
cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder")
cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
// systemd uses the CACHE_DIRECTORY environment variable!
cliflag.StringVarP(root.Flags(), &cacheDir, "cache-dir", "", "CACHE_DIRECTORY", filepath.Join(os.TempDir(), ".coder-cache"), "Specifies a directory to cache binaries for provision operations.")
cliflag.StringVarP(root.Flags(), &cacheDir, "cache-dir", "", "CACHE_DIRECTORY", filepath.Join(os.TempDir(), "coder-cache"), "Specifies a directory to cache binaries for provision operations.")
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
@ -369,6 +369,11 @@ func createFirstUser(cmd *cobra.Command, client *codersdk.Client, cfg config.Roo
}
func newProvisionerDaemon(ctx context.Context, client *codersdk.Client, logger slog.Logger, cacheDir string) (*provisionerd.Server, error) {
err := os.MkdirAll(cacheDir, 0700)
if err != nil {
return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err)
}
terraformClient, terraformServer := provisionersdk.TransportPipe()
go func() {
err := terraform.Serve(ctx, &terraform.ServeOptions{

View File

@ -145,7 +145,7 @@ func workspaceCreate() *cobra.Command {
if err != nil {
return err
}
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), workspace.LatestBuild.ID)
return build.Job, err

View File

@ -32,7 +32,7 @@ func workspaceDelete() *cobra.Command {
if err != nil {
return err
}
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build.ID)
return build.Job, err

View File

@ -31,7 +31,7 @@ func workspaceStart() *cobra.Command {
if err != nil {
return err
}
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build.ID)
return build.Job, err

View File

@ -31,7 +31,7 @@ func workspaceStop() *cobra.Command {
if err != nil {
return err
}
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build.ID)
return build.Job, err

View File

@ -96,7 +96,7 @@ func main() {
job.Status = codersdk.ProvisionerJobSucceeded
}()
err := cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{
err := cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) {
return job, nil
},
@ -172,7 +172,7 @@ func main() {
time.Sleep(3 * time.Second)
resource.Agent.Status = codersdk.WorkspaceAgentConnected
}()
err := cliui.Agent(cmd, cliui.AgentOptions{
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{
WorkspaceName: "dev",
Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) {
return resource, nil

View File

@ -137,8 +137,6 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
return
}
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
@ -183,6 +181,23 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
}
return nil
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDWithoutAfter(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID.String() != latestBuild.ID.String() {
return xerrors.New("build is outdated")
}
return nil
}
defer func() {
disconnectedAt = sql.NullTime{
@ -197,6 +212,13 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = ensureLatestBuild()
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
@ -214,6 +236,12 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
}
}
}

View File

@ -22,7 +22,7 @@ Coder requires a Google Cloud Service Account to provision workspaces.
- Service Account User
3. Click on the created key, and navigate to the "Keys" tab.
4. Click "Add key", then "Create new key".
5. Generate a JSON private key, and paste the contents in \'\' quotes below.
5. Generate a JSON private key, and paste the contents below.
EOF
sensitive = true
}

View File

@ -22,7 +22,7 @@ Coder requires a Google Cloud Service Account to provision workspaces.
- Service Account User
3. Click on the created key, and navigate to the "Keys" tab.
4. Click "Add key", then "Create new key".
5. Generate a JSON private key, and paste the contents in \'\' quotes below.
5. Generate a JSON private key, and paste the contents below.
EOF
sensitive = true
}

5
go.mod
View File

@ -14,6 +14,9 @@ replace github.com/hashicorp/terraform-config-inspect => github.com/kylecarbs/te
// Required until https://github.com/chzyer/readline/pull/198 is merged.
replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8
// Required until https://github.com/briandowns/spinner/pull/136 is merged.
replace github.com/briandowns/spinner => github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e
// opencensus-go leaks a goroutine by default.
replace go.opencensus.io => github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b
@ -90,6 +93,8 @@ require (
storj.io/drpc v0.0.30
)
require github.com/cli/safeexec v1.0.0
require (
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/BurntSushi/toml v1.0.0 // indirect

6
go.sum
View File

@ -254,8 +254,6 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dR
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g=
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA=
github.com/briandowns/spinner v1.18.1 h1:yhQmQtM1zsqFsouh09Bk/jCjd50pC3EOGsh28gLVvwY=
github.com/briandowns/spinner v1.18.1/go.mod h1:mQak9GHqbspjC/5iUx3qMlIho8xBS/ppAL/hX5SmPJU=
github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk=
github.com/buger/jsonparser v0.0.0-20180808090653-f4dd9f5a6b44/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
@ -304,6 +302,8 @@ github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6D
github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I=
github.com/clbanning/mxj/v2 v2.5.5 h1:oT81vUeEiQQ/DcHbzSytRngP6Ky9O+L+0Bw0zSJag9E=
github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s=
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93 h1:QrGfkZDnMxcWHaYDdB7CmqS9i26OAnUj/xcus/abYkY=
github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93/go.mod h1:QiTe66jFdP7cUKMCCf/WrvDyYdtdmdZfVcdoLbzaKVY=
@ -1126,6 +1126,8 @@ github.com/kylecarbs/promptui v0.8.1-0.20201231190244-d8f2159af2b2 h1:MUREBTh4ky
github.com/kylecarbs/promptui v0.8.1-0.20201231190244-d8f2159af2b2/go.mod h1:n4zTdgP0vr0S3w7/O/g98U+e0gwLScEXGwov2nIKuGQ=
github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8 h1:Y7O3Z3YeNRtw14QrtHpevU4dSjCkov0J40MtQ7Nc0n8=
github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8/go.mod h1:n/KX1BZoN1m9EwoXkn/xAV4fd3k8c++gGBsgLONaPOY=
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e/go.mod h1:mQak9GHqbspjC/5iUx3qMlIho8xBS/ppAL/hX5SmPJU=
github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88 h1:tvG/qs5c4worwGyGnbbb4i/dYYLjpFwDMqcIT3awAf8=
github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88/go.mod h1:Z0Nnk4+3Cy89smEbrq+sl1bxc9198gIP4I7wcQF6Kqs=
github.com/kylecarbs/terraform-exec v0.15.1-0.20220202050609-a1ce7181b180 h1:yafC0pmxjs18fnO5RdKFLSItJIjYwGfSHTfcUvlZb3E=

View File

@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"github.com/awalterschulze/gographviz"
@ -87,7 +88,9 @@ func (t *terraform) Provision(stream proto.DRPCProvisioner_ProvisionStream) erro
})
}
}()
if t.cachePath != "" {
// Windows doesn't work with a plugin cache directory.
// The cause is unknown, but it should work.
if t.cachePath != "" && runtime.GOOS != "windows" {
err = terraform.SetEnv(map[string]string{
"TF_PLUGIN_CACHE_DIR": t.cachePath,
})

View File

@ -2,13 +2,14 @@ package terraform
import (
"context"
"os/exec"
"path/filepath"
"github.com/hashicorp/go-version"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/cli/safeexec"
"github.com/coder/coder/provisionersdk"
"github.com/hashicorp/hc-install/product"
@ -41,7 +42,7 @@ type ServeOptions struct {
// Serve starts a dRPC server on the provided transport speaking Terraform provisioner.
func Serve(ctx context.Context, options *ServeOptions) error {
if options.BinaryPath == "" {
binaryPath, err := exec.LookPath("terraform")
binaryPath, err := safeexec.LookPath("terraform")
if err != nil {
installer := &releases.ExactVersion{
InstallDir: options.CachePath,
@ -55,7 +56,16 @@ func Serve(ctx context.Context, options *ServeOptions) error {
}
options.BinaryPath = execPath
} else {
options.BinaryPath = binaryPath
// If the "coder" binary is in the same directory as
// the "terraform" binary, "terraform" is returned.
//
// We must resolve the absolute path for other processes
// to execute this properly!
absoluteBinary, err := filepath.Abs(binaryPath)
if err != nil {
return xerrors.Errorf("absolute: %w", err)
}
options.BinaryPath = absoluteBinary
}
}
return provisionersdk.Serve(ctx, &terraform{