feat: convert entire CLI to clibase (#6491)

I'm sorry.
This commit is contained in:
Ammar Bandukwala 2023-03-23 17:42:20 -05:00 committed by GitHub
parent b71b8daa21
commit 2bd6d2908e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
345 changed files with 9965 additions and 9082 deletions

View File

@ -302,7 +302,7 @@ jobs:
echo "cover=false" >> $GITHUB_OUTPUT echo "cover=false" >> $GITHUB_OUTPUT
fi fi
gotestsum --junitfile="gotests.xml" --packages="./..." -- -parallel=8 -timeout=5m -short -failfast $COVERAGE_FLAGS gotestsum --junitfile="gotests.xml" --packages="./..." -- -parallel=8 -timeout=7m -short -failfast $COVERAGE_FLAGS
- uses: actions/upload-artifact@v3 - uses: actions/upload-artifact@v3
if: success() || failure() if: success() || failure()

View File

@ -501,8 +501,6 @@ docs/admin/prometheus.md: scripts/metricsdocgen/main.go scripts/metricsdocgen/me
yarn run format:write:only ../docs/admin/prometheus.md yarn run format:write:only ../docs/admin/prometheus.md
docs/cli.md: scripts/clidocgen/main.go $(GO_SRC_FILES) docs/manifest.json docs/cli.md: scripts/clidocgen/main.go $(GO_SRC_FILES) docs/manifest.json
# TODO(@ammario): re-enable server.md once we finish clibase migration.
ls ./docs/cli/*.md | grep -vP "\/coder_server" | xargs rm
BASE_PATH="." go run ./scripts/clidocgen BASE_PATH="." go run ./scripts/clidocgen
cd site cd site
yarn run format:write:only ../docs/cli.md ../docs/cli/*.md ../docs/manifest.json yarn run format:write:only ../docs/cli.md ../docs/cli/*.md ../docs/manifest.json
@ -519,7 +517,7 @@ coderd/apidoc/swagger.json: $(shell find ./scripts/apidocgen $(FIND_EXCLUSIONS)
update-golden-files: cli/testdata/.gen-golden helm/tests/testdata/.gen-golden update-golden-files: cli/testdata/.gen-golden helm/tests/testdata/.gen-golden
.PHONY: update-golden-files .PHONY: update-golden-files
cli/testdata/.gen-golden: $(wildcard cli/testdata/*.golden) $(GO_SRC_FILES) cli/testdata/.gen-golden: $(wildcard cli/testdata/*.golden) $(wildcard cli/*.tpl) $(GO_SRC_FILES)
go test ./cli -run=TestCommandHelp -update go test ./cli -run=TestCommandHelp -update
touch "$@" touch "$@"

View File

@ -16,7 +16,6 @@ import (
"time" "time"
"cloud.google.com/go/compute/metadata" "cloud.google.com/go/compute/metadata"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
@ -25,11 +24,11 @@ import (
"github.com/coder/coder/agent" "github.com/coder/coder/agent"
"github.com/coder/coder/agent/reaper" "github.com/coder/coder/agent/reaper"
"github.com/coder/coder/buildinfo" "github.com/coder/coder/buildinfo"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/codersdk/agentsdk" "github.com/coder/coder/codersdk/agentsdk"
) )
func workspaceAgent() *cobra.Command { func (r *RootCmd) workspaceAgent() *clibase.Cmd {
var ( var (
auth string auth string
logDir string logDir string
@ -37,22 +36,15 @@ func workspaceAgent() *cobra.Command {
noReap bool noReap bool
sshMaxTimeout time.Duration sshMaxTimeout time.Duration
) )
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "agent", Use: "agent",
Short: `Starts the Coder workspace agent.`,
// This command isn't useful to manually execute. // This command isn't useful to manually execute.
Hidden: true, Hidden: true,
RunE: func(cmd *cobra.Command, _ []string) error { Handler: func(inv *clibase.Invocation) error {
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(inv.Context())
defer cancel() defer cancel()
rawURL, err := cmd.Flags().GetString(varAgentURL)
if err != nil {
return xerrors.Errorf("CODER_AGENT_URL must be set: %w", err)
}
coderURL, err := url.Parse(rawURL)
if err != nil {
return xerrors.Errorf("parse %q: %w", rawURL, err)
}
agentPorts := map[int]string{} agentPorts := map[int]string{}
isLinux := runtime.GOOS == "linux" isLinux := runtime.GOOS == "linux"
@ -65,7 +57,7 @@ func workspaceAgent() *cobra.Command {
MaxSize: 5, // MB MaxSize: 5, // MB
} }
defer logWriter.Close() defer logWriter.Close()
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug) logger := slog.Make(sloghuman.Sink(inv.Stderr), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug)
logger.Info(ctx, "spawning reaper process") logger.Info(ctx, "spawning reaper process")
// Do not start a reaper on the child process. It's important // Do not start a reaper on the child process. It's important
@ -107,15 +99,15 @@ func workspaceAgent() *cobra.Command {
logWriter := &closeWriter{w: ljLogger} logWriter := &closeWriter{w: ljLogger}
defer logWriter.Close() defer logWriter.Close()
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug) logger := slog.Make(sloghuman.Sink(inv.Stderr), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug)
version := buildinfo.Version() version := buildinfo.Version()
logger.Info(ctx, "starting agent", logger.Info(ctx, "starting agent",
slog.F("url", coderURL), slog.F("url", r.agentURL),
slog.F("auth", auth), slog.F("auth", auth),
slog.F("version", version), slog.F("version", version),
) )
client := agentsdk.New(coderURL) client := agentsdk.New(r.agentURL)
client.SDK.Logger = logger client.SDK.Logger = logger
// Set a reasonable timeout so requests can't hang forever! // Set a reasonable timeout so requests can't hang forever!
// The timeout needs to be reasonably long, because requests // The timeout needs to be reasonably long, because requests
@ -139,7 +131,7 @@ func workspaceAgent() *cobra.Command {
var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error) var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error)
switch auth { switch auth {
case "token": case "token":
token, err := cmd.Flags().GetString(varAgentToken) token, err := inv.ParsedFlags().GetString(varAgentToken)
if err != nil { if err != nil {
return xerrors.Errorf("CODER_AGENT_TOKEN must be set for token auth: %w", err) return xerrors.Errorf("CODER_AGENT_TOKEN must be set for token auth: %w", err)
} }
@ -220,11 +212,44 @@ func workspaceAgent() *cobra.Command {
}, },
} }
cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AGENT_AUTH", "token", "Specify the authentication type to use for the agent") cmd.Options = clibase.OptionSet{
cliflag.StringVarP(cmd.Flags(), &logDir, "log-dir", "", "CODER_AGENT_LOG_DIR", os.TempDir(), "Specify the location for the agent log files") {
cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.") Flag: "auth",
cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.") Default: "token",
cliflag.DurationVarP(cmd.Flags(), &sshMaxTimeout, "ssh-max-timeout", "", "CODER_AGENT_SSH_MAX_TIMEOUT", time.Duration(0), "Specify the max timeout for a SSH connection") Description: "Specify the authentication type to use for the agent.",
Env: "CODER_AGENT_AUTH",
Value: clibase.StringOf(&auth),
},
{
Flag: "log-dir",
Default: os.TempDir(),
Description: "Specify the location for the agent log files.",
Env: "CODER_AGENT_LOG_DIR",
Value: clibase.StringOf(&logDir),
},
{
Flag: "pprof-address",
Default: "127.0.0.1:6060",
Env: "CODER_AGENT_PPROF_ADDRESS",
Value: clibase.StringOf(&pprofAddress),
Description: "The address to serve pprof.",
},
{
Flag: "no-reap",
Env: "",
Description: "Do not start a process reaper.",
Value: clibase.BoolOf(&noReap),
},
{
Flag: "ssh-max-timeout",
Default: "0",
Env: "CODER_AGENT_SSH_MAX_TIMEOUT",
Description: "Specify the max timeout for a SSH connection.",
Value: clibase.DurationOf(&sshMaxTimeout),
},
}
return cmd return cmd
} }

View File

@ -16,7 +16,7 @@ import (
"github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil" "github.com/coder/coder/pty/ptytest"
) )
func TestWorkspaceAgent(t *testing.T) { func TestWorkspaceAgent(t *testing.T) {
@ -40,24 +40,20 @@ func TestWorkspaceAgent(t *testing.T) {
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
logDir := t.TempDir() logDir := t.TempDir()
cmd, _ := clitest.New(t, inv, _ := clitest.New(t,
"agent", "agent",
"--auth", "token", "--auth", "token",
"--agent-token", authToken, "--agent-token", authToken,
"--agent-url", client.URL.String(), "--agent-url", client.URL.String(),
"--log-dir", logDir, "--log-dir", logDir,
) )
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
errC := make(chan error, 1)
go func() {
errC <- cmd.ExecuteContext(ctx)
}()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
cancel() pty := ptytest.New(t).Attach(inv)
err := <-errC
require.NoError(t, err) clitest.Start(t, inv)
pty.ExpectMatch("starting agent")
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
info, err := os.Stat(filepath.Join(logDir, "coder-agent.log")) info, err := os.Stat(filepath.Join(logDir, "coder-agent.log"))
require.NoError(t, err) require.NoError(t, err)
@ -96,16 +92,14 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String()) inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "azure-client", metadataClient),
)
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
errC := make(chan error) clitest.Start(t, inv)
go func() {
// A linting error occurs for weakly typing the context value here.
//nolint // The above seems reasonable for a one-off test.
ctx := context.WithValue(ctx, "azure-client", metadataClient)
errC <- cmd.ExecuteContext(ctx)
}()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
workspace, err := client.Workspace(ctx, workspace.ID) workspace, err := client.Workspace(ctx, workspace.ID)
require.NoError(t, err) require.NoError(t, err)
@ -117,9 +111,6 @@ func TestWorkspaceAgent(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer dialer.Close() defer dialer.Close()
require.True(t, dialer.AwaitReachable(context.Background())) require.True(t, dialer.AwaitReachable(context.Background()))
cancelFunc()
err = <-errC
require.NoError(t, err)
}) })
t.Run("AWS", func(t *testing.T) { t.Run("AWS", func(t *testing.T) {
@ -154,36 +145,29 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String()) inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
ctx, cancelFunc := context.WithCancel(context.Background()) inv = inv.WithContext(
defer cancelFunc() //nolint:revive,staticcheck
errC := make(chan error) context.WithValue(inv.Context(), "aws-client", metadataClient),
go func() { )
// A linting error occurs for weakly typing the context value here. clitest.Start(t, inv)
//nolint // The above seems reasonable for a one-off test.
ctx := context.WithValue(ctx, "aws-client", metadataClient)
errC <- cmd.ExecuteContext(ctx)
}()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
workspace, err := client.Workspace(ctx, workspace.ID) workspace, err := client.Workspace(inv.Context(), workspace.ID)
require.NoError(t, err) require.NoError(t, err)
resources := workspace.LatestBuild.Resources resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version) assert.NotEmpty(t, resources[0].Agents[0].Version)
} }
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) dialer, err := client.DialWorkspaceAgent(inv.Context(), resources[0].Agents[0].ID, nil)
require.NoError(t, err) require.NoError(t, err)
defer dialer.Close() defer dialer.Close()
require.True(t, dialer.AwaitReachable(context.Background())) require.True(t, dialer.AwaitReachable(context.Background()))
cancelFunc()
err = <-errC
require.NoError(t, err)
}) })
t.Run("GoogleCloud", func(t *testing.T) { t.Run("GoogleCloud", func(t *testing.T) {
t.Parallel() t.Parallel()
instanceID := "instanceidentifier" instanceID := "instanceidentifier"
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false) validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
client := coderdtest.New(t, &coderdtest.Options{ client := coderdtest.New(t, &coderdtest.Options{
GoogleTokenValidator: validator, GoogleTokenValidator: validator,
IncludeProvisionerDaemon: true, IncludeProvisionerDaemon: true,
@ -212,16 +196,18 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String()) inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
ctx, cancelFunc := context.WithCancel(context.Background()) ptytest.New(t).Attach(inv)
defer cancelFunc() clitest.SetupConfig(t, client, cfg)
errC := make(chan error) clitest.Start(t,
go func() { inv.WithContext(
// A linting error occurs for weakly typing the context value here. //nolint:revive,staticcheck
//nolint // The above seems reasonable for a one-off test. context.WithValue(context.Background(), "gcp-client", metadataClient),
ctx := context.WithValue(ctx, "gcp-client", metadata) ),
errC <- cmd.ExecuteContext(ctx) )
}()
ctx := inv.Context()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
workspace, err := client.Workspace(ctx, workspace.ID) workspace, err := client.Workspace(ctx, workspace.ID)
require.NoError(t, err) require.NoError(t, err)
@ -248,9 +234,5 @@ func TestWorkspaceAgent(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = uuid.Parse(strings.TrimSpace(string(token))) _, err = uuid.Parse(strings.TrimSpace(string(token)))
require.NoError(t, err) require.NoError(t, err)
cancelFunc()
err = <-errC
require.NoError(t, err)
}) })
} }

View File

@ -1,10 +1,6 @@
// Package clibase offers an all-in-one solution for a highly configurable CLI // Package clibase offers an all-in-one solution for a highly configurable CLI
// application. Within Coder, we use it for our `server` subcommand, which // application. Within Coder, we use it for all of our subcommands, which
// demands more functionality than cobra/viper can offer. // demands more functionality than cobra/viber offers.
//
// We will extend its usage to the rest of our application, completely replacing
// cobra/viper. It's also a candidate to be broken out into its own open-source
// library, so we avoid deep coupling with Coder concepts.
// //
// The Command interface is loosely based on the chi middleware pattern and // The Command interface is loosely based on the chi middleware pattern and
// http.Handler/HandlerFunc. // http.Handler/HandlerFunc.

View File

@ -3,11 +3,15 @@ package clibase
import ( import (
"context" "context"
"errors" "errors"
"flag"
"fmt"
"io" "io"
"os" "os"
"strings" "strings"
"unicode"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"golang.org/x/exp/slices"
"golang.org/x/xerrors" "golang.org/x/xerrors"
) )
@ -47,14 +51,70 @@ type Cmd struct {
HelpHandler HandlerFunc HelpHandler HandlerFunc
} }
// AddSubcommands adds the given subcommands, setting their
// Parent field automatically.
func (c *Cmd) AddSubcommands(cmds ...*Cmd) {
for _, cmd := range cmds {
cmd.Parent = c
c.Children = append(c.Children, cmd)
}
}
// Walk calls fn for the command and all its children. // Walk calls fn for the command and all its children.
func (c *Cmd) Walk(fn func(*Cmd)) { func (c *Cmd) Walk(fn func(*Cmd)) {
fn(c) fn(c)
for _, child := range c.Children { for _, child := range c.Children {
child.Parent = c
child.Walk(fn) child.Walk(fn)
} }
} }
// PrepareAll performs initialization and linting on the command and all its children.
func (c *Cmd) PrepareAll() error {
if c.Use == "" {
return xerrors.New("command must have a Use field so that it has a name")
}
var merr error
slices.SortFunc(c.Options, func(a, b Option) bool {
return a.Flag < b.Flag
})
for _, opt := range c.Options {
if opt.Name == "" {
switch {
case opt.Flag != "":
opt.Name = opt.Flag
case opt.Env != "":
opt.Name = opt.Env
case opt.YAML != "":
opt.Name = opt.YAML
default:
merr = errors.Join(merr, xerrors.Errorf("option must have a Name, Flag, Env or YAML field"))
}
}
if opt.Description != "" {
// Enforce that description uses sentence form.
if unicode.IsLower(rune(opt.Description[0])) {
merr = errors.Join(merr, xerrors.Errorf("option %q description should start with a capital letter", opt.Name))
}
if !strings.HasSuffix(opt.Description, ".") {
merr = errors.Join(merr, xerrors.Errorf("option %q description should end with a period", opt.Name))
}
}
}
slices.SortFunc(c.Children, func(a, b *Cmd) bool {
return a.Name() < b.Name()
})
for _, child := range c.Children {
child.Parent = c
err := child.PrepareAll()
if err != nil {
merr = errors.Join(merr, xerrors.Errorf("command %v: %w", child.Name(), err))
}
}
return merr
}
// Name returns the first word in the Use string. // Name returns the first word in the Use string.
func (c *Cmd) Name() string { func (c *Cmd) Name() string {
return strings.Split(c.Use, " ")[0] return strings.Split(c.Use, " ")[0]
@ -64,7 +124,6 @@ func (c *Cmd) Name() string {
// as seen on the command line. // as seen on the command line.
func (c *Cmd) FullName() string { func (c *Cmd) FullName() string {
var names []string var names []string
if c.Parent != nil { if c.Parent != nil {
names = append(names, c.Parent.FullName()) names = append(names, c.Parent.FullName())
} }
@ -77,7 +136,7 @@ func (c *Cmd) FullName() string {
func (c *Cmd) FullUsage() string { func (c *Cmd) FullUsage() string {
var uses []string var uses []string
if c.Parent != nil { if c.Parent != nil {
uses = append(uses, c.Parent.FullUsage()) uses = append(uses, c.Parent.FullName())
} }
uses = append(uses, c.Use) uses = append(uses, c.Use)
return strings.Join(uses, " ") return strings.Join(uses, " ")
@ -115,28 +174,17 @@ type Invocation struct {
// fields with OS defaults. // fields with OS defaults.
func (i *Invocation) WithOS() *Invocation { func (i *Invocation) WithOS() *Invocation {
return i.with(func(i *Invocation) { return i.with(func(i *Invocation) {
if i.Stdout == nil {
i.Stdout = os.Stdout i.Stdout = os.Stdout
}
if i.Stderr == nil {
i.Stderr = os.Stderr i.Stderr = os.Stderr
}
if i.Stdin == nil {
i.Stdin = os.Stdin i.Stdin = os.Stdin
}
if i.Args == nil {
i.Args = os.Args[1:] i.Args = os.Args[1:]
}
if i.Environ == nil {
i.Environ = ParseEnviron(os.Environ(), "") i.Environ = ParseEnviron(os.Environ(), "")
}
}) })
} }
func (i *Invocation) Context() context.Context { func (i *Invocation) Context() context.Context {
if i.ctx == nil { if i.ctx == nil {
// Consider returning context.Background() instead? return context.Background()
panic("context not set, has WithContext() or Run() been called?")
} }
return i.ctx return i.ctx
} }
@ -155,6 +203,18 @@ type runState struct {
flagParseErr error flagParseErr error
} }
func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
fs2 := pflag.NewFlagSet("", pflag.ContinueOnError)
fs2.Usage = func() {}
fs.VisitAll(func(f *pflag.Flag) {
if f.Name == without {
return
}
fs2.AddFlag(f)
})
return fs2
}
// run recursively executes the command and its children. // run recursively executes the command and its children.
// allArgs is wired through the stack so that global flags can be accepted // allArgs is wired through the stack so that global flags can be accepted
// anywhere in the command invocation. // anywhere in the command invocation.
@ -164,6 +224,23 @@ func (i *Invocation) run(state *runState) error {
return xerrors.Errorf("setting defaults: %w", err) return xerrors.Errorf("setting defaults: %w", err)
} }
// If we set the Default of an array but later see a flag for it, we
// don't want to append, we want to replace. So, we need to keep the state
// of defaulted array options.
defaultedArrays := make(map[string]int)
for _, opt := range i.Command.Options {
sv, ok := opt.Value.(pflag.SliceValue)
if !ok {
continue
}
if opt.Flag == "" {
continue
}
defaultedArrays[opt.Flag] = len(sv.GetSlice())
}
err = i.Command.Options.ParseEnv(i.Environ) err = i.Command.Options.ParseEnv(i.Environ)
if err != nil { if err != nil {
return xerrors.Errorf("parsing env: %w", err) return xerrors.Errorf("parsing env: %w", err)
@ -173,6 +250,7 @@ func (i *Invocation) run(state *runState) error {
children := make(map[string]*Cmd) children := make(map[string]*Cmd)
for _, child := range i.Command.Children { for _, child := range i.Command.Children {
child.Parent = i.Command
for _, name := range append(child.Aliases, child.Name()) { for _, name := range append(child.Aliases, child.Name()) {
if _, ok := children[name]; ok { if _, ok := children[name]; ok {
return xerrors.Errorf("duplicate command name: %s", name) return xerrors.Errorf("duplicate command name: %s", name)
@ -187,7 +265,15 @@ func (i *Invocation) run(state *runState) error {
i.parsedFlags.Usage = func() {} i.parsedFlags.Usage = func() {}
} }
i.parsedFlags.AddFlagSet(i.Command.Options.FlagSet()) // If we find a duplicate flag, we want the deeper command's flag to override
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
// have to create a copy of the flagset without a value.
i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
if i.parsedFlags.Lookup(f.Name) != nil {
i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name)
}
i.parsedFlags.AddFlag(f)
})
var parsedArgs []string var parsedArgs []string
@ -196,24 +282,38 @@ func (i *Invocation) run(state *runState) error {
// so we check the error after looking for a child command. // so we check the error after looking for a child command.
state.flagParseErr = i.parsedFlags.Parse(state.allArgs) state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
parsedArgs = i.parsedFlags.Args() parsedArgs = i.parsedFlags.Args()
i.parsedFlags.VisitAll(func(f *pflag.Flag) {
i, ok := defaultedArrays[f.Name]
if !ok {
return
}
if !f.Changed {
return
}
sv, ok := f.Value.(pflag.SliceValue)
if !ok {
panic("defaulted array option is not a slice value")
}
err := sv.Replace(sv.GetSlice()[i:])
if err != nil {
panic(err)
}
})
} }
// Run child command if found (next child only) // Run child command if found (next child only)
// We must do subcommand detection after flag parsing so we don't mistake flag // We must do subcommand detection after flag parsing so we don't mistake flag
// values for subcommand names. // values for subcommand names.
if len(parsedArgs) > 0 { if len(parsedArgs) > state.commandDepth {
nextArg := parsedArgs[0] nextArg := parsedArgs[state.commandDepth]
if child, ok := children[nextArg]; ok { if child, ok := children[nextArg]; ok {
child.Parent = i.Command child.Parent = i.Command
i.Command = child i.Command = child
state.commandDepth++ state.commandDepth++
err = i.run(state) return i.run(state)
if err != nil {
return xerrors.Errorf(
"subcommand %s: %w", child.Name(), err,
)
}
return nil
} }
} }
@ -266,11 +366,27 @@ func (i *Invocation) run(state *runState) error {
err = mw(i.Command.Handler)(i) err = mw(i.Command.Handler)(i)
if err != nil { if err != nil {
return xerrors.Errorf("running command %s: %w", i.Command.FullName(), err) return &RunCommandError{
Cmd: i.Command,
Err: err,
}
} }
return nil return nil
} }
type RunCommandError struct {
Cmd *Cmd
Err error
}
func (e *RunCommandError) Unwrap() error {
return e.Err
}
func (e *RunCommandError) Error() string {
return fmt.Sprintf("running command %q: %+v", e.Cmd.FullName(), e.Err)
}
// findArg returns the index of the first occurrence of arg in args, skipping // findArg returns the index of the first occurrence of arg in args, skipping
// over all flags. // over all flags.
func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) { func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
@ -314,10 +430,21 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
// If two command share a flag name, the first command wins. // If two command share a flag name, the first command wins.
// //
//nolint:revive //nolint:revive
func (i *Invocation) Run() error { func (i *Invocation) Run() (err error) {
return i.run(&runState{ defer func() {
// Pflag is panicky, so additional context is helpful in tests.
if flag.Lookup("test.v") == nil {
return
}
if r := recover(); r != nil {
err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r)
panic(err)
}
}()
err = i.run(&runState{
allArgs: i.Args, allArgs: i.Args,
}) })
return err
} }
// WithContext returns a copy of the Invocation with the given context. // WithContext returns a copy of the Invocation with the given context.
@ -378,6 +505,9 @@ func RequireRangeArgs(start, end int) MiddlewareFunc {
case start == end && got != start: case start == end && got != start:
switch start { switch start {
case 0: case 0:
if len(i.Command.Children) > 0 {
return xerrors.Errorf("unrecognized subcommand %q", i.Args[0])
}
return xerrors.Errorf("wanted no args but got %v %v", got, i.Args) return xerrors.Errorf("wanted no args but got %v %v", got, i.Args)
default: default:
return xerrors.Errorf( return xerrors.Errorf(

View File

@ -213,6 +213,66 @@ func TestCommand(t *testing.T) {
}) })
} }
func TestCommand_DeepNest(t *testing.T) {
t.Parallel()
cmd := &clibase.Cmd{
Use: "1",
Children: []*clibase.Cmd{
{
Use: "2",
Children: []*clibase.Cmd{
{
Use: "3",
Handler: func(i *clibase.Invocation) error {
i.Stdout.Write([]byte("3"))
return nil
},
},
},
},
},
}
inv := cmd.Invoke("2", "3")
stdio := fakeIO(inv)
err := inv.Run()
require.NoError(t, err)
require.Equal(t, "3", stdio.Stdout.String())
}
func TestCommand_FlagOverride(t *testing.T) {
t.Parallel()
var flag string
cmd := &clibase.Cmd{
Use: "1",
Options: clibase.OptionSet{
{
Flag: "f",
Value: clibase.DiscardValue,
},
},
Children: []*clibase.Cmd{
{
Use: "2",
Options: clibase.OptionSet{
{
Flag: "f",
Value: clibase.StringOf(&flag),
},
},
Handler: func(i *clibase.Invocation) error {
return nil
},
},
},
}
err := cmd.Invoke("2", "--f", "mhmm").Run()
require.NoError(t, err)
require.Equal(t, "mhmm", flag)
}
func TestCommand_MiddlewareOrder(t *testing.T) { func TestCommand_MiddlewareOrder(t *testing.T) {
t.Parallel() t.Parallel()
@ -252,7 +312,7 @@ func TestCommand_RawArgs(t *testing.T) {
cmd := func() *clibase.Cmd { cmd := func() *clibase.Cmd {
return &clibase.Cmd{ return &clibase.Cmd{
Use: "root", Use: "root",
Options: []clibase.Option{ Options: clibase.OptionSet{
{ {
Name: "password", Name: "password",
Flag: "password", Flag: "password",
@ -366,3 +426,80 @@ func TestCommand_ContextCancels(t *testing.T) {
require.Error(t, gotCtx.Err()) require.Error(t, gotCtx.Err())
} }
func TestCommand_Help(t *testing.T) {
t.Parallel()
cmd := func() *clibase.Cmd {
return &clibase.Cmd{
Use: "root",
HelpHandler: (func(i *clibase.Invocation) error {
i.Stdout.Write([]byte("abdracadabra"))
return nil
}),
Handler: (func(i *clibase.Invocation) error {
return xerrors.New("should not be called")
}),
}
}
t.Run("NoHandler", func(t *testing.T) {
t.Parallel()
c := cmd()
c.HelpHandler = nil
err := c.Invoke("--help").Run()
require.Error(t, err)
})
t.Run("Long", func(t *testing.T) {
t.Parallel()
inv := cmd().Invoke("--help")
stdio := fakeIO(inv)
err := inv.Run()
require.NoError(t, err)
require.Contains(t, stdio.Stdout.String(), "abdracadabra")
})
t.Run("Short", func(t *testing.T) {
t.Parallel()
inv := cmd().Invoke("-h")
stdio := fakeIO(inv)
err := inv.Run()
require.NoError(t, err)
require.Contains(t, stdio.Stdout.String(), "abdracadabra")
})
}
func TestCommand_SliceFlags(t *testing.T) {
t.Parallel()
cmd := func(want ...string) *clibase.Cmd {
var got []string
return &clibase.Cmd{
Use: "root",
Options: clibase.OptionSet{
{
Name: "arr",
Flag: "arr",
Default: "bad,bad,bad",
Value: clibase.StringArrayOf(&got),
},
},
Handler: (func(i *clibase.Invocation) error {
require.Equal(t, want, got)
return nil
}),
}
}
err := cmd("good", "good", "good").Invoke("--arr", "good", "--arr", "good", "--arr", "good").Run()
require.NoError(t, err)
err = cmd("bad", "bad", "bad").Invoke().Run()
require.NoError(t, err)
}

View File

@ -44,6 +44,11 @@ func (e Environ) Lookup(name string) (string, bool) {
return "", false return "", false
} }
func (e Environ) Get(name string) string {
v, _ := e.Lookup(name)
return v
}
func (e *Environ) Set(name, value string) { func (e *Environ) Set(name, value string) {
for i, v := range *e { for i, v := range *e {
if v.Name == name { if v.Name == name {

View File

@ -77,7 +77,7 @@ func (s *OptionSet) FlagSet() *pflag.FlagSet {
val := opt.Value val := opt.Value
if val == nil { if val == nil {
val = &DiscardValue{} val = DiscardValue
} }
fs.AddFlag(&pflag.Flag{ fs.AddFlag(&pflag.Flag{

View File

@ -35,10 +35,10 @@ func TestOptionSet_ParseFlags(t *testing.T) {
require.EqualValues(t, "f", workspaceName) require.EqualValues(t, "f", workspaceName)
}) })
t.Run("Strings", func(t *testing.T) { t.Run("StringArray", func(t *testing.T) {
t.Parallel() t.Parallel()
var names clibase.Strings var names clibase.StringArray
os := clibase.OptionSet{ os := clibase.OptionSet{
clibase.Option{ clibase.Option{
@ -49,7 +49,10 @@ func TestOptionSet_ParseFlags(t *testing.T) {
}, },
} }
err := os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"}) err := os.SetDefaults()
require.NoError(t, err)
err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"})
require.NoError(t, err) require.NoError(t, err)
require.EqualValues(t, []string{"foo", "bar"}, names) require.EqualValues(t, []string{"foo", "bar"}, names)
}) })

View File

@ -109,26 +109,26 @@ func (String) Type() string {
return "string" return "string"
} }
var _ pflag.SliceValue = &Strings{} var _ pflag.SliceValue = &StringArray{}
// Strings is a slice of strings that implements pflag.Value and pflag.SliceValue. // StringArray is a slice of strings that implements pflag.Value and pflag.SliceValue.
type Strings []string type StringArray []string
func StringsOf(ss *[]string) *Strings { func StringArrayOf(ss *[]string) *StringArray {
return (*Strings)(ss) return (*StringArray)(ss)
} }
func (s *Strings) Append(v string) error { func (s *StringArray) Append(v string) error {
*s = append(*s, v) *s = append(*s, v)
return nil return nil
} }
func (s *Strings) Replace(vals []string) error { func (s *StringArray) Replace(vals []string) error {
*s = vals *s = vals
return nil return nil
} }
func (s *Strings) GetSlice() []string { func (s *StringArray) GetSlice() []string {
return *s return *s
} }
@ -145,7 +145,7 @@ func writeAsCSV(vals []string) string {
return sb.String() return sb.String()
} }
func (s *Strings) Set(v string) error { func (s *StringArray) Set(v string) error {
ss, err := readAsCSV(v) ss, err := readAsCSV(v)
if err != nil { if err != nil {
return err return err
@ -154,16 +154,16 @@ func (s *Strings) Set(v string) error {
return nil return nil
} }
func (s Strings) String() string { func (s StringArray) String() string {
return writeAsCSV([]string(s)) return writeAsCSV([]string(s))
} }
func (s Strings) Value() []string { func (s StringArray) Value() []string {
return []string(s) return []string(s)
} }
func (Strings) Type() string { func (StringArray) Type() string {
return "strings" return "string-array"
} }
type Duration time.Duration type Duration time.Duration
@ -287,7 +287,7 @@ func (hp *HostPort) UnmarshalJSON(b []byte) error {
} }
func (*HostPort) Type() string { func (*HostPort) Type() string {
return "bind-address" return "host:port"
} }
var ( var (
@ -344,16 +344,50 @@ func (s *Struct[T]) UnmarshalJSON(b []byte) error {
// DiscardValue does nothing but implements the pflag.Value interface. // DiscardValue does nothing but implements the pflag.Value interface.
// It's useful in cases where you want to accept an option, but access the // It's useful in cases where you want to accept an option, but access the
// underlying value directly instead of through the Option methods. // underlying value directly instead of through the Option methods.
type DiscardValue struct{} var DiscardValue discardValue
func (DiscardValue) Set(string) error { type discardValue struct{}
func (discardValue) Set(string) error {
return nil return nil
} }
func (DiscardValue) String() string { func (discardValue) String() string {
return "" return ""
} }
func (DiscardValue) Type() string { func (discardValue) Type() string {
return "discard" return "discard"
} }
var _ pflag.Value = (*Enum)(nil)
type Enum struct {
Choices []string
Value *string
}
func EnumOf(v *string, choices ...string) *Enum {
return &Enum{
Choices: choices,
Value: v,
}
}
func (e *Enum) Set(v string) error {
for _, c := range e.Choices {
if v == c {
*e.Value = v
return nil
}
}
return xerrors.Errorf("invalid choice: %s, should be one of %v", v, e.Choices)
}
func (e *Enum) Type() string {
return fmt.Sprintf("enum[%v]", strings.Join(e.Choices, "|"))
}
func (e *Enum) String() string {
return *e.Value
}

View File

@ -38,7 +38,7 @@ func TestOption_ToYAML(t *testing.T) {
Name: "Workspace Name", Name: "Workspace Name",
Value: &workspaceName, Value: &workspaceName,
Default: "billie", Default: "billie",
Description: "The workspace's name", Description: "The workspace's name.",
Group: &clibase.Group{Name: "Names"}, Group: &clibase.Group{Name: "Names"},
YAML: "workspaceName", YAML: "workspaceName",
}, },

View File

@ -1,185 +0,0 @@
// Package cliflag extends flagset with environment variable defaults.
//
// Usage:
//
// cliflag.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
//
// Will produce the following usage docs:
//
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
package cliflag
import (
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/coder/coder/cli/cliui"
)
// IsSetBool returns the value of the boolean flag if it is set.
// It returns false if the flag isn't set or if any error occurs attempting
// to parse the value of the flag.
func IsSetBool(cmd *cobra.Command, name string) bool {
val, ok := IsSet(cmd, name)
if !ok {
return false
}
b, err := strconv.ParseBool(val)
return err == nil && b
}
// IsSet returns the string value of the flag and whether it was set.
func IsSet(cmd *cobra.Command, name string) (string, bool) {
flag := cmd.Flag(name)
if flag == nil {
return "", false
}
return flag.Value.String(), flag.Changed
}
// String sets a string flag on the given flag set.
func String(flagset *pflag.FlagSet, name, shorthand, env, def, usage string) {
v, ok := os.LookupEnv(env)
if !ok || v == "" {
v = def
}
flagset.StringP(name, shorthand, v, fmtUsage(usage, env))
}
// StringVarP sets a string flag on the given flag set.
func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
v, ok := os.LookupEnv(env)
if !ok || v == "" {
v = def
}
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
}
func StringArray(flagset *pflag.FlagSet, name, shorthand, env string, def []string, usage string) {
v, ok := os.LookupEnv(env)
if !ok || v == "" {
if v == "" {
def = []string{}
} else {
def = strings.Split(v, ",")
}
}
flagset.StringArrayP(name, shorthand, def, fmtUsage(usage, env))
}
func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
val, ok := os.LookupEnv(env)
if ok {
if val == "" {
def = []string{}
} else {
def = strings.Split(val, ",")
}
}
flagset.StringArrayVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
}
// Uint8VarP sets a uint8 flag on the given flag set.
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
val, ok := os.LookupEnv(env)
if !ok || val == "" {
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
vi64, err := strconv.ParseUint(val, 10, 8)
if err != nil {
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
flagset.Uint8VarP(ptr, name, shorthand, uint8(vi64), fmtUsage(usage, env))
}
// IntVarP sets a uint8 flag on the given flag set.
func IntVarP(flagset *pflag.FlagSet, ptr *int, name string, shorthand string, env string, def int, usage string) {
val, ok := os.LookupEnv(env)
if !ok || val == "" {
flagset.IntVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
vi64, err := strconv.ParseUint(val, 10, 8)
if err != nil {
flagset.IntVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
flagset.IntVarP(ptr, name, shorthand, int(vi64), fmtUsage(usage, env))
}
func Bool(flagset *pflag.FlagSet, name, shorthand, env string, def bool, usage string) {
val, ok := os.LookupEnv(env)
if !ok || val == "" {
flagset.BoolP(name, shorthand, def, fmtUsage(usage, env))
return
}
valb, err := strconv.ParseBool(val)
if err != nil {
flagset.BoolP(name, shorthand, def, fmtUsage(usage, env))
return
}
flagset.BoolP(name, shorthand, valb, fmtUsage(usage, env))
}
// BoolVarP sets a bool flag on the given flag set.
func BoolVarP(flagset *pflag.FlagSet, ptr *bool, name string, shorthand string, env string, def bool, usage string) {
val, ok := os.LookupEnv(env)
if !ok || val == "" {
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
valb, err := strconv.ParseBool(val)
if err != nil {
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
flagset.BoolVarP(ptr, name, shorthand, valb, fmtUsage(usage, env))
}
// DurationVarP sets a time.Duration flag on the given flag set.
func DurationVarP(flagset *pflag.FlagSet, ptr *time.Duration, name string, shorthand string, env string, def time.Duration, usage string) {
val, ok := os.LookupEnv(env)
if !ok || val == "" {
flagset.DurationVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
valb, err := time.ParseDuration(val)
if err != nil {
flagset.DurationVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
flagset.DurationVarP(ptr, name, shorthand, valb, fmtUsage(usage, env))
}
func fmtUsage(u string, env string) string {
if env != "" {
// Avoid double dotting.
dot := "."
if strings.HasSuffix(u, ".") {
dot = ""
}
u = fmt.Sprintf("%s%s\n"+cliui.Styles.Placeholder.Render("Consumes $%s"), u, dot, env)
}
return u
}

View File

@ -1,277 +0,0 @@
package cliflag_test
import (
"fmt"
"strconv"
"testing"
"time"
"github.com/spf13/pflag"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cryptorand"
)
// Testcliflag cannot run in parallel because it uses t.Setenv.
//
//nolint:paralleltest
func TestCliflag(t *testing.T) {
t.Run("StringDefault", func(t *testing.T) {
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.String(10)
cliflag.String(flagset, name, shorthand, env, def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env))
})
t.Run("StringEnvVar", func(t *testing.T) {
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.String(10)
cliflag.String(flagset, name, shorthand, env, def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, envValue, got)
})
t.Run("StringVarPDefault", func(t *testing.T) {
var ptr string
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.String(10)
cliflag.StringVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env))
})
t.Run("StringVarPEnvVar", func(t *testing.T) {
var ptr string
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.String(10)
cliflag.StringVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, envValue, got)
})
t.Run("EmptyEnvVar", func(t *testing.T) {
var ptr string
flagset, name, shorthand, _, usage := randomFlag()
def, _ := cryptorand.String(10)
cliflag.StringVarP(flagset, &ptr, name, shorthand, "", def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.NotContains(t, flagset.FlagUsages(), "Consumes")
})
t.Run("StringArrayDefault", func(t *testing.T) {
var ptr []string
flagset, name, shorthand, env, usage := randomFlag()
def := []string{"hello"}
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetStringArray(name)
require.NoError(t, err)
require.Equal(t, def, got)
})
t.Run("StringArrayEnvVar", func(t *testing.T) {
var ptr []string
flagset, name, shorthand, env, usage := randomFlag()
t.Setenv(env, "wow,test")
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage)
got, err := flagset.GetStringArray(name)
require.NoError(t, err)
require.Equal(t, []string{"wow", "test"}, got)
})
t.Run("StringArrayEnvVarEmpty", func(t *testing.T) {
var ptr []string
flagset, name, shorthand, env, usage := randomFlag()
t.Setenv(env, "")
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage)
got, err := flagset.GetStringArray(name)
require.NoError(t, err)
require.Equal(t, []string{}, got)
})
t.Run("UInt8Default", func(t *testing.T) {
var ptr uint8
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.Int63n(10)
cliflag.Uint8VarP(flagset, &ptr, name, shorthand, env, uint8(def), usage)
got, err := flagset.GetUint8(name)
require.NoError(t, err)
require.Equal(t, uint8(def), got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env))
})
t.Run("UInt8EnvVar", func(t *testing.T) {
var ptr uint8
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.Int63n(10)
t.Setenv(env, strconv.FormatUint(uint64(envValue), 10))
def, _ := cryptorand.Int()
cliflag.Uint8VarP(flagset, &ptr, name, shorthand, env, uint8(def), usage)
got, err := flagset.GetUint8(name)
require.NoError(t, err)
require.Equal(t, uint8(envValue), got)
})
t.Run("UInt8FailParse", func(t *testing.T) {
var ptr uint8
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.Int63n(10)
cliflag.Uint8VarP(flagset, &ptr, name, shorthand, env, uint8(def), usage)
got, err := flagset.GetUint8(name)
require.NoError(t, err)
require.Equal(t, uint8(def), got)
})
t.Run("IntDefault", func(t *testing.T) {
var ptr int
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.Int63n(10)
cliflag.IntVarP(flagset, &ptr, name, shorthand, env, int(def), usage)
got, err := flagset.GetInt(name)
require.NoError(t, err)
require.Equal(t, int(def), got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env))
})
t.Run("IntEnvVar", func(t *testing.T) {
var ptr int
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.Int63n(10)
t.Setenv(env, strconv.FormatUint(uint64(envValue), 10))
def, _ := cryptorand.Int()
cliflag.IntVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetInt(name)
require.NoError(t, err)
require.Equal(t, int(envValue), got)
})
t.Run("IntFailParse", func(t *testing.T) {
var ptr int
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.Int63n(10)
cliflag.IntVarP(flagset, &ptr, name, shorthand, env, int(def), usage)
got, err := flagset.GetInt(name)
require.NoError(t, err)
require.Equal(t, int(def), got)
})
t.Run("BoolDefault", func(t *testing.T) {
var ptr bool
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.Bool()
cliflag.BoolVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetBool(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env))
})
t.Run("BoolEnvVar", func(t *testing.T) {
var ptr bool
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.Bool()
t.Setenv(env, strconv.FormatBool(envValue))
def, _ := cryptorand.Bool()
cliflag.BoolVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetBool(name)
require.NoError(t, err)
require.Equal(t, envValue, got)
})
t.Run("BoolFailParse", func(t *testing.T) {
var ptr bool
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.Bool()
cliflag.BoolVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetBool(name)
require.NoError(t, err)
require.Equal(t, def, got)
})
t.Run("DurationDefault", func(t *testing.T) {
var ptr time.Duration
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.Duration()
cliflag.DurationVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetDuration(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env))
})
t.Run("DurationEnvVar", func(t *testing.T) {
var ptr time.Duration
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.Duration()
t.Setenv(env, envValue.String())
def, _ := cryptorand.Duration()
cliflag.DurationVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetDuration(name)
require.NoError(t, err)
require.Equal(t, envValue, got)
})
t.Run("DurationFailParse", func(t *testing.T) {
var ptr time.Duration
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.Duration()
cliflag.DurationVarP(flagset, &ptr, name, shorthand, env, def, usage)
got, err := flagset.GetDuration(name)
require.NoError(t, err)
require.Equal(t, def, got)
})
}
func randomFlag() (*pflag.FlagSet, string, string, string, string) {
fsname, _ := cryptorand.String(10)
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
name, _ := cryptorand.String(10)
shorthand, _ := cryptorand.String(1)
env, _ := cryptorand.String(10)
usage, _ := cryptorand.String(10)
return flagset, name, shorthand, env, usage
}

View File

@ -10,14 +10,16 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/cli" "github.com/coder/coder/cli"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/config" "github.com/coder/coder/cli/config"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisioner/echo"
@ -26,8 +28,13 @@ import (
// New creates a CLI instance with a configuration pointed to a // New creates a CLI instance with a configuration pointed to a
// temporary testing directory. // temporary testing directory.
func New(t *testing.T, args ...string) (*cobra.Command, config.Root) { func New(t *testing.T, args ...string) (*clibase.Invocation, config.Root) {
return NewWithSubcommands(t, cli.AGPL(), args...) var root cli.RootCmd
cmd, err := root.Command(root.AGPL())
require.NoError(t, err)
return NewWithCommand(t, cmd, args...)
} }
type logWriter struct { type logWriter struct {
@ -46,19 +53,21 @@ func (l *logWriter) Write(p []byte) (n int, err error) {
return len(p), nil return len(p), nil
} }
func NewWithSubcommands( func NewWithCommand(
t *testing.T, subcommands []*cobra.Command, args ...string, t *testing.T, cmd *clibase.Cmd, args ...string,
) (*cobra.Command, config.Root) { ) (*clibase.Invocation, config.Root) {
cmd := cli.Root(subcommands) configDir := config.Root(t.TempDir())
dir := t.TempDir() i := &clibase.Invocation{
root := config.Root(dir) Command: cmd,
cmd.SetArgs(append([]string{"--global-config", dir}, args...)) Args: append([]string{"--global-config", string(configDir)}, args...),
Stdin: io.LimitReader(nil, 0),
Stdout: (&logWriter{prefix: "stdout", t: t}),
Stderr: (&logWriter{prefix: "stderr", t: t}),
}
t.Logf("invoking command: %s %s", cmd.Name(), strings.Join(i.Args, " "))
// These can be overridden by the test. // These can be overridden by the test.
cmd.SetOut(&logWriter{prefix: "stdout", t: t}) return i, configDir
cmd.SetErr(&logWriter{prefix: "stderr", t: t})
return cmd, root
} }
// SetupConfig applies the URL and SessionToken of the client to the config. // SetupConfig applies the URL and SessionToken of the client to the config.
@ -120,31 +129,111 @@ func extractTar(t *testing.T, data []byte, directory string) {
// Start runs the command in a goroutine and cleans it up when // Start runs the command in a goroutine and cleans it up when
// the test completed. // the test completed.
func Start(ctx context.Context, t *testing.T, cmd *cobra.Command) { func Start(t *testing.T, inv *clibase.Invocation) {
t.Helper() t.Helper()
closeCh := make(chan struct{}) closeCh := make(chan struct{})
deadline, hasDeadline := ctx.Deadline()
if !hasDeadline {
// We don't want to wait the full 5 minutes for a test to time out.
deadline = time.Now().Add(testutil.WaitMedium)
}
ctx, cancel := context.WithDeadline(ctx, deadline)
go func() { go func() {
defer cancel()
defer close(closeCh) defer close(closeCh)
err := cmd.ExecuteContext(ctx) err := StartWithWaiter(t, inv).Wait()
if ctx.Err() == nil { switch {
case errors.Is(err, context.Canceled):
return
default:
assert.NoError(t, err) assert.NoError(t, err)
} }
}() }()
t.Cleanup(func() {
<-closeCh
})
}
// Run runs the command and asserts that there is no error.
func Run(t *testing.T, inv *clibase.Invocation) {
t.Helper()
err := inv.Run()
require.NoError(t, err)
}
type ErrorWaiter struct {
waitOnce sync.Once
cachedError error
c <-chan error
t *testing.T
}
func (w *ErrorWaiter) Wait() error {
w.waitOnce.Do(func() {
var ok bool
w.cachedError, ok = <-w.c
if !ok {
panic("unexpoected channel close")
}
})
return w.cachedError
}
func (w *ErrorWaiter) RequireSuccess() {
require.NoError(w.t, w.Wait())
}
func (w *ErrorWaiter) RequireError() {
require.Error(w.t, w.Wait())
}
func (w *ErrorWaiter) RequireContains(s string) {
require.ErrorContains(w.t, w.Wait(), s)
}
func (w *ErrorWaiter) RequireIs(want error) {
require.ErrorIs(w.t, w.Wait(), want)
}
func (w *ErrorWaiter) RequireAs(want interface{}) {
require.ErrorAs(w.t, w.Wait(), want)
}
// StartWithWaiter runs the command in a goroutine but returns the error
// instead of asserting it. This is useful for testing error cases.
func StartWithWaiter(t *testing.T, inv *clibase.Invocation) *ErrorWaiter {
t.Helper()
errCh := make(chan error, 1)
var cleaningUp atomic.Bool
var (
ctx = inv.Context()
cancel func()
)
if _, ok := ctx.Deadline(); !ok {
ctx, cancel = context.WithDeadline(ctx, time.Now().Add(testutil.WaitMedium))
} else {
ctx, cancel = context.WithCancel(inv.Context())
}
inv = inv.WithContext(ctx)
go func() {
defer close(errCh)
err := inv.Run()
if cleaningUp.Load() && errors.Is(err, context.DeadlineExceeded) {
// If we're cleaning up, this error is likely related to the
// CLI teardown process. E.g., the server could be slow to shut
// down Postgres.
t.Logf("command %q timed out during test cleanup", inv.Command.FullName())
}
errCh <- err
}()
// Don't exit test routine until server is done. // Don't exit test routine until server is done.
t.Cleanup(func() { t.Cleanup(func() {
cancel() cancel()
<-closeCh cleaningUp.Store(true)
<-errCh
}) })
return &ErrorWaiter{c: errCh, t: t}
} }

View File

@ -18,13 +18,9 @@ func TestCli(t *testing.T) {
t.Parallel() t.Parallel()
clitest.CreateTemplateVersionSource(t, nil) clitest.CreateTemplateVersionSource(t, nil)
client := coderdtest.New(t, nil) client := coderdtest.New(t, nil)
cmd, config := clitest.New(t) i, config := clitest.New(t)
clitest.SetupConfig(t, client, config) clitest.SetupConfig(t, client, config)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(i)
cmd.SetIn(pty.Input()) clitest.Start(t, i)
cmd.SetOut(pty.Output())
go func() {
_ = cmd.Execute()
}()
pty.ExpectMatch("coder") pty.ExpectMatch("coder")
} }

View File

@ -5,11 +5,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/atomic" "go.uber.org/atomic"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
@ -24,9 +24,9 @@ func TestAgent(t *testing.T) {
var disconnected atomic.Bool var disconnected atomic.Bool
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, _ []string) error { Handler: func(inv *clibase.Invocation) error {
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
WorkspaceName: "example", WorkspaceName: "example",
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
agent := codersdk.WorkspaceAgent{ agent := codersdk.WorkspaceAgent{
@ -44,12 +44,13 @@ func TestAgent(t *testing.T) {
return err return err
}, },
} }
cmd.SetOutput(ptty.Output())
cmd.SetIn(ptty.Input()) inv := cmd.Invoke()
ptty.Attach(inv)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
ptty.ExpectMatchContext(ctx, "lost connection") ptty.ExpectMatchContext(ctx, "lost connection")
@ -66,9 +67,9 @@ func TestAgent_TimeoutWithTroubleshootingURL(t *testing.T) {
wantURL := "https://coder.com/troubleshoot" wantURL := "https://coder.com/troubleshoot"
var connected, timeout atomic.Bool var connected, timeout atomic.Bool
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, _ []string) error { Handler: func(inv *clibase.Invocation) error {
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
WorkspaceName: "example", WorkspaceName: "example",
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
agent := codersdk.WorkspaceAgent{ agent := codersdk.WorkspaceAgent{
@ -91,11 +92,12 @@ func TestAgent_TimeoutWithTroubleshootingURL(t *testing.T) {
}, },
} }
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd.SetOutput(ptty.Output())
cmd.SetIn(ptty.Input()) inv := cmd.Invoke()
ptty.Attach(inv)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- cmd.ExecuteContext(ctx) done <- inv.WithContext(ctx).Run()
}() }()
ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting") ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting")
timeout.Store(true) timeout.Store(true)
@ -115,9 +117,10 @@ func TestAgent_StartupTimeout(t *testing.T) {
var status, state atomic.String var status, state atomic.String
setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) } setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) }
setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) } setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) }
cmd := &cobra.Command{
RunE: func(cmd *cobra.Command, _ []string) error { cmd := &clibase.Cmd{
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ Handler: func(inv *clibase.Invocation) error {
err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
WorkspaceName: "example", WorkspaceName: "example",
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
agent := codersdk.WorkspaceAgent{ agent := codersdk.WorkspaceAgent{
@ -144,11 +147,12 @@ func TestAgent_StartupTimeout(t *testing.T) {
} }
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd.SetOutput(ptty.Output())
cmd.SetIn(ptty.Input()) inv := cmd.Invoke()
ptty.Attach(inv)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- cmd.ExecuteContext(ctx) done <- inv.WithContext(ctx).Run()
}() }()
setStatus(codersdk.WorkspaceAgentConnecting) setStatus(codersdk.WorkspaceAgentConnecting)
ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting") ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting")
@ -173,9 +177,9 @@ func TestAgent_StartErrorExit(t *testing.T) {
var status, state atomic.String var status, state atomic.String
setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) } setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) }
setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) } setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) }
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, _ []string) error { Handler: func(inv *clibase.Invocation) error {
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
WorkspaceName: "example", WorkspaceName: "example",
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
agent := codersdk.WorkspaceAgent{ agent := codersdk.WorkspaceAgent{
@ -202,11 +206,12 @@ func TestAgent_StartErrorExit(t *testing.T) {
} }
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd.SetOutput(ptty.Output())
cmd.SetIn(ptty.Input()) inv := cmd.Invoke()
ptty.Attach(inv)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- cmd.ExecuteContext(ctx) done <- inv.WithContext(ctx).Run()
}() }()
setStatus(codersdk.WorkspaceAgentConnected) setStatus(codersdk.WorkspaceAgentConnected)
setState(codersdk.WorkspaceAgentLifecycleStarting) setState(codersdk.WorkspaceAgentLifecycleStarting)
@ -228,9 +233,9 @@ func TestAgent_NoWait(t *testing.T) {
var status, state atomic.String var status, state atomic.String
setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) } setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) }
setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) } setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) }
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, _ []string) error { Handler: func(inv *clibase.Invocation) error {
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
WorkspaceName: "example", WorkspaceName: "example",
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
agent := codersdk.WorkspaceAgent{ agent := codersdk.WorkspaceAgent{
@ -257,11 +262,12 @@ func TestAgent_NoWait(t *testing.T) {
} }
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd.SetOutput(ptty.Output())
cmd.SetIn(ptty.Input()) inv := cmd.Invoke()
ptty.Attach(inv)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- cmd.ExecuteContext(ctx) done <- inv.WithContext(ctx).Run()
}() }()
setStatus(codersdk.WorkspaceAgentConnecting) setStatus(codersdk.WorkspaceAgentConnecting)
ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting") ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting")
@ -270,19 +276,19 @@ func TestAgent_NoWait(t *testing.T) {
require.NoError(t, <-done, "created - should exit early") require.NoError(t, <-done, "created - should exit early")
setState(codersdk.WorkspaceAgentLifecycleStarting) setState(codersdk.WorkspaceAgentLifecycleStarting)
go func() { done <- cmd.ExecuteContext(ctx) }() go func() { done <- inv.WithContext(ctx).Run() }()
require.NoError(t, <-done, "starting - should exit early") require.NoError(t, <-done, "starting - should exit early")
setState(codersdk.WorkspaceAgentLifecycleStartTimeout) setState(codersdk.WorkspaceAgentLifecycleStartTimeout)
go func() { done <- cmd.ExecuteContext(ctx) }() go func() { done <- inv.WithContext(ctx).Run() }()
require.NoError(t, <-done, "start timeout - should exit early") require.NoError(t, <-done, "start timeout - should exit early")
setState(codersdk.WorkspaceAgentLifecycleStartError) setState(codersdk.WorkspaceAgentLifecycleStartError)
go func() { done <- cmd.ExecuteContext(ctx) }() go func() { done <- inv.WithContext(ctx).Run() }()
require.NoError(t, <-done, "start error - should exit early") require.NoError(t, <-done, "start error - should exit early")
setState(codersdk.WorkspaceAgentLifecycleReady) setState(codersdk.WorkspaceAgentLifecycleReady)
go func() { done <- cmd.ExecuteContext(ctx) }() go func() { done <- inv.WithContext(ctx).Run() }()
require.NoError(t, <-done, "ready - should exit early") require.NoError(t, <-done, "ready - should exit early")
} }
@ -297,9 +303,9 @@ func TestAgent_LoginBeforeReadyEnabled(t *testing.T) {
var status, state atomic.String var status, state atomic.String
setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) } setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) }
setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) } setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) }
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, _ []string) error { Handler: func(inv *clibase.Invocation) error {
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
WorkspaceName: "example", WorkspaceName: "example",
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
agent := codersdk.WorkspaceAgent{ agent := codersdk.WorkspaceAgent{
@ -325,12 +331,13 @@ func TestAgent_LoginBeforeReadyEnabled(t *testing.T) {
}, },
} }
inv := cmd.Invoke()
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd.SetOutput(ptty.Output()) ptty.Attach(inv)
cmd.SetIn(ptty.Input())
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- cmd.ExecuteContext(ctx) done <- inv.WithContext(ctx).Run()
}() }()
setStatus(codersdk.WorkspaceAgentConnecting) setStatus(codersdk.WorkspaceAgentConnecting)
ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting") ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting")
@ -339,18 +346,18 @@ func TestAgent_LoginBeforeReadyEnabled(t *testing.T) {
require.NoError(t, <-done, "created - should exit early") require.NoError(t, <-done, "created - should exit early")
setState(codersdk.WorkspaceAgentLifecycleStarting) setState(codersdk.WorkspaceAgentLifecycleStarting)
go func() { done <- cmd.ExecuteContext(ctx) }() go func() { done <- inv.WithContext(ctx).Run() }()
require.NoError(t, <-done, "starting - should exit early") require.NoError(t, <-done, "starting - should exit early")
setState(codersdk.WorkspaceAgentLifecycleStartTimeout) setState(codersdk.WorkspaceAgentLifecycleStartTimeout)
go func() { done <- cmd.ExecuteContext(ctx) }() go func() { done <- inv.WithContext(ctx).Run() }()
require.NoError(t, <-done, "start timeout - should exit early") require.NoError(t, <-done, "start timeout - should exit early")
setState(codersdk.WorkspaceAgentLifecycleStartError) setState(codersdk.WorkspaceAgentLifecycleStartError)
go func() { done <- cmd.ExecuteContext(ctx) }() go func() { done <- inv.WithContext(ctx).Run() }()
require.NoError(t, <-done, "start error - should exit early") require.NoError(t, <-done, "start error - should exit early")
setState(codersdk.WorkspaceAgentLifecycleReady) setState(codersdk.WorkspaceAgentLifecycleReady)
go func() { done <- cmd.ExecuteContext(ctx) }() go func() { done <- inv.WithContext(ctx).Run() }()
require.NoError(t, <-done, "ready - should exit early") require.NoError(t, <-done, "ready - should exit early")
} }

View File

@ -53,6 +53,8 @@ var Styles = struct {
FocusedPrompt: defaultStyles.FocusedPrompt.Foreground(lipgloss.Color("#651fff")), FocusedPrompt: defaultStyles.FocusedPrompt.Foreground(lipgloss.Color("#651fff")),
Fuchsia: defaultStyles.SelectedMenuItem.Copy(), Fuchsia: defaultStyles.SelectedMenuItem.Copy(),
Logo: defaultStyles.Logo.SetString("Coder"), Logo: defaultStyles.Logo.SetString("Coder"),
Warn: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#04B575", Dark: "#ECFD65"}), Warn: lipgloss.NewStyle().Foreground(
lipgloss.AdaptiveColor{Light: "#04B575", Dark: "#ECFD65"},
),
Wrap: lipgloss.NewStyle().Width(80), Wrap: lipgloss.NewStyle().Width(80),
} }

View File

@ -7,9 +7,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
@ -23,10 +23,10 @@ func TestGitAuth(t *testing.T) {
defer cancel() defer cancel()
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
var fetched atomic.Bool var fetched atomic.Bool
return cliui.GitAuth(cmd.Context(), cmd.OutOrStdout(), cliui.GitAuthOptions{ return cliui.GitAuth(inv.Context(), inv.Stdout, cliui.GitAuthOptions{
Fetch: func(ctx context.Context) ([]codersdk.TemplateVersionGitAuth, error) { Fetch: func(ctx context.Context) ([]codersdk.TemplateVersionGitAuth, error) {
defer fetched.Store(true) defer fetched.Store(true)
return []codersdk.TemplateVersionGitAuth{{ return []codersdk.TemplateVersionGitAuth{{
@ -40,12 +40,14 @@ func TestGitAuth(t *testing.T) {
}) })
}, },
} }
cmd.SetOutput(ptty.Output())
cmd.SetIn(ptty.Input()) inv := cmd.Invoke().WithContext(ctx)
ptty.Attach(inv)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
ptty.ExpectMatchContext(ctx, "You must authenticate with") ptty.ExpectMatchContext(ctx, "You must authenticate with")

View File

@ -10,17 +10,22 @@ import (
// cliMessage provides a human-readable message for CLI errors and messages. // cliMessage provides a human-readable message for CLI errors and messages.
type cliMessage struct { type cliMessage struct {
Level string
Style lipgloss.Style Style lipgloss.Style
Header string Header string
Prefix string
Lines []string Lines []string
} }
// String formats the CLI message for consumption by a human. // String formats the CLI message for consumption by a human.
func (m cliMessage) String() string { func (m cliMessage) String() string {
var str strings.Builder var str strings.Builder
_, _ = fmt.Fprintf(&str, "%s\r\n",
Styles.Bold.Render(m.Header)) if m.Prefix != "" {
_, _ = str.WriteString(m.Style.Bold(true).Render(m.Prefix))
}
_, _ = str.WriteString(m.Style.Bold(false).Render(m.Header))
_, _ = str.WriteString("\r\n")
for _, line := range m.Lines { for _, line := range m.Lines {
_, _ = fmt.Fprintf(&str, " %s %s\r\n", m.Style.Render("|"), line) _, _ = fmt.Fprintf(&str, " %s %s\r\n", m.Style.Render("|"), line)
} }
@ -30,9 +35,42 @@ func (m cliMessage) String() string {
// Warn writes a log to the writer provided. // Warn writes a log to the writer provided.
func Warn(wtr io.Writer, header string, lines ...string) { func Warn(wtr io.Writer, header string, lines ...string) {
_, _ = fmt.Fprint(wtr, cliMessage{ _, _ = fmt.Fprint(wtr, cliMessage{
Level: "warning",
Style: Styles.Warn, Style: Styles.Warn,
Prefix: "WARN: ",
Header: header, Header: header,
Lines: lines, Lines: lines,
}.String()) }.String())
} }
// Warn writes a formatted log to the writer provided.
func Warnf(wtr io.Writer, fmtStr string, args ...interface{}) {
Warn(wtr, fmt.Sprintf(fmtStr, args...))
}
// Info writes a log to the writer provided.
func Info(wtr io.Writer, header string, lines ...string) {
_, _ = fmt.Fprint(wtr, cliMessage{
Header: header,
Lines: lines,
}.String())
}
// Infof writes a formatted log to the writer provided.
func Infof(wtr io.Writer, fmtStr string, args ...interface{}) {
Info(wtr, fmt.Sprintf(fmtStr, args...))
}
// Error writes a log to the writer provided.
func Error(wtr io.Writer, header string, lines ...string) {
_, _ = fmt.Fprint(wtr, cliMessage{
Style: Styles.Error,
Prefix: "ERROR: ",
Header: header,
Lines: lines,
}.String())
}
// Errorf writes a formatted log to the writer provided.
func Errorf(wtr io.Writer, fmtStr string, args ...interface{}) {
Error(wtr, fmt.Sprintf(fmtStr, args...))
}

View File

@ -6,13 +6,14 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
) )
type OutputFormat interface { type OutputFormat interface {
ID() string ID() string
AttachFlags(cmd *cobra.Command) AttachOptions(opts *clibase.OptionSet)
Format(ctx context.Context, data any) (string, error) Format(ctx context.Context, data any) (string, error)
} }
@ -45,11 +46,11 @@ func NewOutputFormatter(formats ...OutputFormat) *OutputFormatter {
} }
} }
// AttachFlags attaches the --output flag to the given command, and any // AttachOptions attaches the --output flag to the given command, and any
// additional flags required by the output formatters. // additional flags required by the output formatters.
func (f *OutputFormatter) AttachFlags(cmd *cobra.Command) { func (f *OutputFormatter) AttachOptions(opts *clibase.OptionSet) {
for _, format := range f.formats { for _, format := range f.formats {
format.AttachFlags(cmd) format.AttachOptions(opts)
} }
formatNames := make([]string, 0, len(f.formats)) formatNames := make([]string, 0, len(f.formats))
@ -57,7 +58,15 @@ func (f *OutputFormatter) AttachFlags(cmd *cobra.Command) {
formatNames = append(formatNames, format.ID()) formatNames = append(formatNames, format.ID())
} }
cmd.Flags().StringVarP(&f.formatID, "output", "o", f.formats[0].ID(), "Output format. Available formats: "+strings.Join(formatNames, ", ")) *opts = append(*opts,
clibase.Option{
Flag: "output",
FlagShorthand: "o",
Default: f.formats[0].ID(),
Value: clibase.StringOf(&f.formatID),
Description: "Output format. Available formats: " + strings.Join(formatNames, ", ") + ".",
},
)
} }
// Format formats the given data using the format specified by the --output // Format formats the given data using the format specified by the --output
@ -118,9 +127,17 @@ func (*tableFormat) ID() string {
return "table" return "table"
} }
// AttachFlags implements OutputFormat. // AttachOptions implements OutputFormat.
func (f *tableFormat) AttachFlags(cmd *cobra.Command) { func (f *tableFormat) AttachOptions(opts *clibase.OptionSet) {
cmd.Flags().StringSliceVarP(&f.columns, "column", "c", f.defaultColumns, "Columns to display in table output. Available columns: "+strings.Join(f.allColumns, ", ")) *opts = append(*opts,
clibase.Option{
Flag: "column",
FlagShorthand: "c",
Default: strings.Join(f.defaultColumns, ","),
Value: clibase.StringArrayOf(&f.columns),
Description: "Columns to display in table output. Available columns: " + strings.Join(f.allColumns, ", ") + ".",
},
)
} }
// Format implements OutputFormat. // Format implements OutputFormat.
@ -142,8 +159,8 @@ func (jsonFormat) ID() string {
return "json" return "json"
} }
// AttachFlags implements OutputFormat. // AttachOptions implements OutputFormat.
func (jsonFormat) AttachFlags(_ *cobra.Command) {} func (jsonFormat) AttachOptions(_ *clibase.OptionSet) {}
// Format implements OutputFormat. // Format implements OutputFormat.
func (jsonFormat) Format(_ context.Context, data any) (string, error) { func (jsonFormat) Format(_ context.Context, data any) (string, error) {

View File

@ -6,15 +6,15 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
) )
type format struct { type format struct {
id string id string
attachFlagsFn func(cmd *cobra.Command) attachOptionsFn func(opts *clibase.OptionSet)
formatFn func(ctx context.Context, data any) (string, error) formatFn func(ctx context.Context, data any) (string, error)
} }
@ -24,9 +24,9 @@ func (f *format) ID() string {
return f.id return f.id
} }
func (f *format) AttachFlags(cmd *cobra.Command) { func (f *format) AttachOptions(opts *clibase.OptionSet) {
if f.attachFlagsFn != nil { if f.attachOptionsFn != nil {
f.attachFlagsFn(cmd) f.attachOptionsFn(opts)
} }
} }
@ -82,8 +82,14 @@ func Test_OutputFormatter(t *testing.T) {
cliui.JSONFormat(), cliui.JSONFormat(),
&format{ &format{
id: "foo", id: "foo",
attachFlagsFn: func(cmd *cobra.Command) { attachOptionsFn: func(opts *clibase.OptionSet) {
cmd.Flags().StringP("foo", "f", "", "foo flag 1234") opts.Add(clibase.Option{
Name: "foo",
Flag: "foo",
FlagShorthand: "f",
Value: clibase.DiscardValue,
Description: "foo flag 1234",
})
}, },
formatFn: func(_ context.Context, _ any) (string, error) { formatFn: func(_ context.Context, _ any) (string, error) {
atomic.AddInt64(&called, 1) atomic.AddInt64(&called, 1)
@ -92,13 +98,15 @@ func Test_OutputFormatter(t *testing.T) {
}, },
) )
cmd := &cobra.Command{} cmd := &clibase.Cmd{}
f.AttachFlags(cmd) f.AttachOptions(&cmd.Options)
selected, err := cmd.Flags().GetString("output") fs := cmd.Options.FlagSet()
selected, err := fs.GetString("output")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "json", selected) require.Equal(t, "json", selected)
usage := cmd.Flags().FlagUsages() usage := fs.FlagUsages()
require.Contains(t, usage, "Available formats: json, foo") require.Contains(t, usage, "Available formats: json, foo")
require.Contains(t, usage, "foo flag 1234") require.Contains(t, usage, "foo flag 1234")
@ -112,13 +120,13 @@ func Test_OutputFormatter(t *testing.T) {
require.Equal(t, data, got) require.Equal(t, data, got)
require.EqualValues(t, 0, atomic.LoadInt64(&called)) require.EqualValues(t, 0, atomic.LoadInt64(&called))
require.NoError(t, cmd.Flags().Set("output", "foo")) require.NoError(t, fs.Set("output", "foo"))
out, err = f.Format(ctx, data) out, err = f.Format(ctx, data)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "foo", out) require.Equal(t, "foo", out)
require.EqualValues(t, 1, atomic.LoadInt64(&called)) require.EqualValues(t, 1, atomic.LoadInt64(&called))
require.NoError(t, cmd.Flags().Set("output", "bar")) require.NoError(t, fs.Set("output", "bar"))
out, err = f.Format(ctx, data) out, err = f.Format(ctx, data)
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "bar") require.ErrorContains(t, err, "bar")

View File

@ -5,16 +5,15 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/coderd/parameter" "github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func ParameterSchema(cmd *cobra.Command, parameterSchema codersdk.ParameterSchema) (string, error) { func ParameterSchema(inv *clibase.Invocation, parameterSchema codersdk.ParameterSchema) (string, error) {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), Styles.Bold.Render("var."+parameterSchema.Name)) _, _ = fmt.Fprintln(inv.Stdout, Styles.Bold.Render("var."+parameterSchema.Name))
if parameterSchema.Description != "" { if parameterSchema.Description != "" {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+strings.TrimSpace(strings.Join(strings.Split(parameterSchema.Description, "\n"), "\n "))+"\n") _, _ = fmt.Fprintln(inv.Stdout, " "+strings.TrimSpace(strings.Join(strings.Split(parameterSchema.Description, "\n"), "\n "))+"\n")
} }
var err error var err error
@ -28,15 +27,15 @@ func ParameterSchema(cmd *cobra.Command, parameterSchema codersdk.ParameterSchem
var value string var value string
if len(options) > 0 { if len(options) > 0 {
// Move the cursor up a single line for nicer display! // Move the cursor up a single line for nicer display!
_, _ = fmt.Fprint(cmd.OutOrStdout(), "\033[1A") _, _ = fmt.Fprint(inv.Stdout, "\033[1A")
value, err = Select(cmd, SelectOptions{ value, err = Select(inv, SelectOptions{
Options: options, Options: options,
Default: parameterSchema.DefaultSourceValue, Default: parameterSchema.DefaultSourceValue,
HideSearch: true, HideSearch: true,
}) })
if err == nil { if err == nil {
_, _ = fmt.Fprintln(cmd.OutOrStdout()) _, _ = fmt.Fprintln(inv.Stdout)
_, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+Styles.Prompt.String()+Styles.Field.Render(value)) _, _ = fmt.Fprintln(inv.Stdout, " "+Styles.Prompt.String()+Styles.Field.Render(value))
} }
} else { } else {
text := "Enter a value" text := "Enter a value"
@ -45,7 +44,7 @@ func ParameterSchema(cmd *cobra.Command, parameterSchema codersdk.ParameterSchem
} }
text += ":" text += ":"
value, err = Prompt(cmd, PromptOptions{ value, err = Prompt(inv, PromptOptions{
Text: Styles.Bold.Render(text), Text: Styles.Bold.Render(text),
}) })
value = strings.TrimSpace(value) value = strings.TrimSpace(value)
@ -62,17 +61,17 @@ func ParameterSchema(cmd *cobra.Command, parameterSchema codersdk.ParameterSchem
return value, nil return value, nil
} }
func RichParameter(cmd *cobra.Command, templateVersionParameter codersdk.TemplateVersionParameter) (string, error) { func RichParameter(inv *clibase.Invocation, templateVersionParameter codersdk.TemplateVersionParameter) (string, error) {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), Styles.Bold.Render(templateVersionParameter.Name)) _, _ = fmt.Fprintln(inv.Stdout, Styles.Bold.Render(templateVersionParameter.Name))
if templateVersionParameter.DescriptionPlaintext != "" { if templateVersionParameter.DescriptionPlaintext != "" {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+strings.TrimSpace(strings.Join(strings.Split(templateVersionParameter.DescriptionPlaintext, "\n"), "\n "))+"\n") _, _ = fmt.Fprintln(inv.Stdout, " "+strings.TrimSpace(strings.Join(strings.Split(templateVersionParameter.DescriptionPlaintext, "\n"), "\n "))+"\n")
} }
var err error var err error
var value string var value string
if templateVersionParameter.Type == "list(string)" { if templateVersionParameter.Type == "list(string)" {
// Move the cursor up a single line for nicer display! // Move the cursor up a single line for nicer display!
_, _ = fmt.Fprint(cmd.OutOrStdout(), "\033[1A") _, _ = fmt.Fprint(inv.Stdout, "\033[1A")
var options []string var options []string
err = json.Unmarshal([]byte(templateVersionParameter.DefaultValue), &options) err = json.Unmarshal([]byte(templateVersionParameter.DefaultValue), &options)
@ -80,29 +79,29 @@ func RichParameter(cmd *cobra.Command, templateVersionParameter codersdk.Templat
return "", err return "", err
} }
values, err := MultiSelect(cmd, options) values, err := MultiSelect(inv, options)
if err == nil { if err == nil {
v, err := json.Marshal(&values) v, err := json.Marshal(&values)
if err != nil { if err != nil {
return "", err return "", err
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout()) _, _ = fmt.Fprintln(inv.Stdout)
_, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+Styles.Prompt.String()+Styles.Field.Render(strings.Join(values, ", "))) _, _ = fmt.Fprintln(inv.Stdout, " "+Styles.Prompt.String()+Styles.Field.Render(strings.Join(values, ", ")))
value = string(v) value = string(v)
} }
} else if len(templateVersionParameter.Options) > 0 { } else if len(templateVersionParameter.Options) > 0 {
// Move the cursor up a single line for nicer display! // Move the cursor up a single line for nicer display!
_, _ = fmt.Fprint(cmd.OutOrStdout(), "\033[1A") _, _ = fmt.Fprint(inv.Stdout, "\033[1A")
var richParameterOption *codersdk.TemplateVersionParameterOption var richParameterOption *codersdk.TemplateVersionParameterOption
richParameterOption, err = RichSelect(cmd, RichSelectOptions{ richParameterOption, err = RichSelect(inv, RichSelectOptions{
Options: templateVersionParameter.Options, Options: templateVersionParameter.Options,
Default: templateVersionParameter.DefaultValue, Default: templateVersionParameter.DefaultValue,
HideSearch: true, HideSearch: true,
}) })
if err == nil { if err == nil {
_, _ = fmt.Fprintln(cmd.OutOrStdout()) _, _ = fmt.Fprintln(inv.Stdout)
_, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+Styles.Prompt.String()+Styles.Field.Render(richParameterOption.Name)) _, _ = fmt.Fprintln(inv.Stdout, " "+Styles.Prompt.String()+Styles.Field.Render(richParameterOption.Name))
value = richParameterOption.Value value = richParameterOption.Value
} }
} else { } else {
@ -112,7 +111,7 @@ func RichParameter(cmd *cobra.Command, templateVersionParameter codersdk.Templat
} }
text += ":" text += ":"
value, err = Prompt(cmd, PromptOptions{ value, err = Prompt(inv, PromptOptions{
Text: Styles.Bold.Render(text), Text: Styles.Bold.Render(text),
Validate: func(value string) error { Validate: func(value string) error {
return validateRichPrompt(value, templateVersionParameter) return validateRichPrompt(value, templateVersionParameter)

View File

@ -11,8 +11,9 @@ import (
"github.com/bgentry/speakeasy" "github.com/bgentry/speakeasy"
"github.com/mattn/go-isatty" "github.com/mattn/go-isatty"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
) )
// PromptOptions supply a set of options to the prompt. // PromptOptions supply a set of options to the prompt.
@ -26,8 +27,16 @@ type PromptOptions struct {
const skipPromptFlag = "yes" const skipPromptFlag = "yes"
func AllowSkipPrompt(cmd *cobra.Command) { // SkipPromptOption adds a "--yes/-y" flag to the cmd that can be used to skip
cmd.Flags().BoolP(skipPromptFlag, "y", false, "Bypass prompts") // prompts.
func SkipPromptOption() clibase.Option {
return clibase.Option{
Flag: skipPromptFlag,
FlagShorthand: "y",
Description: "Bypass prompts.",
// Discard
Value: clibase.BoolOf(new(bool)),
}
} }
const ( const (
@ -36,17 +45,17 @@ const (
) )
// Prompt asks the user for input. // Prompt asks the user for input.
func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) { func Prompt(inv *clibase.Invocation, opts PromptOptions) (string, error) {
// If the cmd has a "yes" flag for skipping confirm prompts, honor it. // If the cmd has a "yes" flag for skipping confirm prompts, honor it.
// If it's not a "Confirm" prompt, then don't skip. As the default value of // If it's not a "Confirm" prompt, then don't skip. As the default value of
// "yes" makes no sense. // "yes" makes no sense.
if opts.IsConfirm && cmd.Flags().Lookup(skipPromptFlag) != nil { if opts.IsConfirm && inv.ParsedFlags().Lookup(skipPromptFlag) != nil {
if skip, _ := cmd.Flags().GetBool(skipPromptFlag); skip { if skip, _ := inv.ParsedFlags().GetBool(skipPromptFlag); skip {
return ConfirmYes, nil return ConfirmYes, nil
} }
} }
_, _ = fmt.Fprint(cmd.OutOrStdout(), Styles.FocusedPrompt.String()+opts.Text+" ") _, _ = fmt.Fprint(inv.Stdout, Styles.FocusedPrompt.String()+opts.Text+" ")
if opts.IsConfirm { if opts.IsConfirm {
if len(opts.Default) == 0 { if len(opts.Default) == 0 {
opts.Default = ConfirmYes opts.Default = ConfirmYes
@ -58,19 +67,24 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) {
} else { } else {
renderedNo = Styles.Bold.Render(ConfirmNo) renderedNo = Styles.Bold.Render(ConfirmNo)
} }
_, _ = fmt.Fprint(cmd.OutOrStdout(), Styles.Placeholder.Render("("+renderedYes+Styles.Placeholder.Render("/"+renderedNo+Styles.Placeholder.Render(") ")))) _, _ = fmt.Fprint(inv.Stdout, Styles.Placeholder.Render("("+renderedYes+Styles.Placeholder.Render("/"+renderedNo+Styles.Placeholder.Render(") "))))
} else if opts.Default != "" { } else if opts.Default != "" {
_, _ = fmt.Fprint(cmd.OutOrStdout(), Styles.Placeholder.Render("("+opts.Default+") ")) _, _ = fmt.Fprint(inv.Stdout, Styles.Placeholder.Render("("+opts.Default+") "))
} }
interrupt := make(chan os.Signal, 1) interrupt := make(chan os.Signal, 1)
if inv.Stdin == nil {
panic("inv.Stdin is nil")
}
errCh := make(chan error, 1) errCh := make(chan error, 1)
lineCh := make(chan string) lineCh := make(chan string)
go func() { go func() {
var line string var line string
var err error var err error
inFile, isInputFile := cmd.InOrStdin().(*os.File) inFile, isInputFile := inv.Stdin.(*os.File)
if opts.Secret && isInputFile && isatty.IsTerminal(inFile.Fd()) { if opts.Secret && isInputFile && isatty.IsTerminal(inFile.Fd()) {
// we don't install a signal handler here because speakeasy has its own // we don't install a signal handler here because speakeasy has its own
line, err = speakeasy.Ask("") line, err = speakeasy.Ask("")
@ -78,7 +92,7 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) {
signal.Notify(interrupt, os.Interrupt) signal.Notify(interrupt, os.Interrupt)
defer signal.Stop(interrupt) defer signal.Stop(interrupt)
reader := bufio.NewReader(cmd.InOrStdin()) reader := bufio.NewReader(inv.Stdin)
line, err = reader.ReadString('\n') line, err = reader.ReadString('\n')
// Check if the first line beings with JSON object or array chars. // Check if the first line beings with JSON object or array chars.
@ -96,7 +110,10 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) {
if line == "" { if line == "" {
line = opts.Default line = opts.Default
} }
lineCh <- line select {
case <-inv.Context().Done():
case lineCh <- line:
}
}() }()
select { select {
@ -109,16 +126,16 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) {
if opts.Validate != nil { if opts.Validate != nil {
err := opts.Validate(line) err := opts.Validate(line)
if err != nil { if err != nil {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), defaultStyles.Error.Render(err.Error())) _, _ = fmt.Fprintln(inv.Stdout, defaultStyles.Error.Render(err.Error()))
return Prompt(cmd, opts) return Prompt(inv, opts)
} }
} }
return line, nil return line, nil
case <-cmd.Context().Done(): case <-inv.Context().Done():
return "", cmd.Context().Err() return "", inv.Context().Err()
case <-interrupt: case <-interrupt:
// Print a newline so that any further output starts properly on a new line. // Print a newline so that any further output starts properly on a new line.
_, _ = fmt.Fprintln(cmd.OutOrStdout()) _, _ = fmt.Fprintln(inv.Stdout)
return "", Canceled return "", Canceled
} }
} }

View File

@ -8,10 +8,10 @@ import (
"os/exec" "os/exec"
"testing" "testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/pty" "github.com/coder/coder/pty"
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
@ -77,9 +77,9 @@ func TestPrompt(t *testing.T) {
resp, err := newPrompt(ptty, cliui.PromptOptions{ resp, err := newPrompt(ptty, cliui.PromptOptions{
Text: "ShouldNotSeeThis", Text: "ShouldNotSeeThis",
IsConfirm: true, IsConfirm: true,
}, func(cmd *cobra.Command) { }, func(inv *clibase.Invocation) {
cliui.AllowSkipPrompt(cmd) inv.Command.Options = append(inv.Command.Options, cliui.SkipPromptOption())
cmd.SetArgs([]string{"-y"}) inv.Args = []string{"-y"}
}) })
assert.NoError(t, err) assert.NoError(t, err)
doneChan <- resp doneChan <- resp
@ -145,23 +145,25 @@ func TestPrompt(t *testing.T) {
}) })
} }
func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, cmdOpt func(cmd *cobra.Command)) (string, error) { func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *clibase.Invocation)) (string, error) {
value := "" value := ""
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
var err error var err error
value, err = cliui.Prompt(cmd, opts) value, err = cliui.Prompt(inv, opts)
return err return err
}, },
} }
inv := cmd.Invoke()
// Optionally modify the cmd // Optionally modify the cmd
if cmdOpt != nil { if invOpt != nil {
cmdOpt(cmd) invOpt(inv)
} }
cmd.SetOut(ptty.Output()) inv.Stdout = ptty.Output()
cmd.SetErr(ptty.Output()) inv.Stderr = ptty.Output()
cmd.SetIn(ptty.Input()) inv.Stdin = ptty.Input()
return value, cmd.ExecuteContext(context.Background()) return value, inv.WithContext(context.Background()).Run()
} }
func TestPasswordTerminalState(t *testing.T) { func TestPasswordTerminalState(t *testing.T) {
@ -208,13 +210,17 @@ func TestPasswordTerminalState(t *testing.T) {
// nolint:unused // nolint:unused
func passwordHelper() { func passwordHelper() {
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Run: func(cmd *cobra.Command, args []string) { Handler: func(inv *clibase.Invocation) error {
cliui.Prompt(cmd, cliui.PromptOptions{ cliui.Prompt(inv, cliui.PromptOptions{
Text: "Password:", Text: "Password:",
Secret: true, Secret: true,
}) })
return nil
}, },
} }
cmd.ExecuteContext(context.Background()) err := cmd.Invoke().WithOS().Run()
if err != nil {
panic(err)
}
} }

View File

@ -9,9 +9,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
@ -125,9 +125,9 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
} }
jobLock := sync.Mutex{} jobLock := sync.Mutex{}
logs := make(chan codersdk.ProvisionerJobLog, 1) logs := make(chan codersdk.ProvisionerJobLog, 1)
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
return cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{ return cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{
FetchInterval: time.Millisecond, FetchInterval: time.Millisecond,
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
jobLock.Lock() jobLock.Lock()
@ -145,13 +145,14 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
}) })
}, },
} }
inv := cmd.Invoke()
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd.SetOutput(ptty.Output()) ptty.Attach(inv)
cmd.SetIn(ptty.Input())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
err := cmd.ExecuteContext(context.Background()) err := inv.WithContext(context.Background()).Run()
if err != nil { if err != nil {
assert.ErrorIs(t, err, cliui.Canceled) assert.ErrorIs(t, err, cliui.Canceled)
} }

View File

@ -8,9 +8,9 @@ import (
"github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal" "github.com/AlecAivazis/survey/v2/terminal"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
@ -68,7 +68,7 @@ type RichSelectOptions struct {
} }
// RichSelect displays a list of user options including name and description. // RichSelect displays a list of user options including name and description.
func RichSelect(cmd *cobra.Command, richOptions RichSelectOptions) (*codersdk.TemplateVersionParameterOption, error) { func RichSelect(inv *clibase.Invocation, richOptions RichSelectOptions) (*codersdk.TemplateVersionParameterOption, error) {
opts := make([]string, len(richOptions.Options)) opts := make([]string, len(richOptions.Options))
for i, option := range richOptions.Options { for i, option := range richOptions.Options {
line := option.Name line := option.Name
@ -78,7 +78,7 @@ func RichSelect(cmd *cobra.Command, richOptions RichSelectOptions) (*codersdk.Te
opts[i] = line opts[i] = line
} }
selected, err := Select(cmd, SelectOptions{ selected, err := Select(inv, SelectOptions{
Options: opts, Options: opts,
Default: richOptions.Default, Default: richOptions.Default,
Size: richOptions.Size, Size: richOptions.Size,
@ -97,7 +97,7 @@ func RichSelect(cmd *cobra.Command, richOptions RichSelectOptions) (*codersdk.Te
} }
// Select displays a list of user options. // Select displays a list of user options.
func Select(cmd *cobra.Command, opts SelectOptions) (string, error) { func Select(inv *clibase.Invocation, opts SelectOptions) (string, error) {
// The survey library used *always* fails when testing on Windows, // The survey library used *always* fails when testing on Windows,
// as it requires a live TTY (can't be a conpty). We should fork // as it requires a live TTY (can't be a conpty). We should fork
// this library to add a dummy fallback, that simply reads/writes // this library to add a dummy fallback, that simply reads/writes
@ -123,17 +123,17 @@ func Select(cmd *cobra.Command, opts SelectOptions) (string, error) {
is.Help.Text = "" is.Help.Text = ""
} }
}), survey.WithStdio(fileReadWriter{ }), survey.WithStdio(fileReadWriter{
Reader: cmd.InOrStdin(), Reader: inv.Stdin,
}, fileReadWriter{ }, fileReadWriter{
Writer: cmd.OutOrStdout(), Writer: inv.Stdout,
}, cmd.OutOrStdout())) }, inv.Stdout))
if errors.Is(err, terminal.InterruptErr) { if errors.Is(err, terminal.InterruptErr) {
return value, Canceled return value, Canceled
} }
return value, err return value, err
} }
func MultiSelect(cmd *cobra.Command, items []string) ([]string, error) { func MultiSelect(inv *clibase.Invocation, items []string) ([]string, error) {
// Similar hack is applied to Select() // Similar hack is applied to Select()
if flag.Lookup("test.v") != nil { if flag.Lookup("test.v") != nil {
return items, nil return items, nil
@ -146,10 +146,10 @@ func MultiSelect(cmd *cobra.Command, items []string) ([]string, error) {
var values []string var values []string
err := survey.AskOne(prompt, &values, survey.WithStdio(fileReadWriter{ err := survey.AskOne(prompt, &values, survey.WithStdio(fileReadWriter{
Reader: cmd.InOrStdin(), Reader: inv.Stdin,
}, fileReadWriter{ }, fileReadWriter{
Writer: cmd.OutOrStdout(), Writer: inv.Stdout,
}, cmd.OutOrStdout())) }, inv.Stdout))
if errors.Is(err, terminal.InterruptErr) { if errors.Is(err, terminal.InterruptErr) {
return nil, Canceled return nil, Canceled
} }

View File

@ -1,13 +1,12 @@
package cliui_test package cliui_test
import ( import (
"context"
"testing" "testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
@ -32,16 +31,16 @@ func TestSelect(t *testing.T) {
func newSelect(ptty *ptytest.PTY, opts cliui.SelectOptions) (string, error) { func newSelect(ptty *ptytest.PTY, opts cliui.SelectOptions) (string, error) {
value := "" value := ""
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
var err error var err error
value, err = cliui.Select(cmd, opts) value, err = cliui.Select(inv, opts)
return err return err
}, },
} }
cmd.SetOutput(ptty.Output()) inv := cmd.Invoke()
cmd.SetIn(ptty.Input()) ptty.Attach(inv)
return value, cmd.ExecuteContext(context.Background()) return value, inv.Run()
} }
func TestRichSelect(t *testing.T) { func TestRichSelect(t *testing.T) {
@ -56,11 +55,11 @@ func TestRichSelect(t *testing.T) {
{ {
Name: "A-Name", Name: "A-Name",
Value: "A-Value", Value: "A-Value",
Description: "A-Description", Description: "A-Description.",
}, { }, {
Name: "B-Name", Name: "B-Name",
Value: "B-Value", Value: "B-Value",
Description: "B-Description", Description: "B-Description.",
}, },
}, },
}) })
@ -73,18 +72,18 @@ func TestRichSelect(t *testing.T) {
func newRichSelect(ptty *ptytest.PTY, opts cliui.RichSelectOptions) (string, error) { func newRichSelect(ptty *ptytest.PTY, opts cliui.RichSelectOptions) (string, error) {
value := "" value := ""
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
richOption, err := cliui.RichSelect(cmd, opts) richOption, err := cliui.RichSelect(inv, opts)
if err == nil { if err == nil {
value = richOption.Value value = richOption.Value
} }
return err return err
}, },
} }
cmd.SetOutput(ptty.Output()) inv := cmd.Invoke()
cmd.SetIn(ptty.Input()) ptty.Attach(inv)
return value, cmd.ExecuteContext(context.Background()) return value, inv.Run()
} }
func TestMultiSelect(t *testing.T) { func TestMultiSelect(t *testing.T) {
@ -106,16 +105,16 @@ func TestMultiSelect(t *testing.T) {
func newMultiSelect(ptty *ptytest.PTY, items []string) ([]string, error) { func newMultiSelect(ptty *ptytest.PTY, items []string) ([]string, error) {
var values []string var values []string
cmd := &cobra.Command{ cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
selectedItems, err := cliui.MultiSelect(cmd, items) selectedItems, err := cliui.MultiSelect(inv, items)
if err == nil { if err == nil {
values = selectedItems values = selectedItems
} }
return err return err
}, },
} }
cmd.SetOutput(ptty.Output()) inv := cmd.Invoke()
cmd.SetIn(ptty.Input()) ptty.Attach(inv)
return values, cmd.ExecuteContext(context.Background()) return values, inv.Run()
} }

View File

@ -6,6 +6,7 @@ import (
"path/filepath" "path/filepath"
"github.com/kirsle/configdir" "github.com/kirsle/configdir"
"golang.org/x/xerrors"
) )
const ( const (
@ -15,36 +16,53 @@ const (
// Root represents the configuration directory. // Root represents the configuration directory.
type Root string type Root string
// mustNotBeEmpty prevents us from accidentally writing configuration to the
// current directory. This is primarily valuable in development, where we may
// accidentally use an empty root.
func (r Root) mustNotEmpty() {
if r == "" {
panic("config root must not be empty")
}
}
func (r Root) Session() File { func (r Root) Session() File {
r.mustNotEmpty()
return File(filepath.Join(string(r), "session")) return File(filepath.Join(string(r), "session"))
} }
// ReplicaID is a unique identifier for the Coder server. // ReplicaID is a unique identifier for the Coder server.
func (r Root) ReplicaID() File { func (r Root) ReplicaID() File {
r.mustNotEmpty()
return File(filepath.Join(string(r), "replica_id")) return File(filepath.Join(string(r), "replica_id"))
} }
func (r Root) URL() File { func (r Root) URL() File {
r.mustNotEmpty()
return File(filepath.Join(string(r), "url")) return File(filepath.Join(string(r), "url"))
} }
func (r Root) Organization() File { func (r Root) Organization() File {
r.mustNotEmpty()
return File(filepath.Join(string(r), "organization")) return File(filepath.Join(string(r), "organization"))
} }
func (r Root) DotfilesURL() File { func (r Root) DotfilesURL() File {
r.mustNotEmpty()
return File(filepath.Join(string(r), "dotfilesurl")) return File(filepath.Join(string(r), "dotfilesurl"))
} }
func (r Root) PostgresPath() string { func (r Root) PostgresPath() string {
r.mustNotEmpty()
return filepath.Join(string(r), "postgres") return filepath.Join(string(r), "postgres")
} }
func (r Root) PostgresPassword() File { func (r Root) PostgresPassword() File {
r.mustNotEmpty()
return File(filepath.Join(r.PostgresPath(), "password")) return File(filepath.Join(r.PostgresPath(), "password"))
} }
func (r Root) PostgresPort() File { func (r Root) PostgresPort() File {
r.mustNotEmpty()
return File(filepath.Join(r.PostgresPath(), "port")) return File(filepath.Join(r.PostgresPath(), "port"))
} }
@ -53,16 +71,25 @@ type File string
// Delete deletes the file. // Delete deletes the file.
func (f File) Delete() error { func (f File) Delete() error {
if f == "" {
return xerrors.Errorf("empty file path")
}
return os.Remove(string(f)) return os.Remove(string(f))
} }
// Write writes the string to the file. // Write writes the string to the file.
func (f File) Write(s string) error { func (f File) Write(s string) error {
if f == "" {
return xerrors.Errorf("empty file path")
}
return write(string(f), 0o600, []byte(s)) return write(string(f), 0o600, []byte(s))
} }
// Read reads the file to a string. // Read reads the file to a string.
func (f File) Read() (string, error) { func (f File) Read() (string, error) {
if f == "" {
return "", xerrors.Errorf("empty file path")
}
byt, err := read(string(f)) byt, err := read(string(f))
return string(byt), err return string(byt), err
} }

View File

@ -18,12 +18,11 @@ import (
"github.com/cli/safeexec" "github.com/cli/safeexec"
"github.com/pkg/diff" "github.com/pkg/diff"
"github.com/pkg/diff/write" "github.com/pkg/diff/write"
"github.com/spf13/cobra"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
@ -170,7 +169,7 @@ func sshPrepareWorkspaceConfigs(ctx context.Context, client *codersdk.Client) (r
} }
} }
func configSSH() *cobra.Command { func (r *RootCmd) configSSH() *clibase.Cmd {
var ( var (
sshConfigFile string sshConfigFile string
sshConfigOpts sshConfigOptions sshConfigOpts sshConfigOptions
@ -179,11 +178,12 @@ func configSSH() *cobra.Command {
skipProxyCommand bool skipProxyCommand bool
userHostPrefix string userHostPrefix string
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "config-ssh", Use: "config-ssh",
Short: "Add an SSH Host entry for your workspaces \"ssh coder.workspace\"", Short: "Add an SSH Host entry for your workspaces \"ssh coder.workspace\"",
Example: formatExamples( Long: formatExamples(
example{ example{
Description: "You can use -o (or --ssh-option) so set SSH options to be used for all your workspaces", Description: "You can use -o (or --ssh-option) so set SSH options to be used for all your workspaces",
Command: "coder config-ssh -o ForwardAgent=yes", Command: "coder config-ssh -o ForwardAgent=yes",
@ -193,21 +193,18 @@ func configSSH() *cobra.Command {
Command: "coder config-ssh --dry-run", Command: "coder config-ssh --dry-run",
}, },
), ),
Args: cobra.ExactArgs(0), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, _ []string) error { clibase.RequireNArgs(0),
ctx := cmd.Context() r.InitClient(client),
client, err := CreateClient(cmd) ),
if err != nil { Handler: func(inv *clibase.Invocation) error {
return err recvWorkspaceConfigs := sshPrepareWorkspaceConfigs(inv.Context(), client)
}
recvWorkspaceConfigs := sshPrepareWorkspaceConfigs(ctx, client) out := inv.Stdout
out := cmd.OutOrStdout()
if dryRun { if dryRun {
// Print everything except diff to stderr so // Print everything except diff to stderr so
// that it's possible to capture the diff. // that it's possible to capture the diff.
out = cmd.OutOrStderr() out = inv.Stderr
} }
coderBinary, err := currentBinPath(out) coderBinary, err := currentBinPath(out)
if err != nil { if err != nil {
@ -218,7 +215,7 @@ func configSSH() *cobra.Command {
return xerrors.Errorf("escape coder binary for ssh failed: %w", err) return xerrors.Errorf("escape coder binary for ssh failed: %w", err)
} }
root := createConfig(cmd) root := r.createConfig()
escapedGlobalConfig, err := sshConfigExecEscape(string(root)) escapedGlobalConfig, err := sshConfigExecEscape(string(root))
if err != nil { if err != nil {
return xerrors.Errorf("escape global config for ssh failed: %w", err) return xerrors.Errorf("escape global config for ssh failed: %w", err)
@ -278,7 +275,7 @@ func configSSH() *cobra.Command {
oldOptsMsg = fmt.Sprintf("\n\n Previous options:\n * %s", strings.Join(oldOpts, "\n * ")) oldOptsMsg = fmt.Sprintf("\n\n Previous options:\n * %s", strings.Join(oldOpts, "\n * "))
} }
line, err := cliui.Prompt(cmd, cliui.PromptOptions{ line, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("New options differ from previous options:%s%s\n\n Use new options?", newOptsMsg, oldOptsMsg), Text: fmt.Sprintf("New options differ from previous options:%s%s\n\n Use new options?", newOptsMsg, oldOptsMsg),
IsConfirm: true, IsConfirm: true,
}) })
@ -292,7 +289,7 @@ func configSSH() *cobra.Command {
changes = append(changes, "Use new SSH options") changes = append(changes, "Use new SSH options")
} }
// Only print when prompts are shown. // Only print when prompts are shown.
if yes, _ := cmd.Flags().GetBool("yes"); !yes { if yes, _ := inv.ParsedFlags().GetBool("yes"); !yes {
_, _ = fmt.Fprint(out, "\n") _, _ = fmt.Fprint(out, "\n")
} }
} }
@ -317,7 +314,7 @@ func configSSH() *cobra.Command {
return xerrors.Errorf("fetch workspace configs failed: %w", err) return xerrors.Errorf("fetch workspace configs failed: %w", err)
} }
coderdConfig, err := client.SSHConfiguration(ctx) coderdConfig, err := client.SSHConfiguration(inv.Context())
if err != nil { if err != nil {
// If the error is 404, this deployment does not support // If the error is 404, this deployment does not support
// this endpoint yet. Do not error, just assume defaults. // this endpoint yet. Do not error, just assume defaults.
@ -417,21 +414,21 @@ func configSSH() *cobra.Command {
if dryRun { if dryRun {
_, _ = fmt.Fprintf(out, "Dry run, the following changes would be made to your SSH configuration:\n\n * %s\n\n", strings.Join(changes, "\n * ")) _, _ = fmt.Fprintf(out, "Dry run, the following changes would be made to your SSH configuration:\n\n * %s\n\n", strings.Join(changes, "\n * "))
color := isTTYOut(cmd) color := isTTYOut(inv)
diff, err := diffBytes(sshConfigFile, configRaw, configModified, color) diff, err := diffBytes(sshConfigFile, configRaw, configModified, color)
if err != nil { if err != nil {
return xerrors.Errorf("diff failed: %w", err) return xerrors.Errorf("diff failed: %w", err)
} }
if len(diff) > 0 { if len(diff) > 0 {
// Write diff to stdout. // Write diff to stdout.
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s", diff) _, _ = fmt.Fprintf(inv.Stdout, "%s", diff)
} }
return nil return nil
} }
if len(changes) > 0 { if len(changes) > 0 {
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("The following changes will be made to your SSH configuration:\n\n * %s\n\n Continue?", strings.Join(changes, "\n * ")), Text: fmt.Sprintf("The following changes will be made to your SSH configuration:\n\n * %s\n\n Continue?", strings.Join(changes, "\n * ")),
IsConfirm: true, IsConfirm: true,
}) })
@ -439,7 +436,7 @@ func configSSH() *cobra.Command {
return nil return nil
} }
// Only print when prompts are shown. // Only print when prompts are shown.
if yes, _ := cmd.Flags().GetBool("yes"); !yes { if yes, _ := inv.ParsedFlags().GetBool("yes"); !yes {
_, _ = fmt.Fprint(out, "\n") _, _ = fmt.Fprint(out, "\n")
} }
} }
@ -449,6 +446,7 @@ func configSSH() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("write ssh config failed: %w", err) return xerrors.Errorf("write ssh config failed: %w", err)
} }
_, _ = fmt.Fprintf(out, "Updated %q\n", sshConfigFile)
} }
if len(workspaceConfigs) > 0 { if len(workspaceConfigs) > 0 {
@ -460,14 +458,50 @@ func configSSH() *cobra.Command {
return nil return nil
}, },
} }
cliflag.StringVarP(cmd.Flags(), &sshConfigFile, "ssh-config-file", "", "CODER_SSH_CONFIG_FILE", sshDefaultConfigFileName, "Specifies the path to an SSH config.")
cmd.Flags().StringArrayVarP(&sshConfigOpts.sshOptions, "ssh-option", "o", []string{}, "Specifies additional SSH options to embed in each host stanza.") cmd.Options = clibase.OptionSet{
cmd.Flags().BoolVarP(&dryRun, "dry-run", "n", false, "Perform a trial run with no changes made, showing a diff at the end.") {
cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.") Flag: "ssh-config-file",
_ = cmd.Flags().MarkHidden("skip-proxy-command") Env: "CODER_SSH_CONFIG_FILE",
cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.") Default: sshDefaultConfigFileName,
cmd.Flags().StringVarP(&userHostPrefix, "ssh-host-prefix", "", "", "Override the default host prefix.") Description: "Specifies the path to an SSH config.",
cliui.AllowSkipPrompt(cmd) Value: clibase.StringOf(&sshConfigFile),
},
{
Flag: "ssh-option",
FlagShorthand: "o",
Env: "CODER_SSH_CONFIG_OPTS",
Description: "Specifies additional SSH options to embed in each host stanza.",
Value: clibase.StringArrayOf(&sshConfigOpts.sshOptions),
},
{
Flag: "dry-run",
FlagShorthand: "n",
Env: "CODER_SSH_DRY_RUN",
Description: "Perform a trial run with no changes made, showing a diff at the end.",
Value: clibase.BoolOf(&dryRun),
},
{
Flag: "skip-proxy-command",
Env: "CODER_SSH_SKIP_PROXY_COMMAND",
Description: "Specifies whether the ProxyCommand option should be skipped. Useful for testing.",
Value: clibase.BoolOf(&skipProxyCommand),
Hidden: true,
},
{
Flag: "use-previous-options",
Env: "CODER_SSH_USE_PREVIOUS_OPTIONS",
Description: "Specifies whether or not to keep options from previous run of config-ssh.",
Value: clibase.BoolOf(&usePreviousOpts),
},
{
Flag: "ssh-host-prefix",
Env: "",
Description: "Override the default host prefix.",
Value: clibase.StringOf(&userHostPrefix),
},
cliui.SkipPromptOption(),
}
return cmd return cmd
} }

View File

@ -149,21 +149,17 @@ func TestConfigSSH(t *testing.T) {
tcpAddr, valid := listener.Addr().(*net.TCPAddr) tcpAddr, valid := listener.Addr().(*net.TCPAddr)
require.True(t, valid) require.True(t, valid)
cmd, root := clitest.New(t, "config-ssh", inv, root := clitest.New(t, "config-ssh",
"--ssh-option", "HostName "+tcpAddr.IP.String(), "--ssh-option", "HostName "+tcpAddr.IP.String(),
"--ssh-option", "Port "+strconv.Itoa(tcpAddr.Port), "--ssh-option", "Port "+strconv.Itoa(tcpAddr.Port),
"--ssh-config-file", sshConfigFile, "--ssh-config-file", sshConfigFile,
"--skip-proxy-command") "--skip-proxy-command")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
go func() {
defer close(doneChan) waiter := clitest.StartWithWaiter(t, inv)
err := cmd.Execute()
assert.NoError(t, err)
}()
matches := []struct { matches := []struct {
match, write string match, write string
@ -175,7 +171,7 @@ func TestConfigSSH(t *testing.T) {
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
<-doneChan waiter.RequireSuccess()
fileContents, err := os.ReadFile(sshConfigFile) fileContents, err := os.ReadFile(sshConfigFile)
require.NoError(t, err, "read ssh config file") require.NoError(t, err, "read ssh config file")
@ -187,7 +183,7 @@ func TestConfigSSH(t *testing.T) {
pty = ptytest.New(t) pty = ptytest.New(t)
// Set HOME because coder config is included from ~/.ssh/coder. // Set HOME because coder config is included from ~/.ssh/coder.
sshCmd.Env = append(sshCmd.Env, fmt.Sprintf("HOME=%s", home)) sshCmd.Env = append(sshCmd.Env, fmt.Sprintf("HOME=%s", home))
sshCmd.Stderr = pty.Output() inv.Stderr = pty.Output()
data, err := sshCmd.Output() data, err := sshCmd.Output()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "test", strings.TrimSpace(string(data))) require.Equal(t, "test", strings.TrimSpace(string(data)))
@ -586,14 +582,14 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) {
"--ssh-config-file", sshConfigName, "--ssh-config-file", sshConfigName,
} }
args = append(args, tt.args...) args = append(args, tt.args...)
cmd, root := clitest.New(t, args...) inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
done := tGo(t, func() { done := tGo(t, func() {
err := cmd.Execute() err := inv.Run()
if !tt.wantErr { if !tt.wantErr {
assert.NoError(t, err) assert.NoError(t, err)
} else { } else {
@ -703,17 +699,13 @@ func TestConfigSSH_Hostnames(t *testing.T) {
sshConfigFile := sshConfigFileName(t) sshConfigFile := sshConfigFileName(t)
cmd, root := clitest.New(t, "config-ssh", "--ssh-config-file", sshConfigFile) inv, root := clitest.New(t, "config-ssh", "--ssh-config-file", sshConfigFile)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
go func() { clitest.Start(t, inv)
defer close(doneChan)
err := cmd.Execute()
assert.NoError(t, err)
}()
matches := []struct { matches := []struct {
match, write string match, write string
@ -725,7 +717,7 @@ func TestConfigSSH_Hostnames(t *testing.T) {
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
<-doneChan pty.ExpectMatch("Updated")
var expectedHosts []string var expectedHosts []string
for _, hostnamePattern := range tt.expected { for _, hostnamePattern := range tt.expected {

View File

@ -6,17 +6,16 @@ import (
"io" "io"
"time" "time"
"github.com/spf13/cobra"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func create() *cobra.Command { func (r *RootCmd) create() *clibase.Cmd {
var ( var (
parameterFile string parameterFile string
richParameterFile string richParameterFile string
@ -25,30 +24,27 @@ func create() *cobra.Command {
stopAfter time.Duration stopAfter time.Duration
workspaceName string workspaceName string
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "create [name]", Use: "create [name]",
Short: "Create a workspace", Short: "Create a workspace",
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.Chain(r.InitClient(client)),
client, err := CreateClient(cmd) Handler: func(inv *clibase.Invocation) error {
organization, err := CurrentOrganization(inv, client)
if err != nil { if err != nil {
return err return err
} }
organization, err := CurrentOrganization(cmd, client) if len(inv.Args) >= 1 {
if err != nil { workspaceName = inv.Args[0]
return err
}
if len(args) >= 1 {
workspaceName = args[0]
} }
if workspaceName == "" { if workspaceName == "" {
workspaceName, err = cliui.Prompt(cmd, cliui.PromptOptions{ workspaceName, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Specify a name for your workspace:", Text: "Specify a name for your workspace:",
Validate: func(workspaceName string) error { Validate: func(workspaceName string) error {
_, err = client.WorkspaceByOwnerAndName(cmd.Context(), codersdk.Me, workspaceName, codersdk.WorkspaceOptions{}) _, err = client.WorkspaceByOwnerAndName(inv.Context(), codersdk.Me, workspaceName, codersdk.WorkspaceOptions{})
if err == nil { if err == nil {
return xerrors.Errorf("A workspace already exists named %q!", workspaceName) return xerrors.Errorf("A workspace already exists named %q!", workspaceName)
} }
@ -60,16 +56,16 @@ func create() *cobra.Command {
} }
} }
_, err = client.WorkspaceByOwnerAndName(cmd.Context(), codersdk.Me, workspaceName, codersdk.WorkspaceOptions{}) _, err = client.WorkspaceByOwnerAndName(inv.Context(), codersdk.Me, workspaceName, codersdk.WorkspaceOptions{})
if err == nil { if err == nil {
return xerrors.Errorf("A workspace already exists named %q!", workspaceName) return xerrors.Errorf("A workspace already exists named %q!", workspaceName)
} }
var template codersdk.Template var template codersdk.Template
if templateName == "" { if templateName == "" {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Wrap.Render("Select a template below to preview the provisioned infrastructure:")) _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Wrap.Render("Select a template below to preview the provisioned infrastructure:"))
templates, err := client.TemplatesByOrganization(cmd.Context(), organization.ID) templates, err := client.TemplatesByOrganization(inv.Context(), organization.ID)
if err != nil { if err != nil {
return err return err
} }
@ -98,7 +94,7 @@ func create() *cobra.Command {
} }
// Move the cursor up a single line for nicer display! // Move the cursor up a single line for nicer display!
option, err := cliui.Select(cmd, cliui.SelectOptions{ option, err := cliui.Select(inv, cliui.SelectOptions{
Options: templateNames, Options: templateNames,
HideSearch: true, HideSearch: true,
}) })
@ -108,7 +104,7 @@ func create() *cobra.Command {
template = templateByName[option] template = templateByName[option]
} else { } else {
template, err = client.TemplateByName(cmd.Context(), organization.ID, templateName) template, err = client.TemplateByName(inv.Context(), organization.ID, templateName)
if err != nil { if err != nil {
return xerrors.Errorf("get template by name: %w", err) return xerrors.Errorf("get template by name: %w", err)
} }
@ -123,7 +119,7 @@ func create() *cobra.Command {
schedSpec = ptr.Ref(sched.String()) schedSpec = ptr.Ref(sched.String())
} }
buildParams, err := prepWorkspaceBuild(cmd, client, prepWorkspaceBuildArgs{ buildParams, err := prepWorkspaceBuild(inv, client, prepWorkspaceBuildArgs{
Template: template, Template: template,
ExistingParams: []codersdk.Parameter{}, ExistingParams: []codersdk.Parameter{},
ParameterFile: parameterFile, ParameterFile: parameterFile,
@ -131,10 +127,10 @@ func create() *cobra.Command {
NewWorkspaceName: workspaceName, NewWorkspaceName: workspaceName,
}) })
if err != nil { if err != nil {
return err return xerrors.Errorf("prepare build: %w", err)
} }
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm create?", Text: "Confirm create?",
IsConfirm: true, IsConfirm: true,
}) })
@ -149,7 +145,7 @@ func create() *cobra.Command {
ttlMillis = &template.MaxTTLMillis ttlMillis = &template.MaxTTLMillis
} }
workspace, err := client.CreateWorkspace(cmd.Context(), organization.ID, codersdk.Me, codersdk.CreateWorkspaceRequest{ workspace, err := client.CreateWorkspace(inv.Context(), organization.ID, codersdk.Me, codersdk.CreateWorkspaceRequest{
TemplateID: template.ID, TemplateID: template.ID,
Name: workspaceName, Name: workspaceName,
AutostartSchedule: schedSpec, AutostartSchedule: schedSpec,
@ -158,25 +154,53 @@ func create() *cobra.Command {
RichParameterValues: buildParams.richParameters, RichParameterValues: buildParams.richParameters,
}) })
if err != nil { if err != nil {
return err return xerrors.Errorf("create workspace: %w", err)
} }
err = cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStdout(), client, workspace.LatestBuild.ID) err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID)
if err != nil { if err != nil {
return err return xerrors.Errorf("watch build: %w", err)
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThe %s workspace has been created at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) _, _ = fmt.Fprintf(inv.Stdout, "\nThe %s workspace has been created at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp)))
return nil return nil
}, },
} }
cmd.Options = append(cmd.Options,
clibase.Option{
Flag: "template",
FlagShorthand: "t",
Env: "CODER_TEMPLATE_NAME",
Description: "Specify a template name.",
Value: clibase.StringOf(&templateName),
},
clibase.Option{
Flag: "parameter-file",
Env: "CODER_PARAMETER_FILE",
Description: "Specify a file path with parameter values.",
Value: clibase.StringOf(&parameterFile),
},
clibase.Option{
Flag: "rich-parameter-file",
Env: "CODER_RICH_PARAMETER_FILE",
Description: "Specify a file path with values for rich parameters defined in the template.",
Value: clibase.StringOf(&richParameterFile),
},
clibase.Option{
Flag: "start-at",
Env: "CODER_WORKSPACE_START_AT",
Description: "Specify the workspace autostart schedule. Check coder schedule start --help for the syntax.",
Value: clibase.StringOf(&startAt),
},
clibase.Option{
Flag: "stop-after",
Env: "CODER_WORKSPACE_STOP_AFTER",
Description: "Specify a duration after which the workspace should shut down (e.g. 8h).",
Value: clibase.DurationOf(&stopAfter),
},
cliui.SkipPromptOption(),
)
cliui.AllowSkipPrompt(cmd)
cliflag.StringVarP(cmd.Flags(), &templateName, "template", "t", "CODER_TEMPLATE_NAME", "", "Specify a template name.")
cliflag.StringVarP(cmd.Flags(), &parameterFile, "parameter-file", "", "CODER_PARAMETER_FILE", "", "Specify a file path with parameter values.")
cliflag.StringVarP(cmd.Flags(), &richParameterFile, "rich-parameter-file", "", "CODER_RICH_PARAMETER_FILE", "", "Specify a file path with values for rich parameters defined in the template.")
cliflag.StringVarP(cmd.Flags(), &startAt, "start-at", "", "CODER_WORKSPACE_START_AT", "", "Specify the workspace autostart schedule. Check `coder schedule start --help` for the syntax.")
cliflag.DurationVarP(cmd.Flags(), &stopAfter, "stop-after", "", "CODER_WORKSPACE_STOP_AFTER", 0, "Specify a duration after which the workspace should shut down (e.g. 8h).")
return cmd return cmd
} }
@ -200,8 +224,8 @@ type buildParameters struct {
// prepWorkspaceBuild will ensure a workspace build will succeed on the latest template version. // prepWorkspaceBuild will ensure a workspace build will succeed on the latest template version.
// Any missing params will be prompted to the user. It supports legacy and rich parameters. // Any missing params will be prompted to the user. It supports legacy and rich parameters.
func prepWorkspaceBuild(cmd *cobra.Command, client *codersdk.Client, args prepWorkspaceBuildArgs) (*buildParameters, error) { func prepWorkspaceBuild(inv *clibase.Invocation, client *codersdk.Client, args prepWorkspaceBuildArgs) (*buildParameters, error) {
ctx := cmd.Context() ctx := inv.Context()
var useRichParameters bool var useRichParameters bool
if len(args.ExistingRichParams) > 0 && len(args.RichParameterFile) > 0 { if len(args.ExistingRichParams) > 0 && len(args.RichParameterFile) > 0 {
@ -233,7 +257,7 @@ func prepWorkspaceBuild(cmd *cobra.Command, client *codersdk.Client, args prepWo
useParamFile := false useParamFile := false
if args.ParameterFile != "" { if args.ParameterFile != "" {
useParamFile = true useParamFile = true
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n") _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n")
parameterMapFromFile, err = createParameterMapFromFile(args.ParameterFile) parameterMapFromFile, err = createParameterMapFromFile(args.ParameterFile)
if err != nil { if err != nil {
return nil, err return nil, err
@ -247,7 +271,7 @@ PromptParamLoop:
continue continue
} }
if !disclaimerPrinted { if !disclaimerPrinted {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("This template has customizable parameters. Values can be changed after create, but may have unintended side effects (like data loss).")+"\r\n") _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("This template has customizable parameters. Values can be changed after create, but may have unintended side effects (like data loss).")+"\r\n")
disclaimerPrinted = true disclaimerPrinted = true
} }
@ -262,7 +286,7 @@ PromptParamLoop:
} }
} }
parameterValue, err := getParameterValueFromMapOrInput(cmd, parameterMapFromFile, parameterSchema) parameterValue, err := getParameterValueFromMapOrInput(inv, parameterMapFromFile, parameterSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -276,11 +300,11 @@ PromptParamLoop:
} }
if disclaimerPrinted { if disclaimerPrinted {
_, _ = fmt.Fprintln(cmd.OutOrStdout()) _, _ = fmt.Fprintln(inv.Stdout)
} }
// Rich parameters // Rich parameters
templateVersionParameters, err := client.TemplateVersionRichParameters(cmd.Context(), templateVersion.ID) templateVersionParameters, err := client.TemplateVersionRichParameters(inv.Context(), templateVersion.ID)
if err != nil { if err != nil {
return nil, xerrors.Errorf("get template version rich parameters: %w", err) return nil, xerrors.Errorf("get template version rich parameters: %w", err)
} }
@ -289,7 +313,7 @@ PromptParamLoop:
useParamFile = false useParamFile = false
if args.RichParameterFile != "" { if args.RichParameterFile != "" {
useParamFile = true useParamFile = true
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Attempting to read the variables from the rich parameter file.")+"\r\n") _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("Attempting to read the variables from the rich parameter file.")+"\r\n")
parameterMapFromFile, err = createParameterMapFromFile(args.RichParameterFile) parameterMapFromFile, err = createParameterMapFromFile(args.RichParameterFile)
if err != nil { if err != nil {
return nil, err return nil, err
@ -300,7 +324,7 @@ PromptParamLoop:
PromptRichParamLoop: PromptRichParamLoop:
for _, templateVersionParameter := range templateVersionParameters { for _, templateVersionParameter := range templateVersionParameters {
if !disclaimerPrinted { if !disclaimerPrinted {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("This template has customizable parameters. Values can be changed after create, but may have unintended side effects (like data loss).")+"\r\n") _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("This template has customizable parameters. Values can be changed after create, but may have unintended side effects (like data loss).")+"\r\n")
disclaimerPrinted = true disclaimerPrinted = true
} }
@ -316,11 +340,11 @@ PromptRichParamLoop:
} }
if args.UpdateWorkspace && !templateVersionParameter.Mutable { if args.UpdateWorkspace && !templateVersionParameter.Mutable {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Warn.Render(fmt.Sprintf(`Parameter %q is not mutable, so can't be customized after workspace creation.`, templateVersionParameter.Name))) _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Warn.Render(fmt.Sprintf(`Parameter %q is not mutable, so can't be customized after workspace creation.`, templateVersionParameter.Name)))
continue continue
} }
parameterValue, err := getWorkspaceBuildParameterValueFromMapOrInput(cmd, parameterMapFromFile, templateVersionParameter) parameterValue, err := getWorkspaceBuildParameterValueFromMapOrInput(inv, parameterMapFromFile, templateVersionParameter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -329,10 +353,10 @@ PromptRichParamLoop:
} }
if disclaimerPrinted { if disclaimerPrinted {
_, _ = fmt.Fprintln(cmd.OutOrStdout()) _, _ = fmt.Fprintln(inv.Stdout)
} }
err = cliui.GitAuth(ctx, cmd.OutOrStdout(), cliui.GitAuthOptions{ err = cliui.GitAuth(ctx, inv.Stdout, cliui.GitAuthOptions{
Fetch: func(ctx context.Context) ([]codersdk.TemplateVersionGitAuth, error) { Fetch: func(ctx context.Context) ([]codersdk.TemplateVersionGitAuth, error) {
return client.TemplateVersionGitAuth(ctx, templateVersion.ID) return client.TemplateVersionGitAuth(ctx, templateVersion.ID)
}, },
@ -342,7 +366,7 @@ PromptRichParamLoop:
} }
// Run a dry-run with the given parameters to check correctness // Run a dry-run with the given parameters to check correctness
dryRun, err := client.CreateTemplateVersionDryRun(cmd.Context(), templateVersion.ID, codersdk.CreateTemplateVersionDryRunRequest{ dryRun, err := client.CreateTemplateVersionDryRun(inv.Context(), templateVersion.ID, codersdk.CreateTemplateVersionDryRunRequest{
WorkspaceName: args.NewWorkspaceName, WorkspaceName: args.NewWorkspaceName,
ParameterValues: legacyParameters, ParameterValues: legacyParameters,
RichParameterValues: richParameters, RichParameterValues: richParameters,
@ -350,16 +374,16 @@ PromptRichParamLoop:
if err != nil { if err != nil {
return nil, xerrors.Errorf("begin workspace dry-run: %w", err) return nil, xerrors.Errorf("begin workspace dry-run: %w", err)
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Planning workspace...") _, _ = fmt.Fprintln(inv.Stdout, "Planning workspace...")
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
return client.TemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) return client.TemplateVersionDryRun(inv.Context(), templateVersion.ID, dryRun.ID)
}, },
Cancel: func() error { Cancel: func() error {
return client.CancelTemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) return client.CancelTemplateVersionDryRun(inv.Context(), templateVersion.ID, dryRun.ID)
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
return client.TemplateVersionDryRunLogsAfter(cmd.Context(), templateVersion.ID, dryRun.ID, 0) return client.TemplateVersionDryRunLogsAfter(inv.Context(), templateVersion.ID, dryRun.ID, 0)
}, },
// Don't show log output for the dry-run unless there's an error. // Don't show log output for the dry-run unless there's an error.
Silent: true, Silent: true,
@ -370,19 +394,19 @@ PromptRichParamLoop:
return nil, xerrors.Errorf("dry-run workspace: %w", err) return nil, xerrors.Errorf("dry-run workspace: %w", err)
} }
resources, err := client.TemplateVersionDryRunResources(cmd.Context(), templateVersion.ID, dryRun.ID) resources, err := client.TemplateVersionDryRunResources(inv.Context(), templateVersion.ID, dryRun.ID)
if err != nil { if err != nil {
return nil, xerrors.Errorf("get workspace dry-run resources: %w", err) return nil, xerrors.Errorf("get workspace dry-run resources: %w", err)
} }
err = cliui.WorkspaceResources(cmd.OutOrStdout(), resources, cliui.WorkspaceResourcesOptions{ err = cliui.WorkspaceResources(inv.Stdout, resources, cliui.WorkspaceResourcesOptions{
WorkspaceName: args.NewWorkspaceName, WorkspaceName: args.NewWorkspaceName,
// Since agents haven't connected yet, hiding this makes more sense. // Since agents haven't connected yet, hiding this makes more sense.
HideAgentState: true, HideAgentState: true,
Title: "Workspace Preview", Title: "Workspace Preview",
}) })
if err != nil { if err != nil {
return nil, err return nil, xerrors.Errorf("get resources: %w", err)
} }
return &buildParameters{ return &buildParameters{

View File

@ -42,15 +42,13 @@ func TestCreate(t *testing.T) {
"--start-at", "9:30AM Mon-Fri US/Central", "--start-at", "9:30AM Mon-Fri US/Central",
"--stop-after", "8h", "--stop-after", "8h",
} }
cmd, root := clitest.New(t, args...) inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
matches := []struct { matches := []struct {
@ -100,17 +98,10 @@ func TestCreate(t *testing.T) {
"my-workspace", "my-workspace",
"--template", template.Name, "--template", template.Name,
} }
cmd, root := clitest.New(t, args...) inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) pty := ptytest.New(t).Attach(inv)
pty := ptytest.New(t) waiter := clitest.StartWithWaiter(t, inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := cmd.Execute()
assert.NoError(t, err)
}()
matches := []struct { matches := []struct {
match string match string
write string write string
@ -125,7 +116,7 @@ func TestCreate(t *testing.T) {
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
} }
<-doneChan waiter.RequireSuccess()
ws, err := client.WorkspaceByOwnerAndName(context.Background(), "testuser", "my-workspace", codersdk.WorkspaceOptions{}) ws, err := client.WorkspaceByOwnerAndName(context.Background(), "testuser", "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err, "expected workspace to be created") require.NoError(t, err, "expected workspace to be created")
@ -140,14 +131,14 @@ func TestCreate(t *testing.T) {
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID) coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
_ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "my-workspace", "-y") inv, root := clitest.New(t, "create", "my-workspace", "-y")
member, _ := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) member, _ := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
clitest.SetupConfig(t, member, root) clitest.SetupConfig(t, member, root)
cmdCtx, done := context.WithTimeout(context.Background(), testutil.WaitLong) cmdCtx, done := context.WithTimeout(context.Background(), testutil.WaitLong)
go func() { go func() {
defer done() defer done()
err := cmd.ExecuteContext(cmdCtx) err := inv.WithContext(cmdCtx).Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
// No pty interaction needed since we use the -y skip prompt flag // No pty interaction needed since we use the -y skip prompt flag
@ -162,15 +153,13 @@ func TestCreate(t *testing.T) {
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID) coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "") inv, root := clitest.New(t, "create", "")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
matches := []string{ matches := []string{
@ -185,7 +174,7 @@ func TestCreate(t *testing.T) {
} }
<-doneChan <-doneChan
ws, err := client.WorkspaceByOwnerAndName(cmd.Context(), "testuser", "my-workspace", codersdk.WorkspaceOptions{}) ws, err := client.WorkspaceByOwnerAndName(inv.Context(), "testuser", "my-workspace", codersdk.WorkspaceOptions{})
if assert.NoError(t, err, "expected workspace to be created") { if assert.NoError(t, err, "expected workspace to be created") {
assert.Equal(t, ws.TemplateName, template.Name) assert.Equal(t, ws.TemplateName, template.Name)
assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil") assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil")
@ -206,15 +195,13 @@ func TestCreate(t *testing.T) {
coderdtest.AwaitTemplateVersionJob(t, client, version.ID) coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
_ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "") inv, root := clitest.New(t, "create", "")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -251,15 +238,13 @@ func TestCreate(t *testing.T) {
removeTmpDirUntilSuccessAfterTest(t, tempDir) removeTmpDirUntilSuccessAfterTest(t, tempDir)
parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml") parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml")
_, _ = parameterFile.WriteString("region: \"bingo\"\nusername: \"boingo\"") _, _ = parameterFile.WriteString("region: \"bingo\"\nusername: \"boingo\"")
cmd, root := clitest.New(t, "create", "", "--parameter-file", parameterFile.Name()) inv, root := clitest.New(t, "create", "", "--parameter-file", parameterFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -296,15 +281,13 @@ func TestCreate(t *testing.T) {
parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml") parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml")
_, _ = parameterFile.WriteString("username: \"boingo\"") _, _ = parameterFile.WriteString("username: \"boingo\"")
cmd, root := clitest.New(t, "create", "", "--parameter-file", parameterFile.Name()) inv, root := clitest.New(t, "create", "", "--parameter-file", parameterFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
matches := []struct { matches := []struct {
@ -364,13 +347,11 @@ func TestCreate(t *testing.T) {
require.Equal(t, codersdk.ProvisionerJobSucceeded, version.Job.Status, "job is not failed") require.Equal(t, codersdk.ProvisionerJobSucceeded, version.Job.Status, "job is not failed")
_ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "test", "--parameter-file", parameterFile.Name()) inv, root := clitest.New(t, "create", "test", "--parameter-file", parameterFile.Name(), "-y")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
err = cmd.Execute() err = inv.Run()
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "dry-run workspace") require.ErrorContains(t, err, "dry-run workspace")
}) })
@ -425,15 +406,13 @@ func TestCreateWithRichParameters(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -469,16 +448,14 @@ func TestCreateWithRichParameters(t *testing.T) {
firstParameterName + ": " + firstParameterValue + "\n" + firstParameterName + ": " + firstParameterValue + "\n" +
secondParameterName + ": " + secondParameterValue + "\n" + secondParameterName + ": " + secondParameterValue + "\n" +
immutableParameterName + ": " + immutableParameterValue) immutableParameterName + ": " + immutableParameterValue)
cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name()) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -559,15 +536,13 @@ func TestCreateValidateRichParameters(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -596,15 +571,13 @@ func TestCreateValidateRichParameters(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -636,15 +609,13 @@ func TestCreateValidateRichParameters(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -672,17 +643,10 @@ func TestCreateValidateRichParameters(t *testing.T) {
coderdtest.AwaitTemplateVersionJob(t, client, version.ID) coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) pty := ptytest.New(t).Attach(inv)
pty := ptytest.New(t) clitest.Start(t, inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := cmd.Execute()
assert.NoError(t, err)
}()
matches := []string{ matches := []string{
listOfStringsParameterName, "", listOfStringsParameterName, "",
@ -697,7 +661,6 @@ func TestCreateValidateRichParameters(t *testing.T) {
pty.WriteLine(value) pty.WriteLine(value)
} }
} }
<-doneChan
}) })
t.Run("ValidateListOfStrings_YAMLFile", func(t *testing.T) { t.Run("ValidateListOfStrings_YAMLFile", func(t *testing.T) {
@ -716,17 +679,11 @@ func TestCreateValidateRichParameters(t *testing.T) {
- ddd - ddd
- eee - eee
- fff`) - fff`)
cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name()) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) pty := ptytest.New(t).Attach(inv)
pty := ptytest.New(t)
cmd.SetIn(pty.Input()) clitest.Start(t, inv)
cmd.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := cmd.Execute()
assert.NoError(t, err)
}()
matches := []string{ matches := []string{
"Confirm create?", "yes", "Confirm create?", "yes",
@ -739,7 +696,6 @@ func TestCreateValidateRichParameters(t *testing.T) {
pty.WriteLine(value) pty.WriteLine(value)
} }
} }
<-doneChan
}) })
} }
@ -777,17 +733,10 @@ func TestCreateWithGitAuth(t *testing.T) {
coderdtest.AwaitTemplateVersionJob(t, client, version.ID) coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) pty := ptytest.New(t).Attach(inv)
pty := ptytest.New(t) clitest.Start(t, inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := cmd.Execute()
assert.NoError(t, err)
}()
pty.ExpectMatch("You must authenticate with GitHub to create a workspace") pty.ExpectMatch("You must authenticate with GitHub to create a workspace")
resp := coderdtest.RequestGitAuthCallback(t, "github", client) resp := coderdtest.RequestGitAuthCallback(t, "github", client)
@ -795,7 +744,6 @@ func TestCreateWithGitAuth(t *testing.T) {
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
pty.ExpectMatch("Confirm create?") pty.ExpectMatch("Confirm create?")
pty.WriteLine("yes") pty.WriteLine("yes")
<-doneChan
} }
func createTestParseResponseWithDefault(defaultValue string) []*proto.Parse_Response { func createTestParseResponseWithDefault(defaultValue string) []*proto.Parse_Response {

View File

@ -4,23 +4,25 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
// nolint // nolint
func deleteWorkspace() *cobra.Command { func (r *RootCmd) deleteWorkspace() *clibase.Cmd {
var orphan bool var orphan bool
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "delete <workspace>", Use: "delete <workspace>",
Short: "Delete a workspace", Short: "Delete a workspace",
Aliases: []string{"rm"}, Middleware: clibase.Chain(
Args: cobra.ExactArgs(1), clibase.RequireNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { r.InitClient(client),
_, err := cliui.Prompt(cmd, cliui.PromptOptions{ ),
Handler: func(inv *clibase.Invocation) error {
_, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm delete workspace?", Text: "Confirm delete workspace?",
IsConfirm: true, IsConfirm: true,
Default: cliui.ConfirmNo, Default: cliui.ConfirmNo,
@ -29,11 +31,7 @@ func deleteWorkspace() *cobra.Command {
return err return err
} }
client, err := CreateClient(cmd) workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil {
return err
}
workspace, err := namedWorkspace(cmd, client, args[0])
if err != nil { if err != nil {
return err return err
} }
@ -42,12 +40,12 @@ func deleteWorkspace() *cobra.Command {
if orphan { if orphan {
cliui.Warn( cliui.Warn(
cmd.ErrOrStderr(), inv.Stderr,
"Orphaning workspace requires template edit permission", "Orphaning workspace requires template edit permission",
) )
} }
build, err := client.CreateWorkspaceBuild(cmd.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ build, err := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionDelete, Transition: codersdk.WorkspaceTransitionDelete,
ProvisionerState: state, ProvisionerState: state,
Orphan: orphan, Orphan: orphan,
@ -56,19 +54,23 @@ func deleteWorkspace() *cobra.Command {
return err return err
} }
err = cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStdout(), client, build.ID) err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, build.ID)
if err != nil { if err != nil {
return err return err
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThe %s workspace has been deleted at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) _, _ = fmt.Fprintf(inv.Stdout, "\nThe %s workspace has been deleted at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp)))
return nil return nil
}, },
} }
cmd.Flags().BoolVar(&orphan, "orphan", false, cmd.Options = clibase.OptionSet{
`Delete a workspace without deleting its resources. This can delete a {
workspace in a broken state, but may also lead to unaccounted cloud resources.`, Flag: "orphan",
) Description: "Delete a workspace without deleting its resources. This can delete a workspace in a broken state, but may also lead to unaccounted cloud resources.",
cliui.AllowSkipPrompt(cmd)
Value: clibase.BoolOf(&orphan),
},
cliui.SkipPromptOption(),
}
return cmd return cmd
} }

View File

@ -25,15 +25,13 @@ func TestDelete(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, root := clitest.New(t, "delete", workspace.Name, "-y") inv, root := clitest.New(t, "delete", workspace.Name, "-y")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
// When running with the race detector on, we sometimes get an EOF. // When running with the race detector on, we sometimes get an EOF.
if err != nil { if err != nil {
assert.ErrorIs(t, err, io.EOF) assert.ErrorIs(t, err, io.EOF)
@ -52,17 +50,15 @@ func TestDelete(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, root := clitest.New(t, "delete", workspace.Name, "-y", "--orphan") inv, root := clitest.New(t, "delete", workspace.Name, "-y", "--orphan")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input()) inv.Stderr = pty.Output()
cmd.SetOut(pty.Output())
cmd.SetErr(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
// When running with the race detector on, we sometimes get an EOF. // When running with the race detector on, we sometimes get an EOF.
if err != nil { if err != nil {
assert.ErrorIs(t, err, io.EOF) assert.ErrorIs(t, err, io.EOF)
@ -87,15 +83,13 @@ func TestDelete(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, root := clitest.New(t, "delete", user.Username+"/"+workspace.Name, "-y") inv, root := clitest.New(t, "delete", user.Username+"/"+workspace.Name, "-y")
clitest.SetupConfig(t, adminClient, root) clitest.SetupConfig(t, adminClient, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
// When running with the race detector on, we sometimes get an EOF. // When running with the race detector on, we sometimes get an EOF.
if err != nil { if err != nil {
assert.ErrorIs(t, err, io.EOF) assert.ErrorIs(t, err, io.EOF)
@ -112,12 +106,12 @@ func TestDelete(t *testing.T) {
t.Run("InvalidWorkspaceIdentifier", func(t *testing.T) { t.Run("InvalidWorkspaceIdentifier", func(t *testing.T) {
t.Parallel() t.Parallel()
client := coderdtest.New(t, nil) client := coderdtest.New(t, nil)
cmd, root := clitest.New(t, "delete", "a/b/c", "-y") inv, root := clitest.New(t, "delete", "a/b/c", "-y")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.ErrorContains(t, err, "invalid workspace name: \"a/b/c\"") assert.ErrorContains(t, err, "invalid workspace name: \"a/b/c\"")
}() }()
<-doneChan <-doneChan

View File

@ -10,30 +10,29 @@ import (
"strings" "strings"
"time" "time"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
) )
func dotfiles() *cobra.Command { func (r *RootCmd) dotfiles() *clibase.Cmd {
var symlinkDir string var symlinkDir string
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "dotfiles [git_repo_url]", Use: "dotfiles <git_repo_url>",
Args: cobra.ExactArgs(1), Middleware: clibase.RequireNArgs(1),
Short: "Checkout and install a dotfiles repository from a Git URL", Short: "Personalize your workspace by applying a canonical dotfiles repository",
Example: formatExamples( Long: formatExamples(
example{ example{
Description: "Check out and install a dotfiles repository without prompts", Description: "Check out and install a dotfiles repository without prompts",
Command: "coder dotfiles --yes git@github.com:example/dotfiles.git", Command: "coder dotfiles --yes git@github.com:example/dotfiles.git",
}, },
), ),
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
var ( var (
dotfilesRepoDir = "dotfiles" dotfilesRepoDir = "dotfiles"
gitRepo = args[0] gitRepo = inv.Args[0]
cfg = createConfig(cmd) cfg = r.createConfig()
cfgDir = string(cfg) cfgDir = string(cfg)
dotfilesDir = filepath.Join(cfgDir, dotfilesRepoDir) dotfilesDir = filepath.Join(cfgDir, dotfilesRepoDir)
// This follows the same pattern outlined by others in the market: // This follows the same pattern outlined by others in the market:
@ -50,7 +49,11 @@ func dotfiles() *cobra.Command {
} }
) )
_, _ = fmt.Fprint(cmd.OutOrStdout(), "Checking if dotfiles repository already exists...\n") if cfg == "" {
return xerrors.Errorf("no config directory")
}
_, _ = fmt.Fprint(inv.Stdout, "Checking if dotfiles repository already exists...\n")
dotfilesExists, err := dirExists(dotfilesDir) dotfilesExists, err := dirExists(dotfilesDir)
if err != nil { if err != nil {
return xerrors.Errorf("checking dir %s: %w", dotfilesDir, err) return xerrors.Errorf("checking dir %s: %w", dotfilesDir, err)
@ -65,7 +68,7 @@ func dotfiles() *cobra.Command {
// if the git url has changed we create a backup and clone fresh // if the git url has changed we create a backup and clone fresh
if gitRepo != du { if gitRepo != du {
backupDir := fmt.Sprintf("%s_backup_%s", dotfilesDir, time.Now().Format(time.RFC3339)) backupDir := fmt.Sprintf("%s_backup_%s", dotfilesDir, time.Now().Format(time.RFC3339))
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("The dotfiles URL has changed from %q to %q.\n Coder will backup the existing repo to %s.\n\n Continue?", du, gitRepo, backupDir), Text: fmt.Sprintf("The dotfiles URL has changed from %q to %q.\n Coder will backup the existing repo to %s.\n\n Continue?", du, gitRepo, backupDir),
IsConfirm: true, IsConfirm: true,
}) })
@ -77,7 +80,7 @@ func dotfiles() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("renaming dir %s: %w", dotfilesDir, err) return xerrors.Errorf("renaming dir %s: %w", dotfilesDir, err)
} }
_, _ = fmt.Fprint(cmd.OutOrStdout(), "Done backup up dotfiles.\n") _, _ = fmt.Fprint(inv.Stdout, "Done backup up dotfiles.\n")
dotfilesExists = false dotfilesExists = false
moved = true moved = true
} }
@ -89,20 +92,20 @@ func dotfiles() *cobra.Command {
promptText string promptText string
) )
if dotfilesExists { if dotfilesExists {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Found dotfiles repository at %s\n", dotfilesDir) _, _ = fmt.Fprintf(inv.Stdout, "Found dotfiles repository at %s\n", dotfilesDir)
gitCmdDir = dotfilesDir gitCmdDir = dotfilesDir
subcommands = []string{"pull", "--ff-only"} subcommands = []string{"pull", "--ff-only"}
promptText = fmt.Sprintf("Pulling latest from %s into directory %s.\n Continue?", gitRepo, dotfilesDir) promptText = fmt.Sprintf("Pulling latest from %s into directory %s.\n Continue?", gitRepo, dotfilesDir)
} else { } else {
if !moved { if !moved {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Did not find dotfiles repository at %s\n", dotfilesDir) _, _ = fmt.Fprintf(inv.Stdout, "Did not find dotfiles repository at %s\n", dotfilesDir)
} }
gitCmdDir = cfgDir gitCmdDir = cfgDir
subcommands = []string{"clone", args[0], dotfilesRepoDir} subcommands = []string{"clone", inv.Args[0], dotfilesRepoDir}
promptText = fmt.Sprintf("Cloning %s into directory %s.\n\n Continue?", gitRepo, dotfilesDir) promptText = fmt.Sprintf("Cloning %s into directory %s.\n\n Continue?", gitRepo, dotfilesDir)
} }
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: promptText, Text: promptText,
IsConfirm: true, IsConfirm: true,
}) })
@ -113,7 +116,7 @@ func dotfiles() *cobra.Command {
// ensure command dir exists // ensure command dir exists
err = os.MkdirAll(gitCmdDir, 0o750) err = os.MkdirAll(gitCmdDir, 0o750)
if err != nil { if err != nil {
return xerrors.Errorf("ensuring dir at %s: %w", gitCmdDir, err) return xerrors.Errorf("ensuring dir at %q: %w", gitCmdDir, err)
} }
// check if git ssh command already exists so we can just wrap it // check if git ssh command already exists so we can just wrap it
@ -123,18 +126,18 @@ func dotfiles() *cobra.Command {
} }
// clone or pull repo // clone or pull repo
c := exec.CommandContext(cmd.Context(), "git", subcommands...) c := exec.CommandContext(inv.Context(), "git", subcommands...)
c.Dir = gitCmdDir c.Dir = gitCmdDir
c.Env = append(os.Environ(), fmt.Sprintf(`GIT_SSH_COMMAND=%s -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no`, gitsshCmd)) c.Env = append(inv.Environ.ToOS(), fmt.Sprintf(`GIT_SSH_COMMAND=%s -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no`, gitsshCmd))
c.Stdout = cmd.OutOrStdout() c.Stdout = inv.Stdout
c.Stderr = cmd.ErrOrStderr() c.Stderr = inv.Stderr
err = c.Run() err = c.Run()
if err != nil { if err != nil {
if !dotfilesExists { if !dotfilesExists {
return err return err
} }
// if the repo exists we soft fail the update operation and try to continue // if the repo exists we soft fail the update operation and try to continue
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Error.Render("Failed to update repo, continuing...")) _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Error.Render("Failed to update repo, continuing..."))
} }
// save git repo url so we can detect changes next time // save git repo url so we can detect changes next time
@ -158,7 +161,7 @@ func dotfiles() *cobra.Command {
script := findScript(installScriptSet, files) script := findScript(installScriptSet, files)
if script != "" { if script != "" {
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Running install script %s.\n\n Continue?", script), Text: fmt.Sprintf("Running install script %s.\n\n Continue?", script),
IsConfirm: true, IsConfirm: true,
}) })
@ -166,29 +169,29 @@ func dotfiles() *cobra.Command {
return err return err
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Running %s...\n", script) _, _ = fmt.Fprintf(inv.Stdout, "Running %s...\n", script)
// it is safe to use a variable command here because it's from // it is safe to use a variable command here because it's from
// a filtered list of pre-approved install scripts // a filtered list of pre-approved install scripts
// nolint:gosec // nolint:gosec
scriptCmd := exec.CommandContext(cmd.Context(), filepath.Join(dotfilesDir, script)) scriptCmd := exec.CommandContext(inv.Context(), filepath.Join(dotfilesDir, script))
scriptCmd.Dir = dotfilesDir scriptCmd.Dir = dotfilesDir
scriptCmd.Stdout = cmd.OutOrStdout() scriptCmd.Stdout = inv.Stdout
scriptCmd.Stderr = cmd.ErrOrStderr() scriptCmd.Stderr = inv.Stderr
err = scriptCmd.Run() err = scriptCmd.Run()
if err != nil { if err != nil {
return xerrors.Errorf("running %s: %w", script, err) return xerrors.Errorf("running %s: %w", script, err)
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Dotfiles installation complete.") _, _ = fmt.Fprintln(inv.Stdout, "Dotfiles installation complete.")
return nil return nil
} }
if len(dotfiles) == 0 { if len(dotfiles) == 0 {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "No install scripts or dotfiles found, nothing to do.") _, _ = fmt.Fprintln(inv.Stdout, "No install scripts or dotfiles found, nothing to do.")
return nil return nil
} }
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "No install scripts found, symlinking dotfiles to home directory.\n\n Continue?", Text: "No install scripts found, symlinking dotfiles to home directory.\n\n Continue?",
IsConfirm: true, IsConfirm: true,
}) })
@ -206,7 +209,7 @@ func dotfiles() *cobra.Command {
for _, df := range dotfiles { for _, df := range dotfiles {
from := filepath.Join(dotfilesDir, df) from := filepath.Join(dotfilesDir, df)
to := filepath.Join(symlinkDir, df) to := filepath.Join(symlinkDir, df)
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Symlinking %s to %s...\n", from, to) _, _ = fmt.Fprintf(inv.Stdout, "Symlinking %s to %s...\n", from, to)
isRegular, err := isRegular(to) isRegular, err := isRegular(to)
if err != nil { if err != nil {
@ -215,7 +218,7 @@ func dotfiles() *cobra.Command {
// move conflicting non-symlink files to file.ext.bak // move conflicting non-symlink files to file.ext.bak
if isRegular { if isRegular {
backup := fmt.Sprintf("%s.bak", to) backup := fmt.Sprintf("%s.bak", to)
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Moving %s to %s...\n", to, backup) _, _ = fmt.Fprintf(inv.Stdout, "Moving %s to %s...\n", to, backup)
err = os.Rename(to, backup) err = os.Rename(to, backup)
if err != nil { if err != nil {
return xerrors.Errorf("renaming dir %s: %w", to, err) return xerrors.Errorf("renaming dir %s: %w", to, err)
@ -228,13 +231,19 @@ func dotfiles() *cobra.Command {
} }
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Dotfiles installation complete.") _, _ = fmt.Fprintln(inv.Stdout, "Dotfiles installation complete.")
return nil return nil
}, },
} }
cliui.AllowSkipPrompt(cmd) cmd.Options = clibase.OptionSet{
cliflag.StringVarP(cmd.Flags(), &symlinkDir, "symlink-dir", "", "CODER_SYMLINK_DIR", "", "Specifies the directory for the dotfiles symlink destinations. If empty will use $HOME.") {
Flag: "symlink-dir",
Env: "CODER_SYMLINK_DIR",
Description: "Specifies the directory for the dotfiles symlink destinations. If empty, will use $HOME.",
Value: clibase.StringOf(&symlinkDir),
},
cliui.SkipPromptOption(),
}
return cmd return cmd
} }

View File

@ -15,14 +15,16 @@ import (
"github.com/coder/coder/cryptorand" "github.com/coder/coder/cryptorand"
) )
// nolint:paralleltest
func TestDotfiles(t *testing.T) { func TestDotfiles(t *testing.T) {
t.Parallel()
t.Run("MissingArg", func(t *testing.T) { t.Run("MissingArg", func(t *testing.T) {
cmd, _ := clitest.New(t, "dotfiles") t.Parallel()
err := cmd.Execute() inv, _ := clitest.New(t, "dotfiles")
err := inv.Run()
require.Error(t, err) require.Error(t, err)
}) })
t.Run("NoInstallScript", func(t *testing.T) { t.Run("NoInstallScript", func(t *testing.T) {
t.Parallel()
_, root := clitest.New(t) _, root := clitest.New(t)
testRepo := testGitRepo(t, root) testRepo := testGitRepo(t, root)
@ -40,8 +42,8 @@ func TestDotfiles(t *testing.T) {
out, err := c.CombinedOutput() out, err := c.CombinedOutput()
require.NoError(t, err, string(out)) require.NoError(t, err, string(out))
cmd, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) inv, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo)
err = cmd.Execute() err = inv.Run()
require.NoError(t, err) require.NoError(t, err)
b, err := os.ReadFile(filepath.Join(string(root), ".bashrc")) b, err := os.ReadFile(filepath.Join(string(root), ".bashrc"))
@ -49,6 +51,7 @@ func TestDotfiles(t *testing.T) {
require.Equal(t, string(b), "wow") require.Equal(t, string(b), "wow")
}) })
t.Run("InstallScript", func(t *testing.T) { t.Run("InstallScript", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("install scripts on windows require sh and aren't very practical") t.Skip("install scripts on windows require sh and aren't very practical")
} }
@ -69,8 +72,8 @@ func TestDotfiles(t *testing.T) {
err = c.Run() err = c.Run()
require.NoError(t, err) require.NoError(t, err)
cmd, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) inv, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo)
err = cmd.Execute() err = inv.Run()
require.NoError(t, err) require.NoError(t, err)
b, err := os.ReadFile(filepath.Join(string(root), ".bashrc")) b, err := os.ReadFile(filepath.Join(string(root), ".bashrc"))
@ -78,6 +81,7 @@ func TestDotfiles(t *testing.T) {
require.Equal(t, string(b), "wow\n") require.Equal(t, string(b), "wow\n")
}) })
t.Run("SymlinkBackup", func(t *testing.T) { t.Run("SymlinkBackup", func(t *testing.T) {
t.Parallel()
_, root := clitest.New(t) _, root := clitest.New(t)
testRepo := testGitRepo(t, root) testRepo := testGitRepo(t, root)
@ -100,8 +104,8 @@ func TestDotfiles(t *testing.T) {
out, err := c.CombinedOutput() out, err := c.CombinedOutput()
require.NoError(t, err, string(out)) require.NoError(t, err, string(out))
cmd, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) inv, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo)
err = cmd.Execute() err = inv.Run()
require.NoError(t, err) require.NoError(t, err)
b, err := os.ReadFile(filepath.Join(string(root), ".bashrc")) b, err := os.ReadFile(filepath.Join(string(root), ".bashrc"))

View File

@ -7,9 +7,9 @@ import (
"os/signal" "os/signal"
"time" "time"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
@ -18,23 +18,22 @@ import (
// gitAskpass is used by the Coder agent to automatically authenticate // gitAskpass is used by the Coder agent to automatically authenticate
// with Git providers based on a hostname. // with Git providers based on a hostname.
func gitAskpass() *cobra.Command { func (r *RootCmd) gitAskpass() *clibase.Cmd {
return &cobra.Command{ return &clibase.Cmd{
Use: "gitaskpass", Use: "gitaskpass",
Hidden: true, Hidden: true,
Args: cobra.ExactArgs(1), Handler: func(inv *clibase.Invocation) error {
RunE: func(cmd *cobra.Command, args []string) error { ctx := inv.Context()
ctx := cmd.Context()
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...) ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
defer stop() defer stop()
user, host, err := gitauth.ParseAskpass(args[0]) user, host, err := gitauth.ParseAskpass(inv.Args[0])
if err != nil { if err != nil {
return xerrors.Errorf("parse host: %w", err) return xerrors.Errorf("parse host: %w", err)
} }
client, err := createAgentClient(cmd) client, err := r.createAgentClient()
if err != nil { if err != nil {
return xerrors.Errorf("create agent client: %w", err) return xerrors.Errorf("create agent client: %w", err)
} }
@ -45,16 +44,16 @@ func gitAskpass() *cobra.Command {
if errors.As(err, &apiError) && apiError.StatusCode() == http.StatusNotFound { if errors.As(err, &apiError) && apiError.StatusCode() == http.StatusNotFound {
// This prevents the "Run 'coder --help' for usage" // This prevents the "Run 'coder --help' for usage"
// message from occurring. // message from occurring.
cmd.Printf("%s\n", apiError.Message) cliui.Errorf(inv.Stderr, "%s\n", apiError.Message)
return cliui.Canceled return cliui.Canceled
} }
return xerrors.Errorf("get git token: %w", err) return xerrors.Errorf("get git token: %w", err)
} }
if token.URL != "" { if token.URL != "" {
if err := openURL(cmd, token.URL); err == nil { if err := openURL(inv, token.URL); err == nil {
cmd.Printf("Your browser has been opened to authenticate with Git:\n\n\t%s\n\n", token.URL) cliui.Infof(inv.Stdout, "Your browser has been opened to authenticate with Git:\n\n\t%s\n\n", token.URL)
} else { } else {
cmd.Printf("Open the following URL to authenticate with Git:\n\n\t%s\n\n", token.URL) cliui.Infof(inv.Stdout, "Open the following URL to authenticate with Git:\n\n\t%s\n\n", token.URL)
} }
for r := retry.New(250*time.Millisecond, 10*time.Second); r.Wait(ctx); { for r := retry.New(250*time.Millisecond, 10*time.Second); r.Wait(ctx); {
@ -62,19 +61,19 @@ func gitAskpass() *cobra.Command {
if err != nil { if err != nil {
continue continue
} }
cmd.Printf("You've been authenticated with Git!\n") cliui.Infof(inv.Stdout, "You've been authenticated with Git!\n")
break break
} }
} }
if token.Password != "" { if token.Password != "" {
if user == "" { if user == "" {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), token.Username) _, _ = fmt.Fprintln(inv.Stdout, token.Username)
} else { } else {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), token.Password) _, _ = fmt.Fprintln(inv.Stdout, token.Password)
} }
} else { } else {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), token.Username) _, _ = fmt.Fprintln(inv.Stdout, token.Username)
} }
return nil return nil

View File

@ -18,10 +18,10 @@ import (
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
) )
// nolint:paralleltest
func TestGitAskpass(t *testing.T) { func TestGitAskpass(t *testing.T) {
t.Setenv("GIT_PREFIX", "/") t.Parallel()
t.Run("UsernameAndPassword", func(t *testing.T) { t.Run("UsernameAndPassword", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.GitAuthResponse{ httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.GitAuthResponse{
Username: "something", Username: "something",
@ -30,22 +30,23 @@ func TestGitAskpass(t *testing.T) {
})) }))
t.Cleanup(srv.Close) t.Cleanup(srv.Close)
url := srv.URL url := srv.URL
cmd, _ := clitest.New(t, "--agent-url", url, "Username for 'https://github.com':") inv, _ := clitest.New(t, "--agent-url", url, "Username for 'https://github.com':")
inv.Environ.Set("GIT_PREFIX", "/")
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetOutput(pty.Output()) inv.Stdout = pty.Output()
err := cmd.Execute() clitest.Start(t, inv)
require.NoError(t, err)
pty.ExpectMatch("something") pty.ExpectMatch("something")
cmd, _ = clitest.New(t, "--agent-url", url, "Password for 'https://potato@github.com':") inv, _ = clitest.New(t, "--agent-url", url, "Password for 'https://potato@github.com':")
inv.Environ.Set("GIT_PREFIX", "/")
pty = ptytest.New(t) pty = ptytest.New(t)
cmd.SetOutput(pty.Output()) inv.Stdout = pty.Output()
err = cmd.Execute() clitest.Start(t, inv)
require.NoError(t, err)
pty.ExpectMatch("bananas") pty.ExpectMatch("bananas")
}) })
t.Run("NoHost", func(t *testing.T) { t.Run("NoHost", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(context.Background(), w, http.StatusNotFound, codersdk.Response{ httpapi.Write(context.Background(), w, http.StatusNotFound, codersdk.Response{
Message: "Nope!", Message: "Nope!",
@ -53,15 +54,17 @@ func TestGitAskpass(t *testing.T) {
})) }))
t.Cleanup(srv.Close) t.Cleanup(srv.Close)
url := srv.URL url := srv.URL
cmd, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':")
inv.Environ.Set("GIT_PREFIX", "/")
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetOutput(pty.Output()) inv.Stderr = pty.Output()
err := cmd.Execute() err := inv.Run()
require.ErrorIs(t, err, cliui.Canceled) require.ErrorIs(t, err, cliui.Canceled)
pty.ExpectMatch("Nope!") pty.ExpectMatch("Nope!")
}) })
t.Run("Poll", func(t *testing.T) { t.Run("Poll", func(t *testing.T) {
t.Parallel()
resp := atomic.Pointer[agentsdk.GitAuthResponse]{} resp := atomic.Pointer[agentsdk.GitAuthResponse]{}
resp.Store(&agentsdk.GitAuthResponse{ resp.Store(&agentsdk.GitAuthResponse{
URL: "https://something.org", URL: "https://something.org",
@ -81,11 +84,12 @@ func TestGitAskpass(t *testing.T) {
t.Cleanup(srv.Close) t.Cleanup(srv.Close)
url := srv.URL url := srv.URL
cmd, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':")
inv.Environ.Set("GIT_PREFIX", "/")
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetOutput(pty.Output()) inv.Stdout = pty.Output()
go func() { go func() {
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
<-poll <-poll

View File

@ -12,19 +12,19 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
) )
func gitssh() *cobra.Command { func (r *RootCmd) gitssh() *clibase.Cmd {
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "gitssh", Use: "gitssh",
Hidden: true, Hidden: true,
Short: `Wraps the "ssh" command and uses the coder gitssh key for authentication`, Short: `Wraps the "ssh" command and uses the coder gitssh key for authentication`,
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
ctx := cmd.Context() ctx := inv.Context()
env := os.Environ() env := os.Environ()
// Catch interrupt signals to ensure the temporary private // Catch interrupt signals to ensure the temporary private
@ -33,12 +33,12 @@ func gitssh() *cobra.Command {
defer stop() defer stop()
// Early check so errors are reported immediately. // Early check so errors are reported immediately.
identityFiles, err := parseIdentityFilesForHost(ctx, args, env) identityFiles, err := parseIdentityFilesForHost(ctx, inv.Args, env)
if err != nil { if err != nil {
return err return err
} }
client, err := createAgentClient(cmd) client, err := r.createAgentClient()
if err != nil { if err != nil {
return xerrors.Errorf("create agent client: %w", err) return xerrors.Errorf("create agent client: %w", err)
} }
@ -78,24 +78,25 @@ func gitssh() *cobra.Command {
identityArgs = append(identityArgs, "-i", id) identityArgs = append(identityArgs, "-i", id)
} }
args := inv.Args
args = append(identityArgs, args...) args = append(identityArgs, args...)
c := exec.CommandContext(ctx, "ssh", args...) c := exec.CommandContext(ctx, "ssh", args...)
c.Env = append(c.Env, env...) c.Env = append(c.Env, env...)
c.Stderr = cmd.ErrOrStderr() c.Stderr = inv.Stderr
c.Stdout = cmd.OutOrStdout() c.Stdout = inv.Stdout
c.Stdin = cmd.InOrStdin() c.Stdin = inv.Stdin
err = c.Run() err = c.Run()
if err != nil { if err != nil {
exitErr := &exec.ExitError{} exitErr := &exec.ExitError{}
if xerrors.As(err, &exitErr) && exitErr.ExitCode() == 255 { if xerrors.As(err, &exitErr) && exitErr.ExitCode() == 255 {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), _, _ = fmt.Fprintln(inv.Stderr,
"\n"+cliui.Styles.Wrap.Render("Coder authenticates with "+cliui.Styles.Field.Render("git")+ "\n"+cliui.Styles.Wrap.Render("Coder authenticates with "+cliui.Styles.Field.Render("git")+
" using the public key below. All clones with SSH are authenticated automatically 🪄.")+"\n") " using the public key below. All clones with SSH are authenticated automatically 🪄.")+"\n")
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Code.Render(strings.TrimSpace(key.PublicKey))+"\n") _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Code.Render(strings.TrimSpace(key.PublicKey))+"\n")
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Add to GitHub and GitLab:") _, _ = fmt.Fprintln(inv.Stderr, "Add to GitHub and GitLab:")
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Prompt.String()+"https://github.com/settings/ssh/new") _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Prompt.String()+"https://github.com/settings/ssh/new")
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Prompt.String()+"https://gitlab.com/-/profile/keys") _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Prompt.String()+"https://gitlab.com/-/profile/keys")
_, _ = fmt.Fprintln(cmd.ErrOrStderr()) _, _ = fmt.Fprintln(inv.Stderr)
return err return err
} }
return xerrors.Errorf("run ssh command: %w", err) return xerrors.Errorf("run ssh command: %w", err)

View File

@ -57,15 +57,12 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
// start workspace agent // start workspace agent
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) inv, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
agentClient := client agentClient := client
clitest.SetupConfig(t, agentClient, root) clitest.SetupConfig(t, agentClient, root)
errC := make(chan error, 1) clitest.Start(t, inv)
go func() {
errC <- cmd.ExecuteContext(ctx)
}()
t.Cleanup(func() { require.NoError(t, <-errC) })
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
return agentClient, agentToken, pubkey return agentClient, agentToken, pubkey
} }
@ -141,7 +138,7 @@ func TestGitSSH(t *testing.T) {
}, pubkey) }, pubkey)
// set to agent config dir // set to agent config dir
cmd, _ := clitest.New(t, inv, _ := clitest.New(t,
"gitssh", "gitssh",
"--agent-url", client.URL.String(), "--agent-url", client.URL.String(),
"--agent-token", token, "--agent-token", token,
@ -151,7 +148,7 @@ func TestGitSSH(t *testing.T) {
"-o", "IdentitiesOnly=yes", "-o", "IdentitiesOnly=yes",
"127.0.0.1", "127.0.0.1",
) )
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
require.EqualValues(t, 1, inc) require.EqualValues(t, 1, inc)
@ -213,10 +210,10 @@ func TestGitSSH(t *testing.T) {
"mytest", "mytest",
} }
// Test authentication via local private key. // Test authentication via local private key.
cmd, _ := clitest.New(t, cmdArgs...) inv, _ := clitest.New(t, cmdArgs...)
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
select { select {
case key := <-authkey: case key := <-authkey:
@ -230,10 +227,10 @@ func TestGitSSH(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// With the local file deleted, the coder key should be used. // With the local file deleted, the coder key should be used.
cmd, _ = clitest.New(t, cmdArgs...) inv, _ = clitest.New(t, cmdArgs...)
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
select { select {
case key := <-authkey: case key := <-authkey:

292
cli/help.go Normal file
View File

@ -0,0 +1,292 @@
package cli
import (
"bufio"
"bytes"
_ "embed"
"fmt"
"io"
"regexp"
"sort"
"strings"
"text/tabwriter"
"text/template"
"unicode"
"github.com/mitchellh/go-wordwrap"
"golang.org/x/crypto/ssh/terminal"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui"
)
//go:embed help.tpl
var helpTemplateRaw string
type optionGroup struct {
Name string
Description string
Options clibase.OptionSet
}
func ttyWidth() int {
width, _, err := terminal.GetSize(0)
if err != nil {
return 80
}
return width
}
// wrapTTY wraps a string to the width of the terminal, or 80 no terminal
// is detected.
func wrapTTY(s string) string {
return wordwrap.WrapString(s, uint(ttyWidth()))
}
var usageTemplate = template.Must(
template.New("usage").Funcs(
template.FuncMap{
"wrapTTY": func(s string) string {
return wrapTTY(s)
},
"trimNewline": func(s string) string {
return strings.TrimSuffix(s, "\n")
},
"typeHelper": func(opt *clibase.Option) string {
switch v := opt.Value.(type) {
case *clibase.Enum:
return strings.Join(v.Choices, "|")
default:
return v.Type()
}
},
"joinStrings": func(s []string) string {
return strings.Join(s, ", ")
},
"indent": func(body string, spaces int) string {
twidth := ttyWidth()
spacing := strings.Repeat(" ", spaces)
body = wordwrap.WrapString(body, uint(twidth-len(spacing)))
var sb strings.Builder
for _, line := range strings.Split(body, "\n") {
// Remove existing indent, if any.
line = strings.TrimSpace(line)
// Use spaces so we can easily calculate wrapping.
_, _ = sb.WriteString(spacing)
_, _ = sb.WriteString(line)
_, _ = sb.WriteString("\n")
}
return sb.String()
},
"formatSubcommand": func(cmd *clibase.Cmd) string {
// Minimize padding by finding the longest neighboring name.
maxNameLength := len(cmd.Name())
if parent := cmd.Parent; parent != nil {
for _, c := range parent.Children {
if len(c.Name()) > maxNameLength {
maxNameLength = len(c.Name())
}
}
}
var sb strings.Builder
_, _ = fmt.Fprintf(
&sb, "%s%s%s",
strings.Repeat(" ", 4), cmd.Name(), strings.Repeat(" ", maxNameLength-len(cmd.Name())+4),
)
// This is the point at which indentation begins if there's a
// next line.
descStart := sb.Len()
twidth := ttyWidth()
for i, line := range strings.Split(
wordwrap.WrapString(cmd.Short, uint(twidth-descStart)), "\n",
) {
if i > 0 {
_, _ = sb.WriteString(strings.Repeat(" ", descStart))
}
_, _ = sb.WriteString(line)
_, _ = sb.WriteString("\n")
}
return sb.String()
},
"envName": func(opt clibase.Option) string {
if opt.Env == "" {
return ""
}
return opt.Env
},
"flagName": func(opt clibase.Option) string {
return opt.Flag
},
"prettyHeader": func(s string) string {
return cliui.Styles.Bold.Render(s)
},
"isEnterprise": func(opt clibase.Option) bool {
return opt.Annotations.IsSet("enterprise")
},
"isDeprecated": func(opt clibase.Option) bool {
return len(opt.UseInstead) > 0
},
"formatLong": func(long string) string {
// We intentionally don't wrap here because it would misformat
// examples, where the new line would start without the prior
// line's indentation.
return strings.TrimSpace(long)
},
"formatGroupDescription": func(s string) string {
s = strings.ReplaceAll(s, "\n", "")
s = s + "\n"
s = wrapTTY(s)
return s
},
"visibleChildren": func(cmd *clibase.Cmd) []*clibase.Cmd {
return filterSlice(cmd.Children, func(c *clibase.Cmd) bool {
return !c.Hidden
})
},
"optionGroups": func(cmd *clibase.Cmd) []optionGroup {
groups := []optionGroup{{
// Default group.
Name: "",
Description: "",
}}
enterpriseGroup := optionGroup{
Name: "Enterprise",
Description: `These options are only available in the Enterprise Edition.`,
}
// Sort options lexicographically.
sort.Slice(cmd.Options, func(i, j int) bool {
return cmd.Options[i].Name < cmd.Options[j].Name
})
optionLoop:
for _, opt := range cmd.Options {
if opt.Hidden {
continue
}
// Enterprise options are always grouped separately.
if opt.Annotations.IsSet("enterprise") {
enterpriseGroup.Options = append(enterpriseGroup.Options, opt)
continue
}
if len(opt.Group.Ancestry()) == 0 {
// Just add option to default group.
groups[0].Options = append(groups[0].Options, opt)
continue
}
groupName := opt.Group.FullName()
for i, foundGroup := range groups {
if foundGroup.Name != groupName {
continue
}
groups[i].Options = append(groups[i].Options, opt)
continue optionLoop
}
groups = append(groups, optionGroup{
Name: groupName,
Description: opt.Group.Description,
Options: clibase.OptionSet{opt},
})
}
sort.Slice(groups, func(i, j int) bool {
// Sort groups lexicographically.
return groups[i].Name < groups[j].Name
})
// Always show enterprise group last.
groups = append(groups, enterpriseGroup)
return filterSlice(groups, func(g optionGroup) bool {
return len(g.Options) > 0
})
},
},
).Parse(helpTemplateRaw),
)
func filterSlice[T any](s []T, f func(T) bool) []T {
var r []T
for _, v := range s {
if f(v) {
r = append(r, v)
}
}
return r
}
// newLineLimiter makes working with Go templates more bearable. Without this,
// modifying the template is a slow toil of counting newlines and constantly
// checking that a change to one command's help doesn't clobber break another.
type newlineLimiter struct {
w io.Writer
limit int
newLineCounter int
}
func (lm *newlineLimiter) Write(p []byte) (int, error) {
rd := bytes.NewReader(p)
for r, n, _ := rd.ReadRune(); n > 0; r, n, _ = rd.ReadRune() {
switch {
case r == '\r':
// Carriage returns can sneak into `help.tpl` when `git clone`
// is configured to automatically convert line endings.
continue
case r == '\n':
lm.newLineCounter++
if lm.newLineCounter > lm.limit {
continue
}
case !unicode.IsSpace(r):
lm.newLineCounter = 0
}
_, err := lm.w.Write([]byte(string(r)))
if err != nil {
return 0, err
}
}
return len(p), nil
}
var usageWantsArgRe = regexp.MustCompile(`<.*>`)
// helpFn returns a function that generates usage (help)
// output for a given command.
func helpFn() clibase.HandlerFunc {
return func(inv *clibase.Invocation) error {
// We buffer writes to stderr because the newlineLimiter writes one
// rune at a time.
stderrBuf := bufio.NewWriter(inv.Stderr)
out := newlineLimiter{w: stderrBuf, limit: 2}
tabwriter := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0)
err := usageTemplate.Execute(tabwriter, inv.Command)
if err != nil {
return xerrors.Errorf("execute template: %w", err)
}
err = tabwriter.Flush()
if err != nil {
return err
}
err = stderrBuf.Flush()
if err != nil {
return err
}
if len(inv.Args) > 0 && !usageWantsArgRe.MatchString(inv.Command.Use) {
_, _ = fmt.Fprintf(inv.Stderr, "---\nerror: unknown subcommand %q\n", inv.Args[0])
}
return nil
}
}

55
cli/help.tpl Normal file
View File

@ -0,0 +1,55 @@
{{- /* Heavily inspired by the Go toolchain formatting. */ -}}
Usage: {{.FullUsage}}
{{ with .Short }}
{{- wrapTTY . }}
{{"\n"}}
{{- end}}
{{ with .Aliases }}
{{ "\n" }}
{{ "Aliases:"}} {{ joinStrings .}}
{{ "\n" }}
{{- end }}
{{- with .Long}}
{{- formatLong . }}
{{ "\n" }}
{{- end }}
{{ with visibleChildren . }}
{{- range $index, $child := . }}
{{- if eq $index 0 }}
{{ prettyHeader "Subcommands"}}
{{- end }}
{{- "\n" }}
{{- formatSubcommand . | trimNewline }}
{{- end }}
{{- "\n" }}
{{- end }}
{{- range $index, $group := optionGroups . }}
{{ with $group.Name }} {{- print $group.Name " Options" | prettyHeader }} {{ else -}} {{ prettyHeader "Options"}}{{- end -}}
{{- with $group.Description }}
{{ formatGroupDescription . }}
{{- else }}
{{- end }}
{{- range $index, $option := $group.Options }}
{{- if not (eq $option.FlagShorthand "") }}{{- print "\n -" $option.FlagShorthand ", " -}}
{{- else }}{{- print "\n " -}}
{{- end }}
{{- with flagName $option }}--{{ . }}{{ end }} {{- with typeHelper $option }} {{ . }}{{ end }}
{{- with envName $option }}, ${{ . }}{{ end }}
{{- with $option.Default }} (default: {{ . }}){{ end }}
{{- with $option.Description }}
{{- $desc := $option.Description }}
{{ indent $desc 10 }}
{{- if isDeprecated $option }} DEPRECATED {{ end }}
{{- end -}}
{{- end }}
{{- end }}
---
{{- if .Parent }}
Run `coder --help` for a list of global options.
{{- else }}
Report bugs and request features at https://github.com/coder/coder/issues/new
{{- end }}

View File

@ -5,8 +5,8 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/cobra"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/schedule" "github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/coderd/util/ptr"
@ -64,7 +64,7 @@ func workspaceListRowFromWorkspace(now time.Time, usersByID map[uuid.UUID]coders
} }
} }
func list() *cobra.Command { func (r *RootCmd) list() *clibase.Cmd {
var ( var (
all bool all bool
defaultQuery = "owner:me" defaultQuery = "owner:me"
@ -75,18 +75,17 @@ func list() *cobra.Command {
cliui.JSONFormat(), cliui.JSONFormat(),
) )
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "list", Use: "list",
Short: "List workspaces", Short: "List workspaces",
Aliases: []string{"ls"}, Aliases: []string{"ls"},
Args: cobra.ExactArgs(0), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(0),
client, err := CreateClient(cmd) r.InitClient(client),
if err != nil { ),
return err Handler: func(inv *clibase.Invocation) error {
}
filter := codersdk.WorkspaceFilter{ filter := codersdk.WorkspaceFilter{
FilterQuery: searchQuery, FilterQuery: searchQuery,
} }
@ -94,19 +93,19 @@ func list() *cobra.Command {
filter.FilterQuery = "" filter.FilterQuery = ""
} }
res, err := client.Workspaces(cmd.Context(), filter) res, err := client.Workspaces(inv.Context(), filter)
if err != nil { if err != nil {
return err return err
} }
if len(res.Workspaces) == 0 { if len(res.Workspaces) == 0 {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Prompt.String()+"No workspaces found! Create one:") _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Prompt.String()+"No workspaces found! Create one:")
_, _ = fmt.Fprintln(cmd.ErrOrStderr()) _, _ = fmt.Fprintln(inv.Stderr)
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), " "+cliui.Styles.Code.Render("coder create <name>")) _, _ = fmt.Fprintln(inv.Stderr, " "+cliui.Styles.Code.Render("coder create <name>"))
_, _ = fmt.Fprintln(cmd.ErrOrStderr()) _, _ = fmt.Fprintln(inv.Stderr)
return nil return nil
} }
userRes, err := client.Users(cmd.Context(), codersdk.UsersRequest{}) userRes, err := client.Users(inv.Context(), codersdk.UsersRequest{})
if err != nil { if err != nil {
return err return err
} }
@ -122,20 +121,31 @@ func list() *cobra.Command {
displayWorkspaces[i] = workspaceListRowFromWorkspace(now, usersByID, workspace) displayWorkspaces[i] = workspaceListRowFromWorkspace(now, usersByID, workspace)
} }
out, err := formatter.Format(cmd.Context(), displayWorkspaces) out, err := formatter.Format(inv.Context(), displayWorkspaces)
if err != nil { if err != nil {
return err return err
} }
_, err = fmt.Fprintln(cmd.OutOrStdout(), out) _, err = fmt.Fprintln(inv.Stdout, out)
return err return err
}, },
} }
cmd.Options = clibase.OptionSet{
{
Flag: "all",
FlagShorthand: "a",
Description: "Specifies whether all workspaces will be listed or not.",
cmd.Flags().BoolVarP(&all, "all", "a", false, Value: clibase.BoolOf(&all),
"Specifies whether all workspaces will be listed or not.") },
cmd.Flags().StringVar(&searchQuery, "search", defaultQuery, "Search for a workspace with a query.") {
Flag: "search",
Description: "Search for a workspace with a query.",
Default: defaultQuery,
Value: clibase.StringOf(&searchQuery),
},
}
formatter.AttachFlags(cmd) formatter.AttachOptions(&cmd.Options)
return cmd return cmd
} }

View File

@ -27,17 +27,15 @@ func TestList(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, root := clitest.New(t, "ls") inv, root := clitest.New(t, "ls")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancelFunc() defer cancelFunc()
done := make(chan any) done := make(chan any)
go func() { go func() {
errC := cmd.ExecuteContext(ctx) errC := inv.WithContext(ctx).Run()
assert.NoError(t, errC) assert.NoError(t, errC)
close(done) close(done)
}() }()
@ -57,15 +55,15 @@ func TestList(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, root := clitest.New(t, "list", "--output=json") inv, root := clitest.New(t, "list", "--output=json")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancelFunc() defer cancelFunc()
out := bytes.NewBuffer(nil) out := bytes.NewBuffer(nil)
cmd.SetOut(out) inv.Stdout = out
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
var templates []codersdk.Workspace var templates []codersdk.Workspace

View File

@ -14,10 +14,9 @@ import (
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/pkg/browser" "github.com/pkg/browser"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/userpassword" "github.com/coder/coder/coderd/userpassword"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
@ -38,7 +37,7 @@ func init() {
browser.Stdout = io.Discard browser.Stdout = io.Discard
} }
func login() *cobra.Command { func (r *RootCmd) login() *clibase.Cmd {
const firstUserTrialEnv = "CODER_FIRST_USER_TRIAL" const firstUserTrialEnv = "CODER_FIRST_USER_TRIAL"
var ( var (
@ -47,20 +46,16 @@ func login() *cobra.Command {
password string password string
trial bool trial bool
) )
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "login <url>", Use: "login <url>",
Short: "Authenticate with Coder deployment", Short: "Authenticate with Coder deployment",
Args: cobra.MaximumNArgs(1), Middleware: clibase.RequireRangeArgs(0, 1),
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
rawURL := "" rawURL := ""
if len(args) == 0 { if len(inv.Args) == 0 {
var err error rawURL = r.clientURL.String()
rawURL, err = cmd.Flags().GetString(varURL)
if err != nil {
return xerrors.Errorf("get global url flag")
}
} else { } else {
rawURL = args[0] rawURL = inv.Args[0]
} }
if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") { if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") {
@ -79,7 +74,7 @@ func login() *cobra.Command {
serverURL.Scheme = "https" serverURL.Scheme = "https"
} }
client, err := createUnauthenticatedClient(cmd, serverURL) client, err := r.createUnauthenticatedClient(serverURL)
if err != nil { if err != nil {
return err return err
} }
@ -87,25 +82,25 @@ func login() *cobra.Command {
// Try to check the version of the server prior to logging in. // Try to check the version of the server prior to logging in.
// It may be useful to warn the user if they are trying to login // It may be useful to warn the user if they are trying to login
// on a very old client. // on a very old client.
err = checkVersions(cmd, client) err = r.checkVersions(inv, client)
if err != nil { if err != nil {
// Checking versions isn't a fatal error so we print a warning // Checking versions isn't a fatal error so we print a warning
// and proceed. // and proceed.
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(err.Error())) _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Warn.Render(err.Error()))
} }
hasInitialUser, err := client.HasFirstUser(cmd.Context()) hasInitialUser, err := client.HasFirstUser(inv.Context())
if err != nil { if err != nil {
return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err) return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err)
} }
if !hasInitialUser { if !hasInitialUser {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), Caret+"Your Coder deployment hasn't been set up!\n") _, _ = fmt.Fprintf(inv.Stdout, Caret+"Your Coder deployment hasn't been set up!\n")
if username == "" { if username == "" {
if !isTTY(cmd) { if !isTTY(inv) {
return xerrors.New("the initial user cannot be created in non-interactive mode. use the API") return xerrors.New("the initial user cannot be created in non-interactive mode. use the API")
} }
_, err := cliui.Prompt(cmd, cliui.PromptOptions{ _, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Would you like to create the first user?", Text: "Would you like to create the first user?",
Default: cliui.ConfirmYes, Default: cliui.ConfirmYes,
IsConfirm: true, IsConfirm: true,
@ -120,7 +115,7 @@ func login() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("get current user: %w", err) return xerrors.Errorf("get current user: %w", err)
} }
username, err = cliui.Prompt(cmd, cliui.PromptOptions{ username, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "What " + cliui.Styles.Field.Render("username") + " would you like?", Text: "What " + cliui.Styles.Field.Render("username") + " would you like?",
Default: currentUser.Username, Default: currentUser.Username,
}) })
@ -133,7 +128,7 @@ func login() *cobra.Command {
} }
if email == "" { if email == "" {
email, err = cliui.Prompt(cmd, cliui.PromptOptions{ email, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "What's your " + cliui.Styles.Field.Render("email") + "?", Text: "What's your " + cliui.Styles.Field.Render("email") + "?",
Validate: func(s string) error { Validate: func(s string) error {
err := validator.New().Var(s, "email") err := validator.New().Var(s, "email")
@ -152,7 +147,7 @@ func login() *cobra.Command {
var matching bool var matching bool
for !matching { for !matching {
password, err = cliui.Prompt(cmd, cliui.PromptOptions{ password, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Enter a " + cliui.Styles.Field.Render("password") + ":", Text: "Enter a " + cliui.Styles.Field.Render("password") + ":",
Secret: true, Secret: true,
Validate: func(s string) error { Validate: func(s string) error {
@ -162,7 +157,7 @@ func login() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("specify password prompt: %w", err) return xerrors.Errorf("specify password prompt: %w", err)
} }
confirm, err := cliui.Prompt(cmd, cliui.PromptOptions{ confirm, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm " + cliui.Styles.Field.Render("password") + ":", Text: "Confirm " + cliui.Styles.Field.Render("password") + ":",
Secret: true, Secret: true,
Validate: cliui.ValidateNotEmpty, Validate: cliui.ValidateNotEmpty,
@ -173,13 +168,13 @@ func login() *cobra.Command {
matching = confirm == password matching = confirm == password
if !matching { if !matching {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Error.Render("Passwords do not match")) _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Error.Render("Passwords do not match"))
} }
} }
} }
if !cmd.Flags().Changed("first-user-trial") && os.Getenv(firstUserTrialEnv) == "" { if !inv.ParsedFlags().Changed("first-user-trial") && os.Getenv(firstUserTrialEnv) == "" {
v, _ := cliui.Prompt(cmd, cliui.PromptOptions{ v, _ := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Start a 30-day trial of Enterprise?", Text: "Start a 30-day trial of Enterprise?",
IsConfirm: true, IsConfirm: true,
Default: "yes", Default: "yes",
@ -187,7 +182,7 @@ func login() *cobra.Command {
trial = v == "yes" || v == "y" trial = v == "yes" || v == "y"
} }
_, err = client.CreateFirstUser(cmd.Context(), codersdk.CreateFirstUserRequest{ _, err = client.CreateFirstUser(inv.Context(), codersdk.CreateFirstUserRequest{
Email: email, Email: email,
Username: username, Username: username,
Password: password, Password: password,
@ -196,7 +191,7 @@ func login() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("create initial user: %w", err) return xerrors.Errorf("create initial user: %w", err)
} }
resp, err := client.LoginWithPassword(cmd.Context(), codersdk.LoginWithPasswordRequest{ resp, err := client.LoginWithPassword(inv.Context(), codersdk.LoginWithPasswordRequest{
Email: email, Email: email,
Password: password, Password: password,
}) })
@ -205,7 +200,7 @@ func login() *cobra.Command {
} }
sessionToken := resp.SessionToken sessionToken := resp.SessionToken
config := createConfig(cmd) config := r.createConfig()
err = config.Session().Write(sessionToken) err = config.Session().Write(sessionToken)
if err != nil { if err != nil {
return xerrors.Errorf("write session token: %w", err) return xerrors.Errorf("write session token: %w", err)
@ -215,32 +210,32 @@ func login() *cobra.Command {
return xerrors.Errorf("write server url: %w", err) return xerrors.Errorf("write server url: %w", err)
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), _, _ = fmt.Fprintf(inv.Stdout,
cliui.Styles.Paragraph.Render(fmt.Sprintf("Welcome to Coder, %s! You're authenticated.", cliui.Styles.Keyword.Render(username)))+"\n") cliui.Styles.Paragraph.Render(fmt.Sprintf("Welcome to Coder, %s! You're authenticated.", cliui.Styles.Keyword.Render(username)))+"\n")
_, _ = fmt.Fprintf(cmd.OutOrStdout(), _, _ = fmt.Fprintf(inv.Stdout,
cliui.Styles.Paragraph.Render("Get started by creating a template: "+cliui.Styles.Code.Render("coder templates init"))+"\n") cliui.Styles.Paragraph.Render("Get started by creating a template: "+cliui.Styles.Code.Render("coder templates init"))+"\n")
return nil return nil
} }
sessionToken, _ := cmd.Flags().GetString(varToken) sessionToken, _ := inv.ParsedFlags().GetString(varToken)
if sessionToken == "" { if sessionToken == "" {
authURL := *serverURL authURL := *serverURL
// Don't use filepath.Join, we don't want to use the os separator // Don't use filepath.Join, we don't want to use the os separator
// for a url. // for a url.
authURL.Path = path.Join(serverURL.Path, "/cli-auth") authURL.Path = path.Join(serverURL.Path, "/cli-auth")
if err := openURL(cmd, authURL.String()); err != nil { if err := openURL(inv, authURL.String()); err != nil {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Open the following in your browser:\n\n\t%s\n\n", authURL.String()) _, _ = fmt.Fprintf(inv.Stdout, "Open the following in your browser:\n\n\t%s\n\n", authURL.String())
} else { } else {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Your browser has been opened to visit:\n\n\t%s\n\n", authURL.String()) _, _ = fmt.Fprintf(inv.Stdout, "Your browser has been opened to visit:\n\n\t%s\n\n", authURL.String())
} }
sessionToken, err = cliui.Prompt(cmd, cliui.PromptOptions{ sessionToken, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Paste your token here:", Text: "Paste your token here:",
Secret: true, Secret: true,
Validate: func(token string) error { Validate: func(token string) error {
client.SetSessionToken(token) client.SetSessionToken(token)
_, err := client.User(cmd.Context(), codersdk.Me) _, err := client.User(inv.Context(), codersdk.Me)
if err != nil { if err != nil {
return xerrors.New("That's not a valid token!") return xerrors.New("That's not a valid token!")
} }
@ -254,12 +249,12 @@ func login() *cobra.Command {
// Login to get user data - verify it is OK before persisting // Login to get user data - verify it is OK before persisting
client.SetSessionToken(sessionToken) client.SetSessionToken(sessionToken)
resp, err := client.User(cmd.Context(), codersdk.Me) resp, err := client.User(inv.Context(), codersdk.Me)
if err != nil { if err != nil {
return xerrors.Errorf("get user: %w", err) return xerrors.Errorf("get user: %w", err)
} }
config := createConfig(cmd) config := r.createConfig()
err = config.Session().Write(sessionToken) err = config.Session().Write(sessionToken)
if err != nil { if err != nil {
return xerrors.Errorf("write session token: %w", err) return xerrors.Errorf("write session token: %w", err)
@ -269,14 +264,36 @@ func login() *cobra.Command {
return xerrors.Errorf("write server url: %w", err) return xerrors.Errorf("write server url: %w", err)
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), Caret+"Welcome to Coder, %s! You're authenticated.\n", cliui.Styles.Keyword.Render(resp.Username)) _, _ = fmt.Fprintf(inv.Stdout, Caret+"Welcome to Coder, %s! You're authenticated.\n", cliui.Styles.Keyword.Render(resp.Username))
return nil return nil
}, },
} }
cliflag.StringVarP(cmd.Flags(), &email, "first-user-email", "", "CODER_FIRST_USER_EMAIL", "", "Specifies an email address to use if creating the first user for the deployment.") cmd.Options = clibase.OptionSet{
cliflag.StringVarP(cmd.Flags(), &username, "first-user-username", "", "CODER_FIRST_USER_USERNAME", "", "Specifies a username to use if creating the first user for the deployment.") {
cliflag.StringVarP(cmd.Flags(), &password, "first-user-password", "", "CODER_FIRST_USER_PASSWORD", "", "Specifies a password to use if creating the first user for the deployment.") Flag: "first-user-email",
cliflag.BoolVarP(cmd.Flags(), &trial, "first-user-trial", "", firstUserTrialEnv, false, "Specifies whether a trial license should be provisioned for the Coder deployment or not.") Env: "CODER_FIRST_USER_EMAIL",
Description: "Specifies an email address to use if creating the first user for the deployment.",
Value: clibase.StringOf(&email),
},
{
Flag: "first-user-username",
Env: "CODER_FIRST_USER_USERNAME",
Description: "Specifies a username to use if creating the first user for the deployment.",
Value: clibase.StringOf(&username),
},
{
Flag: "first-user-password",
Env: "CODER_FIRST_USER_PASSWORD",
Description: "Specifies a password to use if creating the first user for the deployment.",
Value: clibase.StringOf(&password),
},
{
Flag: "first-user-trial",
Env: firstUserTrialEnv,
Description: "Specifies whether a trial license should be provisioned for the Coder deployment or not.",
Value: clibase.BoolOf(&trial),
},
}
return cmd return cmd
} }
@ -293,8 +310,8 @@ func isWSL() (bool, error) {
} }
// openURL opens the provided URL via user's default browser // openURL opens the provided URL via user's default browser
func openURL(cmd *cobra.Command, urlToOpen string) error { func openURL(inv *clibase.Invocation, urlToOpen string) error {
noOpen, err := cmd.Flags().GetBool(varNoOpen) noOpen, err := inv.ParsedFlags().GetBool(varNoOpen)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -314,7 +331,7 @@ func openURL(cmd *cobra.Command, urlToOpen string) error {
browserEnv := os.Getenv("BROWSER") browserEnv := os.Getenv("BROWSER")
if browserEnv != "" { if browserEnv != "" {
browserSh := fmt.Sprintf("%s '%s'", browserEnv, urlToOpen) browserSh := fmt.Sprintf("%s '%s'", browserEnv, urlToOpen)
cmd := exec.CommandContext(cmd.Context(), "sh", "-c", browserSh) cmd := exec.CommandContext(inv.Context(), "sh", "-c", browserSh)
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return xerrors.Errorf("failed to run %v (out: %q): %w", cmd.Args, out, err) return xerrors.Errorf("failed to run %v (out: %q): %w", cmd.Args, out, err)

View File

@ -20,7 +20,7 @@ func TestLogin(t *testing.T) {
t.Parallel() t.Parallel()
client := coderdtest.New(t, nil) client := coderdtest.New(t, nil)
root, _ := clitest.New(t, "login", client.URL.String()) root, _ := clitest.New(t, "login", client.URL.String())
err := root.Execute() err := root.Run()
require.Error(t, err) require.Error(t, err)
}) })
@ -28,7 +28,7 @@ func TestLogin(t *testing.T) {
t.Parallel() t.Parallel()
badLoginURL := "https://fcca2077f06e68aaf9" badLoginURL := "https://fcca2077f06e68aaf9"
root, _ := clitest.New(t, "login", badLoginURL) root, _ := clitest.New(t, "login", badLoginURL)
err := root.Execute() err := root.Run()
errMsg := fmt.Sprintf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser?", badLoginURL) errMsg := fmt.Sprintf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser?", badLoginURL)
require.ErrorContains(t, err, errMsg) require.ErrorContains(t, err, errMsg)
}) })
@ -41,12 +41,10 @@ func TestLogin(t *testing.T) {
// https://github.com/mattn/go-isatty/issues/59 // https://github.com/mattn/go-isatty/issues/59
doneChan := make(chan struct{}) doneChan := make(chan struct{})
root, _ := clitest.New(t, "login", "--force-tty", client.URL.String()) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String())
pty := ptytest.New(t) pty := ptytest.New(t).Attach(root)
root.SetIn(pty.Input())
root.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := root.Execute() err := root.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -74,16 +72,10 @@ func TestLogin(t *testing.T) {
// The --force-tty flag is required on Windows, because the `isatty` library does not // The --force-tty flag is required on Windows, because the `isatty` library does not
// accurately detect Windows ptys when they are not attached to a process: // accurately detect Windows ptys when they are not attached to a process:
// https://github.com/mattn/go-isatty/issues/59 // https://github.com/mattn/go-isatty/issues/59
doneChan := make(chan struct{}) inv, _ := clitest.New(t, "--url", client.URL.String(), "login", "--force-tty")
root, _ := clitest.New(t, "--url", client.URL.String(), "login", "--force-tty") pty := ptytest.New(t).Attach(inv)
pty := ptytest.New(t)
root.SetIn(pty.Input()) clitest.Start(t, inv)
root.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := root.Execute()
assert.NoError(t, err)
}()
matches := []string{ matches := []string{
"first user?", "yes", "first user?", "yes",
@ -100,7 +92,6 @@ func TestLogin(t *testing.T) {
pty.WriteLine(value) pty.WriteLine(value)
} }
pty.ExpectMatch("Welcome to Coder") pty.ExpectMatch("Welcome to Coder")
<-doneChan
}) })
t.Run("InitialUserFlags", func(t *testing.T) { t.Run("InitialUserFlags", func(t *testing.T) {
@ -108,12 +99,10 @@ func TestLogin(t *testing.T) {
client := coderdtest.New(t, nil) client := coderdtest.New(t, nil)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
root, _ := clitest.New(t, "login", client.URL.String(), "--first-user-username", "testuser", "--first-user-email", "user@coder.com", "--first-user-password", "SomeSecurePassword!", "--first-user-trial") root, _ := clitest.New(t, "login", client.URL.String(), "--first-user-username", "testuser", "--first-user-email", "user@coder.com", "--first-user-password", "SomeSecurePassword!", "--first-user-trial")
pty := ptytest.New(t) pty := ptytest.New(t).Attach(root)
root.SetIn(pty.Input())
root.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := root.Execute() err := root.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
pty.ExpectMatch("Welcome to Coder") pty.ExpectMatch("Welcome to Coder")
@ -130,12 +119,10 @@ func TestLogin(t *testing.T) {
// https://github.com/mattn/go-isatty/issues/59 // https://github.com/mattn/go-isatty/issues/59
doneChan := make(chan struct{}) doneChan := make(chan struct{})
root, _ := clitest.New(t, "login", "--force-tty", client.URL.String()) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String())
pty := ptytest.New(t) pty := ptytest.New(t).Attach(root)
root.SetIn(pty.Input())
root.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := root.ExecuteContext(ctx) err := root.WithContext(ctx).Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -173,12 +160,10 @@ func TestLogin(t *testing.T) {
doneChan := make(chan struct{}) doneChan := make(chan struct{})
root, _ := clitest.New(t, "login", "--force-tty", client.URL.String(), "--no-open") root, _ := clitest.New(t, "login", "--force-tty", client.URL.String(), "--no-open")
pty := ptytest.New(t) pty := ptytest.New(t).Attach(root)
root.SetIn(pty.Input())
root.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := root.Execute() err := root.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -197,12 +182,10 @@ func TestLogin(t *testing.T) {
defer cancelFunc() defer cancelFunc()
doneChan := make(chan struct{}) doneChan := make(chan struct{})
root, _ := clitest.New(t, "login", client.URL.String(), "--no-open") root, _ := clitest.New(t, "login", client.URL.String(), "--no-open")
pty := ptytest.New(t) pty := ptytest.New(t).Attach(root)
root.SetIn(pty.Input())
root.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := root.ExecuteContext(ctx) err := root.WithContext(ctx).Run()
// An error is expected in this case, since the login wasn't successful: // An error is expected in this case, since the login wasn't successful:
assert.Error(t, err) assert.Error(t, err)
}() }()
@ -219,7 +202,7 @@ func TestLogin(t *testing.T) {
client := coderdtest.New(t, nil) client := coderdtest.New(t, nil)
coderdtest.CreateFirstUser(t, client) coderdtest.CreateFirstUser(t, client)
root, cfg := clitest.New(t, "login", client.URL.String(), "--token", client.SessionToken()) root, cfg := clitest.New(t, "login", client.URL.String(), "--token", client.SessionToken())
err := root.Execute() err := root.Run()
require.NoError(t, err) require.NoError(t, err)
sessionFile, err := cfg.Session().Read() sessionFile, err := cfg.Session().Read()
require.NoError(t, err) require.NoError(t, err)

View File

@ -5,27 +5,28 @@ import (
"os" "os"
"strings" "strings"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
) )
func logout() *cobra.Command { func (r *RootCmd) logout() *clibase.Cmd {
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "logout", Use: "logout",
Short: "Unauthenticate your local session", Short: "Unauthenticate your local session",
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.Chain(
client, err := CreateClient(cmd) r.InitClient(client),
if err != nil { ),
return err Handler: func(inv *clibase.Invocation) error {
}
var errors []error var errors []error
config := createConfig(cmd) config := r.createConfig()
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ var err error
_, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Are you sure you want to log out?", Text: "Are you sure you want to log out?",
IsConfirm: true, IsConfirm: true,
Default: cliui.ConfirmYes, Default: cliui.ConfirmYes,
@ -34,7 +35,7 @@ func logout() *cobra.Command {
return err return err
} }
err = client.Logout(cmd.Context()) err = client.Logout(inv.Context())
if err != nil { if err != nil {
errors = append(errors, xerrors.Errorf("logout api: %w", err)) errors = append(errors, xerrors.Errorf("logout api: %w", err))
} }
@ -67,11 +68,10 @@ func logout() *cobra.Command {
errorString := strings.TrimRight(errorStringBuilder.String(), "\n") errorString := strings.TrimRight(errorStringBuilder.String(), "\n")
return xerrors.New("Failed to log out.\n" + errorString) return xerrors.New("Failed to log out.\n" + errorString)
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), Caret+"You are no longer logged in. You can log in using 'coder login <url>'.\n") _, _ = fmt.Fprintf(inv.Stdout, Caret+"You are no longer logged in. You can log in using 'coder login <url>'.\n")
return nil return nil
}, },
} }
cmd.Options = append(cmd.Options, cliui.SkipPromptOption())
cliui.AllowSkipPrompt(cmd)
return cmd return cmd
} }

View File

@ -1,9 +1,7 @@
package cli_test package cli_test
import ( import (
"fmt"
"os" "os"
"regexp"
"runtime" "runtime"
"testing" "testing"
@ -30,12 +28,12 @@ func TestLogout(t *testing.T) {
logoutChan := make(chan struct{}) logoutChan := make(chan struct{})
logout, _ := clitest.New(t, "logout", "--global-config", string(config)) logout, _ := clitest.New(t, "logout", "--global-config", string(config))
logout.SetIn(pty.Input()) logout.Stdin = pty.Input()
logout.SetOut(pty.Output()) logout.Stdout = pty.Output()
go func() { go func() {
defer close(logoutChan) defer close(logoutChan)
err := logout.Execute() err := logout.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.NoFileExists(t, string(config.URL())) assert.NoFileExists(t, string(config.URL()))
assert.NoFileExists(t, string(config.Session())) assert.NoFileExists(t, string(config.Session()))
@ -58,12 +56,12 @@ func TestLogout(t *testing.T) {
logoutChan := make(chan struct{}) logoutChan := make(chan struct{})
logout, _ := clitest.New(t, "logout", "--global-config", string(config), "-y") logout, _ := clitest.New(t, "logout", "--global-config", string(config), "-y")
logout.SetIn(pty.Input()) logout.Stdin = pty.Input()
logout.SetOut(pty.Output()) logout.Stdout = pty.Output()
go func() { go func() {
defer close(logoutChan) defer close(logoutChan)
err := logout.Execute() err := logout.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.NoFileExists(t, string(config.URL())) assert.NoFileExists(t, string(config.URL()))
assert.NoFileExists(t, string(config.Session())) assert.NoFileExists(t, string(config.Session()))
@ -88,13 +86,13 @@ func TestLogout(t *testing.T) {
logoutChan := make(chan struct{}) logoutChan := make(chan struct{})
logout, _ := clitest.New(t, "logout", "--global-config", string(config)) logout, _ := clitest.New(t, "logout", "--global-config", string(config))
logout.SetIn(pty.Input()) logout.Stdin = pty.Input()
logout.SetOut(pty.Output()) logout.Stdout = pty.Output()
go func() { go func() {
defer close(logoutChan) defer close(logoutChan)
err := logout.Execute() err := logout.Run()
assert.EqualError(t, err, "You are not logged in. Try logging in using 'coder login <url>'.") assert.ErrorContains(t, err, "You are not logged in. Try logging in using 'coder login <url>'.")
}() }()
<-logoutChan <-logoutChan
@ -115,13 +113,13 @@ func TestLogout(t *testing.T) {
logoutChan := make(chan struct{}) logoutChan := make(chan struct{})
logout, _ := clitest.New(t, "logout", "--global-config", string(config)) logout, _ := clitest.New(t, "logout", "--global-config", string(config))
logout.SetIn(pty.Input()) logout.Stdin = pty.Input()
logout.SetOut(pty.Output()) logout.Stdout = pty.Output()
go func() { go func() {
defer close(logoutChan) defer close(logoutChan)
err = logout.Execute() err = logout.Run()
assert.EqualError(t, err, "You are not logged in. Try logging in using 'coder login <url>'.") assert.ErrorContains(t, err, "You are not logged in. Try logging in using 'coder login <url>'.")
}() }()
<-logoutChan <-logoutChan
@ -166,29 +164,27 @@ func TestLogout(t *testing.T) {
} }
}() }()
logoutChan := make(chan struct{})
logout, _ := clitest.New(t, "logout", "--global-config", string(config)) logout, _ := clitest.New(t, "logout", "--global-config", string(config))
logout.SetIn(pty.Input()) logout.Stdin = pty.Input()
logout.SetOut(pty.Output()) logout.Stdout = pty.Output()
go func() { go func() {
defer close(logoutChan)
err := logout.Execute()
assert.NotNil(t, err)
var errorMessage string
if runtime.GOOS == "windows" {
errorMessage = "The process cannot access the file because it is being used by another process."
} else {
errorMessage = "permission denied"
}
errRegex := regexp.MustCompile(fmt.Sprintf("Failed to log out.\n\tremove URL file: .+: %s\n\tremove session file: .+: %s", errorMessage, errorMessage))
assert.Regexp(t, errRegex, err.Error())
}()
pty.ExpectMatch("Are you sure you want to log out?") pty.ExpectMatch("Are you sure you want to log out?")
pty.WriteLine("yes") pty.WriteLine("yes")
<-logoutChan }()
err = logout.Run()
require.Error(t, err)
t.Logf("err: %v", err)
var wantError string
if runtime.GOOS == "windows" {
wantError = "The process cannot access the file because it is being used by another process."
} else {
wantError = "permission denied"
}
require.ErrorContains(t, err, wantError)
}) })
} }
@ -200,11 +196,11 @@ func login(t *testing.T, pty *ptytest.PTY) config.Root {
doneChan := make(chan struct{}) doneChan := make(chan struct{})
root, cfg := clitest.New(t, "login", "--force-tty", client.URL.String(), "--no-open") root, cfg := clitest.New(t, "login", "--force-tty", client.URL.String(), "--no-open")
root.SetIn(pty.Input()) root.Stdin = pty.Input()
root.SetOut(pty.Output()) root.Stdout = pty.Output()
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := root.Execute() err := root.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()

View File

@ -5,10 +5,10 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
@ -51,20 +51,20 @@ func createParameterMapFromFile(parameterFile string) (map[string]string, error)
// Returns a parameter value from a given map, if the map does not exist or does not contain the item, it takes input from the user. // Returns a parameter value from a given map, if the map does not exist or does not contain the item, it takes input from the user.
// Throws an error if there are any errors with the users input. // Throws an error if there are any errors with the users input.
func getParameterValueFromMapOrInput(cmd *cobra.Command, parameterMap map[string]string, parameterSchema codersdk.ParameterSchema) (string, error) { func getParameterValueFromMapOrInput(inv *clibase.Invocation, parameterMap map[string]string, parameterSchema codersdk.ParameterSchema) (string, error) {
var parameterValue string var parameterValue string
var err error var err error
if parameterMap != nil { if parameterMap != nil {
var ok bool var ok bool
parameterValue, ok = parameterMap[parameterSchema.Name] parameterValue, ok = parameterMap[parameterSchema.Name]
if !ok { if !ok {
parameterValue, err = cliui.ParameterSchema(cmd, parameterSchema) parameterValue, err = cliui.ParameterSchema(inv, parameterSchema)
if err != nil { if err != nil {
return "", err return "", err
} }
} }
} else { } else {
parameterValue, err = cliui.ParameterSchema(cmd, parameterSchema) parameterValue, err = cliui.ParameterSchema(inv, parameterSchema)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -72,20 +72,20 @@ func getParameterValueFromMapOrInput(cmd *cobra.Command, parameterMap map[string
return parameterValue, nil return parameterValue, nil
} }
func getWorkspaceBuildParameterValueFromMapOrInput(cmd *cobra.Command, parameterMap map[string]string, templateVersionParameter codersdk.TemplateVersionParameter) (*codersdk.WorkspaceBuildParameter, error) { func getWorkspaceBuildParameterValueFromMapOrInput(inv *clibase.Invocation, parameterMap map[string]string, templateVersionParameter codersdk.TemplateVersionParameter) (*codersdk.WorkspaceBuildParameter, error) {
var parameterValue string var parameterValue string
var err error var err error
if parameterMap != nil { if parameterMap != nil {
var ok bool var ok bool
parameterValue, ok = parameterMap[templateVersionParameter.Name] parameterValue, ok = parameterMap[templateVersionParameter.Name]
if !ok { if !ok {
parameterValue, err = cliui.RichParameter(cmd, templateVersionParameter) parameterValue, err = cliui.RichParameter(inv, templateVersionParameter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
} else { } else {
parameterValue, err = cliui.RichParameter(cmd, templateVersionParameter) parameterValue, err = cliui.RichParameter(inv, templateVersionParameter)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,13 +1,13 @@
package cli package cli
import ( import (
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
) )
func parameters() *cobra.Command { func (r *RootCmd) parameters() *clibase.Cmd {
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Short: "List parameters for a given scope", Short: "List parameters for a given scope",
Example: formatExamples( Long: formatExamples(
example{ example{
Command: "coder parameters list workspace my-workspace", Command: "coder parameters list workspace my-workspace",
}, },
@ -20,12 +20,9 @@ func parameters() *cobra.Command {
// constructing curl requests. // constructing curl requests.
Hidden: true, Hidden: true,
Aliases: []string{"params"}, Aliases: []string{"params"},
RunE: func(cmd *cobra.Command, args []string) error { Children: []*clibase.Cmd{
return cmd.Help() r.parameterList(),
}, },
} }
cmd.AddCommand(
parameterList(),
)
return cmd return cmd
} }

View File

@ -4,32 +4,32 @@ import (
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func parameterList() *cobra.Command { func (r *RootCmd) parameterList() *clibase.Cmd {
formatter := cliui.NewOutputFormatter( formatter := cliui.NewOutputFormatter(
cliui.TableFormat([]codersdk.Parameter{}, []string{"name", "scope", "destination scheme"}), cliui.TableFormat([]codersdk.Parameter{}, []string{"name", "scope", "destination scheme"}),
cliui.JSONFormat(), cliui.JSONFormat(),
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "list", Use: "list",
Aliases: []string{"ls"}, Aliases: []string{"ls"},
Args: cobra.ExactArgs(2), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(2),
scope, name := args[0], args[1] r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
scope, name := inv.Args[0], inv.Args[1]
client, err := CreateClient(cmd) organization, err := CurrentOrganization(inv, client)
if err != nil {
return err
}
organization, err := CurrentOrganization(cmd, client)
if err != nil { if err != nil {
return xerrors.Errorf("get current organization: %w", err) return xerrors.Errorf("get current organization: %w", err)
} }
@ -37,13 +37,13 @@ func parameterList() *cobra.Command {
var scopeID uuid.UUID var scopeID uuid.UUID
switch codersdk.ParameterScope(scope) { switch codersdk.ParameterScope(scope) {
case codersdk.ParameterWorkspace: case codersdk.ParameterWorkspace:
workspace, err := namedWorkspace(cmd, client, name) workspace, err := namedWorkspace(inv.Context(), client, name)
if err != nil { if err != nil {
return err return err
} }
scopeID = workspace.ID scopeID = workspace.ID
case codersdk.ParameterTemplate: case codersdk.ParameterTemplate:
template, err := client.TemplateByName(cmd.Context(), organization.ID, name) template, err := client.TemplateByName(inv.Context(), organization.ID, name)
if err != nil { if err != nil {
return xerrors.Errorf("get workspace template: %w", err) return xerrors.Errorf("get workspace template: %w", err)
} }
@ -57,7 +57,7 @@ func parameterList() *cobra.Command {
// Could be a template_version id or a job id. Check for the // Could be a template_version id or a job id. Check for the
// version id. // version id.
tv, err := client.TemplateVersion(cmd.Context(), scopeID) tv, err := client.TemplateVersion(inv.Context(), scopeID)
if err == nil { if err == nil {
scopeID = tv.Job.ID scopeID = tv.Job.ID
} }
@ -68,21 +68,21 @@ func parameterList() *cobra.Command {
}) })
} }
params, err := client.Parameters(cmd.Context(), codersdk.ParameterScope(scope), scopeID) params, err := client.Parameters(inv.Context(), codersdk.ParameterScope(scope), scopeID)
if err != nil { if err != nil {
return xerrors.Errorf("fetch params: %w", err) return xerrors.Errorf("fetch params: %w", err)
} }
out, err := formatter.Format(cmd.Context(), params) out, err := formatter.Format(inv.Context(), params)
if err != nil { if err != nil {
return xerrors.Errorf("render output: %w", err) return xerrors.Errorf("render output: %w", err)
} }
_, err = fmt.Fprintln(cmd.OutOrStdout(), out) _, err = fmt.Fprintln(inv.Stdout, out)
return err return err
}, },
} }
formatter.AttachFlags(cmd) formatter.AttachOptions(&cmd.Options)
return cmd return cmd
} }

View File

@ -5,46 +5,48 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman" "cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func ping() *cobra.Command { func (r *RootCmd) ping() *clibase.Cmd {
var ( var (
pingNum int pingNum int64
pingTimeout time.Duration pingTimeout time.Duration
pingWait time.Duration pingWait time.Duration
verbose bool
) )
cmd := &cobra.Command{
client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "ping <workspace>", Use: "ping <workspace>",
Short: "Ping a workspace", Short: "Ping a workspace",
Args: cobra.ExactArgs(1), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(1),
ctx, cancel := context.WithCancel(cmd.Context()) r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel() defer cancel()
client, err := CreateClient(cmd) workspaceName := inv.Args[0]
if err != nil { _, workspaceAgent, err := getWorkspaceAndAgent(
return err ctx, inv, client,
} codersdk.Me, workspaceName,
)
workspaceName := args[0]
_, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, workspaceName, false)
if err != nil { if err != nil {
return err return err
} }
var logger slog.Logger var logger slog.Logger
if verbose { if r.verbose {
logger = slog.Make(sloghuman.Sink(cmd.OutOrStdout())).Leveled(slog.LevelDebug) logger = slog.Make(sloghuman.Sink(inv.Stdout)).Leveled(slog.LevelDebug)
} }
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{Logger: logger}) conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{Logger: logger})
@ -70,8 +72,8 @@ func ping() *cobra.Command {
cancel() cancel()
if err != nil { if err != nil {
if xerrors.Is(err, context.DeadlineExceeded) { if xerrors.Is(err, context.DeadlineExceeded) {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "ping to %q timed out \n", workspaceName) _, _ = fmt.Fprintf(inv.Stdout, "ping to %q timed out \n", workspaceName)
if n == pingNum { if n == int(pingNum) {
return nil return nil
} }
continue continue
@ -84,8 +86,8 @@ func ping() *cobra.Command {
continue continue
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "ping to %q failed %s\n", workspaceName, err.Error()) _, _ = fmt.Fprintf(inv.Stdout, "ping to %q failed %s\n", workspaceName, err.Error())
if n == pingNum { if n == int(pingNum) {
return nil return nil
} }
continue continue
@ -95,7 +97,7 @@ func ping() *cobra.Command {
var via string var via string
if p2p { if p2p {
if !didP2p { if !didP2p {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "p2p connection established in", _, _ = fmt.Fprintln(inv.Stdout, "p2p connection established in",
cliui.Styles.DateTimeStamp.Render(time.Since(start).Round(time.Millisecond).String()), cliui.Styles.DateTimeStamp.Render(time.Since(start).Round(time.Millisecond).String()),
) )
} }
@ -117,22 +119,40 @@ func ping() *cobra.Command {
) )
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "pong from %s %s in %s\n", _, _ = fmt.Fprintf(inv.Stdout, "pong from %s %s in %s\n",
cliui.Styles.Keyword.Render(workspaceName), cliui.Styles.Keyword.Render(workspaceName),
via, via,
cliui.Styles.DateTimeStamp.Render(dur.String()), cliui.Styles.DateTimeStamp.Render(dur.String()),
) )
if n == pingNum { if n == int(pingNum) {
return nil return nil
} }
} }
}, },
} }
cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enables verbose logging.") cmd.Options = clibase.OptionSet{
cmd.Flags().DurationVarP(&pingWait, "wait", "", time.Second, "Specifies how long to wait between pings.") {
cmd.Flags().DurationVarP(&pingTimeout, "timeout", "t", 5*time.Second, "Specifies how long to wait for a ping to complete.") Flag: "wait",
cmd.Flags().IntVarP(&pingNum, "num", "n", 10, "Specifies the number of pings to perform.") Description: "Specifies how long to wait between pings.",
Default: "1s",
Value: clibase.DurationOf(&pingWait),
},
{
Flag: "timeout",
FlagShorthand: "t",
Default: "5s",
Description: "Specifies how long to wait for a ping to complete.",
Value: clibase.DurationOf(&pingTimeout),
},
{
Flag: "num",
FlagShorthand: "n",
Default: "10",
Description: "Specifies the number of pings to perform.",
Value: clibase.Int64Of(&pingNum),
},
}
return cmd return cmd
} }

View File

@ -22,12 +22,12 @@ func TestPing(t *testing.T) {
t.Parallel() t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t, nil) client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
cmd, root := clitest.New(t, "ping", workspace.Name) inv, root := clitest.New(t, "ping", workspace.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
agentClient := agentsdk.New(client.URL) agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(agentToken) agentClient.SetSessionToken(agentToken)
@ -43,7 +43,7 @@ func TestPing(t *testing.T) {
defer cancel() defer cancel()
cmdDone := tGo(t, func() { cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.NoError(t, err) assert.NoError(t, err)
}) })

View File

@ -12,26 +12,25 @@ import (
"syscall" "syscall"
"github.com/pion/udp" "github.com/pion/udp"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/agent" "github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func portForward() *cobra.Command { func (r *RootCmd) portForward() *clibase.Cmd {
var ( var (
tcpForwards []string // <port>:<port> tcpForwards []string // <port>:<port>
udpForwards []string // <port>:<port> udpForwards []string // <port>:<port>
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "port-forward <workspace>", Use: "port-forward <workspace>",
Short: "Forward ports from machine to a workspace", Short: "Forward ports from machine to a workspace",
Aliases: []string{"tunnel"}, Aliases: []string{"tunnel"},
Args: cobra.ExactArgs(1), Long: formatExamples(
Example: formatExamples(
example{ example{
Description: "Port forward a single TCP port from 1234 in the workspace to port 5678 on your local machine", Description: "Port forward a single TCP port from 1234 in the workspace to port 5678 on your local machine",
Command: "coder port-forward <workspace> --tcp 5678:1234", Command: "coder port-forward <workspace> --tcp 5678:1234",
@ -49,8 +48,12 @@ func portForward() *cobra.Command {
Command: "coder port-forward <workspace> --tcp 8080,9000:3000,9090-9092,10000-10002:10010-10012", Command: "coder port-forward <workspace> --tcp 8080,9000:3000,9090-9092,10000-10002:10010-10012",
}, },
), ),
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.Chain(
ctx, cancel := context.WithCancel(cmd.Context()) clibase.RequireNArgs(1),
r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel() defer cancel()
specs, err := parsePortForwards(tcpForwards, udpForwards) specs, err := parsePortForwards(tcpForwards, udpForwards)
@ -58,19 +61,14 @@ func portForward() *cobra.Command {
return xerrors.Errorf("parse port-forward specs: %w", err) return xerrors.Errorf("parse port-forward specs: %w", err)
} }
if len(specs) == 0 { if len(specs) == 0 {
err = cmd.Help() err = inv.Command.HelpHandler(inv)
if err != nil { if err != nil {
return xerrors.Errorf("generate help output: %w", err) return xerrors.Errorf("generate help output: %w", err)
} }
return xerrors.New("no port-forwards requested") return xerrors.New("no port-forwards requested")
} }
client, err := CreateClient(cmd) workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, codersdk.Me, inv.Args[0])
if err != nil {
return err
}
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
if err != nil { if err != nil {
return err return err
} }
@ -78,13 +76,13 @@ func portForward() *cobra.Command {
return xerrors.New("workspace must be in start transition to port-forward") return xerrors.New("workspace must be in start transition to port-forward")
} }
if workspace.LatestBuild.Job.CompletedAt == nil { if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID) err = cliui.WorkspaceBuild(ctx, inv.Stderr, client, workspace.LatestBuild.ID)
if err != nil { if err != nil {
return err return err
} }
} }
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID) return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@ -116,7 +114,7 @@ func portForward() *cobra.Command {
defer closeAllListeners() defer closeAllListeners()
for i, spec := range specs { for i, spec := range specs {
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec) l, err := listenAndPortForward(ctx, inv, conn, wg, spec)
if err != nil { if err != nil {
return err return err
} }
@ -137,7 +135,7 @@ func portForward() *cobra.Command {
case <-ctx.Done(): case <-ctx.Done():
closeErr = ctx.Err() closeErr = ctx.Err()
case <-sigs: case <-sigs:
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "\nReceived signal, closing all listeners and active connections") _, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections")
} }
cancel() cancel()
@ -145,19 +143,33 @@ func portForward() *cobra.Command {
}() }()
conn.AwaitReachable(ctx) conn.AwaitReachable(ctx)
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!") _, _ = fmt.Fprintln(inv.Stderr, "Ready!")
wg.Wait() wg.Wait()
return closeErr return closeErr
}, },
} }
cliflag.StringArrayVarP(cmd.Flags(), &tcpForwards, "tcp", "p", "CODER_PORT_FORWARD_TCP", nil, "Forward TCP port(s) from the workspace to the local machine") cmd.Options = clibase.OptionSet{
cliflag.StringArrayVarP(cmd.Flags(), &udpForwards, "udp", "", "CODER_PORT_FORWARD_UDP", nil, "Forward UDP port(s) from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols") {
Flag: "tcp",
FlagShorthand: "p",
Env: "CODER_PORT_FORWARD_TCP",
Description: "Forward TCP port(s) from the workspace to the local machine.",
Value: clibase.StringArrayOf(&tcpForwards),
},
{
Flag: "udp",
Env: "CODER_PORT_FORWARD_UDP",
Description: "Forward UDP port(s) from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols.",
Value: clibase.StringArrayOf(&udpForwards),
},
}
return cmd return cmd
} }
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *codersdk.WorkspaceAgentConn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { func listenAndPortForward(ctx context.Context, inv *clibase.Invocation, conn *codersdk.WorkspaceAgentConn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
var ( var (
l net.Listener l net.Listener
@ -200,8 +212,8 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *codersd
if xerrors.Is(err, net.ErrClosed) { if xerrors.Is(err, net.ErrClosed) {
return return
} }
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err) _, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") _, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
return return
} }
@ -209,7 +221,7 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *codersd
defer netConn.Close() defer netConn.Close()
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress) remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
if err != nil { if err != nil {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) _, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
return return
} }
defer remoteConn.Close() defer remoteConn.Close()

View File

@ -31,14 +31,12 @@ func TestPortForward(t *testing.T) {
client := coderdtest.New(t, nil) client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client) _ = coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.New(t, "port-forward", "blah") inv, root := clitest.New(t, "port-forward", "blah")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input()) inv.Stderr = pty.Output()
cmd.SetOut(pty.Output())
cmd.SetErr(pty.Output())
err := cmd.Execute() err := inv.Run()
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "no port-forwards") require.ErrorContains(t, err, "no port-forwards")
@ -133,17 +131,17 @@ func TestPortForward(t *testing.T) {
// Launch port-forward in a goroutine so we can start dialing // Launch port-forward in a goroutine so we can start dialing
// the "local" listener. // the "local" listener.
cmd, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag) inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
errC := make(chan error) errC := make(chan error)
go func() { go func() {
errC <- cmd.ExecuteContext(ctx) errC <- inv.WithContext(ctx).Run()
}() }()
pty.ExpectMatch("Ready!") pty.ExpectMatch("Ready!")
@ -181,17 +179,17 @@ func TestPortForward(t *testing.T) {
// Launch port-forward in a goroutine so we can start dialing // Launch port-forward in a goroutine so we can start dialing
// the "local" listeners. // the "local" listeners.
cmd, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag1, flag2) inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag1, flag2)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
errC := make(chan error) errC := make(chan error)
go func() { go func() {
errC <- cmd.ExecuteContext(ctx) errC <- inv.WithContext(ctx).Run()
}() }()
pty.ExpectMatch("Ready!") pty.ExpectMatch("Ready!")
@ -238,17 +236,15 @@ func TestPortForward(t *testing.T) {
// Launch port-forward in a goroutine so we can start dialing // Launch port-forward in a goroutine so we can start dialing
// the "local" listeners. // the "local" listeners.
cmd, root := clitest.New(t, append([]string{"-v", "port-forward", workspace.Name}, flags...)...) inv, root := clitest.New(t, append([]string{"-v", "port-forward", workspace.Name}, flags...)...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input()) inv.Stderr = pty.Output()
cmd.SetOut(pty.Output())
cmd.SetErr(pty.Output())
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
errC := make(chan error) errC := make(chan error)
go func() { go func() {
errC <- cmd.ExecuteContext(ctx) errC <- inv.WithContext(ctx).Run()
}() }()
pty.ExpectMatch("Ready!") pty.ExpectMatch("Ready!")
@ -304,12 +300,12 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) codersdk.
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
// Start workspace agent in a goroutine // Start workspace agent in a goroutine
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) inv, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
errC := make(chan error) errC := make(chan error)
agentCtx, agentCancel := context.WithCancel(ctx) agentCtx, agentCancel := context.WithCancel(ctx)
t.Cleanup(func() { t.Cleanup(func() {
@ -318,7 +314,7 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) codersdk.
require.NoError(t, err) require.NoError(t, err)
}) })
go func() { go func() {
errC <- cmd.ExecuteContext(agentCtx) errC <- inv.WithContext(agentCtx).Run()
}() }()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)

View File

@ -3,30 +3,26 @@ package cli
import ( import (
"strings" "strings"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func publickey() *cobra.Command { func (r *RootCmd) publickey() *clibase.Cmd {
var reset bool var reset bool
client := new(codersdk.Client)
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "publickey", Use: "publickey",
Aliases: []string{"pubkey"}, Aliases: []string{"pubkey"},
Short: "Output your Coder public key used for Git operations", Short: "Output your Coder public key used for Git operations",
RunE: func(cmd *cobra.Command, args []string) error { Middleware: r.InitClient(client),
client, err := CreateClient(cmd) Handler: func(inv *clibase.Invocation) error {
if err != nil {
return xerrors.Errorf("create codersdk client: %w", err)
}
if reset { if reset {
// Confirm prompt if using --reset. We don't want to accidentally // Confirm prompt if using --reset. We don't want to accidentally
// reset our public key. // reset our public key.
_, err := cliui.Prompt(cmd, cliui.PromptOptions{ _, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm regenerate a new sshkey for your workspaces? This will require updating the key " + Text: "Confirm regenerate a new sshkey for your workspaces? This will require updating the key " +
"on any services it is registered with. This action cannot be reverted.", "on any services it is registered with. This action cannot be reverted.",
IsConfirm: true, IsConfirm: true,
@ -36,33 +32,38 @@ func publickey() *cobra.Command {
} }
// Reset the public key, let the retrieve re-read it. // Reset the public key, let the retrieve re-read it.
_, err = client.RegenerateGitSSHKey(cmd.Context(), codersdk.Me) _, err = client.RegenerateGitSSHKey(inv.Context(), codersdk.Me)
if err != nil { if err != nil {
return err return err
} }
} }
key, err := client.GitSSHKey(cmd.Context(), codersdk.Me) key, err := client.GitSSHKey(inv.Context(), codersdk.Me)
if err != nil { if err != nil {
return xerrors.Errorf("create codersdk client: %w", err) return xerrors.Errorf("create codersdk client: %w", err)
} }
cmd.Println(cliui.Styles.Wrap.Render( cliui.Infof(inv.Stdout,
"This is your public key for using "+cliui.Styles.Field.Render("git")+" in "+ "This is your public key for using "+cliui.Styles.Field.Render("git")+" in "+
"Coder. All clones with SSH will be authenticated automatically 🪄.", "Coder. All clones with SSH will be authenticated automatically 🪄.\n\n",
)) )
cmd.Println() cliui.Infof(inv.Stdout, cliui.Styles.Code.Render(strings.TrimSpace(key.PublicKey))+"\n\n")
cmd.Println(cliui.Styles.Code.Render(strings.TrimSpace(key.PublicKey))) cliui.Infof(inv.Stdout, "Add to GitHub and GitLab:"+"\n")
cmd.Println() cliui.Infof(inv.Stdout, cliui.Styles.Prompt.String()+"https://github.com/settings/ssh/new"+"\n")
cmd.Println("Add to GitHub and GitLab:") cliui.Infof(inv.Stdout, cliui.Styles.Prompt.String()+"https://gitlab.com/-/profile/keys"+"\n")
cmd.Println(cliui.Styles.Prompt.String() + "https://github.com/settings/ssh/new")
cmd.Println(cliui.Styles.Prompt.String() + "https://gitlab.com/-/profile/keys")
return nil return nil
}, },
} }
cmd.Flags().BoolVar(&reset, "reset", false, "Regenerate your public key. This will require updating the key on any services it's registered with.")
cliui.AllowSkipPrompt(cmd) cmd.Options = clibase.OptionSet{
{
Flag: "reset",
Description: "Regenerate your public key. This will require updating the key on any services it's registered with.",
Value: clibase.BoolOf(&reset),
},
cliui.SkipPromptOption(),
}
return cmd return cmd
} }

View File

@ -16,11 +16,11 @@ func TestPublicKey(t *testing.T) {
t.Parallel() t.Parallel()
client := coderdtest.New(t, nil) client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client) _ = coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.New(t, "publickey") inv, root := clitest.New(t, "publickey")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
cmd.SetOut(buf) inv.Stdout = buf
err := cmd.Execute() err := inv.Run()
require.NoError(t, err) require.NoError(t, err)
publicKey := buf.String() publicKey := buf.String()
require.NotEmpty(t, publicKey) require.NotEmpty(t, publicKey)

View File

@ -3,34 +3,34 @@ package cli
import ( import (
"fmt" "fmt"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func rename() *cobra.Command { func (r *RootCmd) rename() *clibase.Cmd {
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "rename <workspace> <new name>", Use: "rename <workspace> <new name>",
Short: "Rename a workspace", Short: "Rename a workspace",
Args: cobra.ExactArgs(2), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(2),
client, err := CreateClient(cmd) r.InitClient(client),
if err != nil { ),
return err Handler: func(inv *clibase.Invocation) error {
} workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
workspace, err := namedWorkspace(cmd, client, args[0])
if err != nil { if err != nil {
return xerrors.Errorf("get workspace: %w", err) return xerrors.Errorf("get workspace: %w", err)
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n\n", _, _ = fmt.Fprintf(inv.Stdout, "%s\n\n",
cliui.Styles.Wrap.Render("WARNING: A rename can result in data loss if a resource references the workspace name in the template (e.g volumes). Please backup any data before proceeding."), cliui.Styles.Wrap.Render("WARNING: A rename can result in data loss if a resource references the workspace name in the template (e.g volumes). Please backup any data before proceeding."),
) )
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "See: %s\n\n", "https://coder.com/docs/coder-oss/latest/templates/resource-persistence#%EF%B8%8F-persistence-pitfalls") _, _ = fmt.Fprintf(inv.Stdout, "See: %s\n\n", "https://coder.com/docs/coder-oss/latest/templates/resource-persistence#%EF%B8%8F-persistence-pitfalls")
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Type %q to confirm rename:", workspace.Name), Text: fmt.Sprintf("Type %q to confirm rename:", workspace.Name),
Validate: func(s string) error { Validate: func(s string) error {
if s == workspace.Name { if s == workspace.Name {
@ -43,17 +43,18 @@ func rename() *cobra.Command {
return err return err
} }
err = client.UpdateWorkspace(cmd.Context(), workspace.ID, codersdk.UpdateWorkspaceRequest{ err = client.UpdateWorkspace(inv.Context(), workspace.ID, codersdk.UpdateWorkspaceRequest{
Name: args[1], Name: inv.Args[1],
}) })
if err != nil { if err != nil {
return xerrors.Errorf("rename workspace: %w", err) return xerrors.Errorf("rename workspace: %w", err)
} }
_, _ = fmt.Fprintf(inv.Stdout, "Workspace %q renamed to %q\n", workspace.Name, inv.Args[1])
return nil return nil
}, },
} }
cliui.AllowSkipPrompt(cmd) cmd.Options = append(cmd.Options, cliui.SkipPromptOption())
return cmd return cmd
} }

View File

@ -5,7 +5,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clitest" "github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/coderdtest"
@ -30,21 +29,15 @@ func TestRename(t *testing.T) {
// Only append one letter because it's easy to exceed maximum length: // Only append one letter because it's easy to exceed maximum length:
// E.g. "compassionate-chandrasekhar82" + "t". // E.g. "compassionate-chandrasekhar82" + "t".
want := workspace.Name + "t" want := workspace.Name + "t"
cmd, root := clitest.New(t, "rename", workspace.Name, want, "--yes") inv, root := clitest.New(t, "rename", workspace.Name, want, "--yes")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) pty.Attach(inv)
cmd.SetOut(pty.Output()) clitest.Start(t, inv)
errC := make(chan error, 1)
go func() {
errC <- cmd.ExecuteContext(ctx)
}()
pty.ExpectMatch("confirm rename:") pty.ExpectMatch("confirm rename:")
pty.WriteLine(workspace.Name) pty.WriteLine(workspace.Name)
pty.ExpectMatch("renamed to")
require.NoError(t, <-errC)
ws, err := client.Workspace(ctx, workspace.ID) ws, err := client.Workspace(ctx, workspace.ID)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -4,25 +4,24 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/migrations" "github.com/coder/coder/coderd/database/migrations"
"github.com/coder/coder/coderd/userpassword" "github.com/coder/coder/coderd/userpassword"
) )
func resetPassword() *cobra.Command { func (*RootCmd) resetPassword() *clibase.Cmd {
var postgresURL string var postgresURL string
root := &cobra.Command{ root := &clibase.Cmd{
Use: "reset-password <username>", Use: "reset-password <username>",
Short: "Directly connect to the database to reset a user's password", Short: "Directly connect to the database to reset a user's password",
Args: cobra.ExactArgs(1), Middleware: clibase.RequireNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
username := args[0] username := inv.Args[0]
sqlDB, err := sql.Open("postgres", postgresURL) sqlDB, err := sql.Open("postgres", postgresURL)
if err != nil { if err != nil {
@ -40,14 +39,14 @@ func resetPassword() *cobra.Command {
} }
db := database.New(sqlDB) db := database.New(sqlDB)
user, err := db.GetUserByEmailOrUsername(cmd.Context(), database.GetUserByEmailOrUsernameParams{ user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{
Username: username, Username: username,
}) })
if err != nil { if err != nil {
return xerrors.Errorf("retrieving user: %w", err) return xerrors.Errorf("retrieving user: %w", err)
} }
password, err := cliui.Prompt(cmd, cliui.PromptOptions{ password, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Enter new " + cliui.Styles.Field.Render("password") + ":", Text: "Enter new " + cliui.Styles.Field.Render("password") + ":",
Secret: true, Secret: true,
Validate: func(s string) error { Validate: func(s string) error {
@ -57,7 +56,7 @@ func resetPassword() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("password prompt: %w", err) return xerrors.Errorf("password prompt: %w", err)
} }
confirmedPassword, err := cliui.Prompt(cmd, cliui.PromptOptions{ confirmedPassword, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm " + cliui.Styles.Field.Render("password") + ":", Text: "Confirm " + cliui.Styles.Field.Render("password") + ":",
Secret: true, Secret: true,
Validate: cliui.ValidateNotEmpty, Validate: cliui.ValidateNotEmpty,
@ -74,7 +73,7 @@ func resetPassword() *cobra.Command {
return xerrors.Errorf("hash password: %w", err) return xerrors.Errorf("hash password: %w", err)
} }
err = db.UpdateUserHashedPassword(cmd.Context(), database.UpdateUserHashedPasswordParams{ err = db.UpdateUserHashedPassword(inv.Context(), database.UpdateUserHashedPasswordParams{
ID: user.ID, ID: user.ID,
HashedPassword: []byte(hashedPassword), HashedPassword: []byte(hashedPassword),
}) })
@ -82,12 +81,19 @@ func resetPassword() *cobra.Command {
return xerrors.Errorf("updating password: %w", err) return xerrors.Errorf("updating password: %w", err)
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nPassword has been reset for user %s!\n", cliui.Styles.Keyword.Render(user.Username)) _, _ = fmt.Fprintf(inv.Stdout, "\nPassword has been reset for user %s!\n", cliui.Styles.Keyword.Render(user.Username))
return nil return nil
}, },
} }
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to") root.Options = clibase.OptionSet{
{
Flag: "postgres-url",
Description: "URL of a PostgreSQL database to connect to.",
Env: "CODER_PG_CONNECTION_URL",
Value: clibase.StringOf(&postgresURL),
},
}
return root return root
} }

View File

@ -37,7 +37,7 @@ func TestResetPassword(t *testing.T) {
defer closeFunc() defer closeFunc()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
serverDone := make(chan struct{}) serverDone := make(chan struct{})
serverCmd, cfg := clitest.New(t, serverinv, cfg := clitest.New(t,
"server", "server",
"--http-address", ":0", "--http-address", ":0",
"--access-url", "http://example.com", "--access-url", "http://example.com",
@ -46,7 +46,7 @@ func TestResetPassword(t *testing.T) {
) )
go func() { go func() {
defer close(serverDone) defer close(serverDone)
err = serverCmd.ExecuteContext(ctx) err = serverinv.WithContext(ctx).Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
var rawURL string var rawURL string
@ -67,15 +67,15 @@ func TestResetPassword(t *testing.T) {
// reset the password // reset the password
resetCmd, cmdCfg := clitest.New(t, "reset-password", "--postgres-url", connectionURL, username) resetinv, cmdCfg := clitest.New(t, "reset-password", "--postgres-url", connectionURL, username)
clitest.SetupConfig(t, client, cmdCfg) clitest.SetupConfig(t, client, cmdCfg)
cmdDone := make(chan struct{}) cmdDone := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t)
resetCmd.SetIn(pty.Input()) resetinv.Stdin = pty.Input()
resetCmd.SetOut(pty.Output()) resetinv.Stdout = pty.Output()
go func() { go func() {
defer close(cmdDone) defer close(cmdDone)
err = resetCmd.Execute() err = resetinv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()

View File

@ -4,23 +4,29 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func restart() *cobra.Command { func (r *RootCmd) restart() *clibase.Cmd {
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "restart <workspace>", Use: "restart <workspace>",
Short: "Restart a workspace", Short: "Restart a workspace",
Args: cobra.ExactArgs(1), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(1),
ctx := cmd.Context() r.InitClient(client),
out := cmd.OutOrStdout() ),
Options: clibase.OptionSet{
cliui.SkipPromptOption(),
},
Handler: func(inv *clibase.Invocation) error {
ctx := inv.Context()
out := inv.Stdout
_, err := cliui.Prompt(cmd, cliui.PromptOptions{ _, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm restart workspace?", Text: "Confirm restart workspace?",
IsConfirm: true, IsConfirm: true,
}) })
@ -28,11 +34,7 @@ func restart() *cobra.Command {
return err return err
} }
client, err := CreateClient(cmd) workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil {
return err
}
workspace, err := namedWorkspace(cmd, client, args[0])
if err != nil { if err != nil {
return err return err
} }
@ -63,6 +65,5 @@ func restart() *cobra.Command {
return nil return nil
}, },
} }
cliui.AllowSkipPrompt(cmd)
return cmd return cmd
} }

View File

@ -25,18 +25,16 @@ func TestRestart(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
cmd, root := clitest.New(t, "restart", workspace.Name, "--yes") inv, root := clitest.New(t, "restart", workspace.Name, "--yes")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- cmd.ExecuteContext(ctx) done <- inv.WithContext(ctx).Run()
}() }()
pty.ExpectMatch("Stopping workspace") pty.ExpectMatch("Stopping workspace")
pty.ExpectMatch("Starting workspace") pty.ExpectMatch("Starting workspace")

View File

@ -1,32 +1,37 @@
package cli package cli
import ( import (
"bufio"
"context" "context"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
"text/template"
"time" "time"
"unicode/utf8"
"golang.org/x/crypto/ssh/terminal"
"golang.org/x/exp/slices"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/mattn/go-isatty" "github.com/mattn/go-isatty"
"github.com/spf13/cobra"
"github.com/coder/coder/buildinfo" "github.com/coder/coder/buildinfo"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/cli/config" "github.com/coder/coder/cli/config"
"github.com/coder/coder/coderd" "github.com/coder/coder/coderd"
@ -66,84 +71,82 @@ const (
var errUnauthenticated = xerrors.New(notLoggedInMessage) var errUnauthenticated = xerrors.New(notLoggedInMessage)
func init() { func (r *RootCmd) Core() []*clibase.Cmd {
// Set cobra template functions in init to avoid conflicts in tests.
cobra.AddTemplateFuncs(templateFunctions)
}
func Core() []*cobra.Command {
// Please re-sort this list alphabetically if you change it! // Please re-sort this list alphabetically if you change it!
return []*cobra.Command{ return []*clibase.Cmd{
configSSH(), r.dotfiles(),
create(), r.login(),
deleteWorkspace(), r.logout(),
dotfiles(), r.portForward(),
gitssh(), r.publickey(),
list(), r.resetPassword(),
login(), r.state(),
logout(), r.templates(),
parameters(), r.users(),
ping(), r.tokens(),
portForward(), r.version(),
publickey(),
rename(), // Workspace Commands
resetPassword(), r.configSSH(),
restart(), r.rename(),
scaletest(), r.ping(),
schedules(), r.create(),
show(), r.deleteWorkspace(),
speedtest(), r.list(),
ssh(), r.schedules(),
start(), r.show(),
state(), r.speedtest(),
stop(), r.ssh(),
templates(), r.start(),
tokens(), r.stop(),
update(), r.update(),
users(), r.restart(),
versionCmd(), r.parameters(),
vscodeSSH(),
workspaceAgent(), // Hidden
r.workspaceAgent(),
r.scaletest(),
r.gitssh(),
r.vscodeSSH(),
} }
} }
func AGPL() []*cobra.Command { func (r *RootCmd) AGPL() []*clibase.Cmd {
all := append(Core(), Server(func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) { all := append(r.Core(), r.Server(func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) {
api := coderd.New(o) api := coderd.New(o)
return api, api, nil return api, api, nil
})) }))
return all return all
} }
func Root(subcommands []*cobra.Command) *cobra.Command { // Main is the entrypoint for the Coder CLI.
// The GIT_ASKPASS environment variable must point at func (r *RootCmd) RunMain(subcommands []*clibase.Cmd) {
// a binary with no arguments. To prevent writing rand.Seed(time.Now().UnixMicro())
// cross-platform scripts to invoke the Coder binary
// with a `gitaskpass` subcommand, we override the entrypoint
// to check if the command was invoked.
isGitAskpass := false
cmd, err := r.Command(subcommands)
if err != nil {
panic(err)
}
err = cmd.Invoke().WithOS().Run()
if err != nil {
if errors.Is(err, cliui.Canceled) {
//nolint:revive
os.Exit(1)
}
f := prettyErrorFormatter{w: os.Stderr}
f.format(err)
//nolint:revive
os.Exit(1)
}
}
func (r *RootCmd) Command(subcommands []*clibase.Cmd) (*clibase.Cmd, error) {
fmtLong := `Coder %s A tool for provisioning self-hosted development environments with Terraform. fmtLong := `Coder %s A tool for provisioning self-hosted development environments with Terraform.
` `
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "coder", Use: "coder [global-flags] <subcommand>",
SilenceErrors: true, Long: fmt.Sprintf(fmtLong, buildinfo.Version()) + formatExamples(
SilenceUsage: true,
Long: fmt.Sprintf(fmtLong, buildinfo.Version()),
Args: func(cmd *cobra.Command, args []string) error {
if gitauth.CheckCommand(args, os.Environ()) {
isGitAskpass = true
return nil
}
return cobra.NoArgs(cmd, args)
},
RunE: func(cmd *cobra.Command, args []string) error {
if isGitAskpass {
return gitAskpass().RunE(cmd, args)
}
return cmd.Help()
},
Example: formatExamples(
example{ example{
Description: "Start a Coder server", Description: "Start a Coder server",
Command: "coder server", Command: "coder server",
@ -153,30 +156,204 @@ func Root(subcommands []*cobra.Command) *cobra.Command {
Command: "coder templates init", Command: "coder templates init",
}, },
), ),
Handler: func(i *clibase.Invocation) error {
// fmt.Fprintf(i.Stderr, "env debug: %+v", i.Environ)
// The GIT_ASKPASS environment variable must point at
// a binary with no arguments. To prevent writing
// cross-platform scripts to invoke the Coder binary
// with a `gitaskpass` subcommand, we override the entrypoint
// to check if the command was invoked.
if gitauth.CheckCommand(i.Args, i.Environ.ToOS()) {
return r.gitAskpass().Handler(i)
}
return i.Command.HelpHandler(i)
},
} }
cmd.AddCommand(subcommands...) cmd.AddSubcommands(subcommands...)
fixUnknownSubcommandError(cmd.Commands())
cmd.SetUsageTemplate(usageTemplateCobra()) // Set default help handler for all commands.
cmd.Walk(func(c *clibase.Cmd) {
if c.HelpHandler == nil {
c.HelpHandler = helpFn()
}
})
cliflag.String(cmd.PersistentFlags(), varURL, "", envURL, "", "URL to a deployment.") var merr error
cliflag.Bool(cmd.PersistentFlags(), varNoVersionCheck, "", envNoVersionCheck, false, "Suppress warning when client and server versions do not match.") // Add [flags] to usage when appropriate.
cliflag.Bool(cmd.PersistentFlags(), varNoFeatureWarning, "", envNoFeatureWarning, false, "Suppress warnings about unlicensed features.") cmd.Walk(func(cmd *clibase.Cmd) {
cliflag.String(cmd.PersistentFlags(), varToken, "", envSessionToken, "", fmt.Sprintf("Specify an authentication token. For security reasons setting %s is preferred.", envSessionToken)) const flags = "[flags]"
cliflag.String(cmd.PersistentFlags(), varAgentToken, "", "CODER_AGENT_TOKEN", "", "An agent authentication token.") if strings.Contains(cmd.Use, flags) {
_ = cmd.PersistentFlags().MarkHidden(varAgentToken) merr = errors.Join(
cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "URL for an agent to access your deployment.") merr,
_ = cmd.PersistentFlags().MarkHidden(varAgentURL) xerrors.Errorf(
cliflag.String(cmd.PersistentFlags(), config.FlagName, "", "CODER_CONFIG_DIR", config.DefaultDir(), "Path to the global `coder` config directory.") "command %q shouldn't have %q in usage since it's automatically populated",
cliflag.StringArray(cmd.PersistentFlags(), varHeader, "", "CODER_HEADER", []string{}, "HTTP headers added to all requests. Provide as \"Key=Value\"") cmd.FullUsage(),
cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY.") flags,
_ = cmd.PersistentFlags().MarkHidden(varForceTty) ),
cmd.PersistentFlags().Bool(varNoOpen, false, "Block automatically opening URLs in the browser.") )
_ = cmd.PersistentFlags().MarkHidden(varNoOpen) return
cliflag.Bool(cmd.PersistentFlags(), varVerbose, "v", "CODER_VERBOSE", false, "Enable verbose output.") }
return cmd var hasFlag bool
for _, opt := range cmd.Options {
if opt.Flag != "" {
hasFlag = true
break
}
}
if !hasFlag {
return
}
// We insert [flags] between the command's name and its arguments.
tokens := strings.SplitN(cmd.Use, " ", 2)
if len(tokens) == 1 {
cmd.Use = fmt.Sprintf("%s %s", tokens[0], flags)
return
}
cmd.Use = fmt.Sprintf("%s %s %s", tokens[0], flags, tokens[1])
})
// Add alises when appropriate.
cmd.Walk(func(cmd *clibase.Cmd) {
// TODO: we should really be consistent about naming.
if cmd.Name() == "delete" || cmd.Name() == "remove" {
if slices.Contains(cmd.Aliases, "rm") {
merr = errors.Join(
merr,
xerrors.Errorf("command %q shouldn't have alias %q since it's added automatically", cmd.FullName(), "rm"),
)
return
}
cmd.Aliases = append(cmd.Aliases, "rm")
}
})
// Sanity-check command options.
cmd.Walk(func(cmd *clibase.Cmd) {
for _, opt := range cmd.Options {
// Verify that every option is configurable.
if opt.Flag == "" && opt.Env == "" {
if cmd.Name() == "server" {
// The server command is funky and has YAML-only options, e.g.
// support links.
return
}
merr = errors.Join(
merr,
xerrors.Errorf("option %q in %q should have a flag or env", opt.Name, cmd.FullName()),
)
}
}
})
if merr != nil {
return nil, merr
}
if r.agentURL == nil {
r.agentURL = new(url.URL)
}
if r.clientURL == nil {
r.clientURL = new(url.URL)
}
globalGroup := &clibase.Group{
Name: "Global",
Description: `Global options are applied to all commands. They can be set using environment variables or flags.`,
}
cmd.Options = clibase.OptionSet{
{
Flag: varURL,
Env: envURL,
Description: "URL to a deployment.",
Value: clibase.URLOf(r.clientURL),
Group: globalGroup,
},
{
Flag: varToken,
Env: envSessionToken,
Description: fmt.Sprintf("Specify an authentication token. For security reasons setting %s is preferred.", envSessionToken),
Value: clibase.StringOf(&r.token),
Group: globalGroup,
},
{
Flag: varAgentToken,
Description: "An agent authentication token.",
Value: clibase.StringOf(&r.agentToken),
Hidden: true,
Group: globalGroup,
},
{
Flag: varAgentURL,
Env: "CODER_AGENT_URL",
Description: "URL for an agent to access your deployment.",
Value: clibase.URLOf(r.agentURL),
Hidden: true,
Group: globalGroup,
},
{
Flag: varNoVersionCheck,
Env: envNoVersionCheck,
Description: "Suppress warning when client and server versions do not match.",
Value: clibase.BoolOf(&r.noVersionCheck),
Group: globalGroup,
},
{
Flag: varNoFeatureWarning,
Env: envNoFeatureWarning,
Description: "Suppress warnings about unlicensed features.",
Value: clibase.BoolOf(&r.noFeatureWarning),
Group: globalGroup,
},
{
Flag: varHeader,
Env: "CODER_HEADER",
Description: "Additional HTTP headers added to all requests. Provide as " + `key=value` + ". Can be specified multiple times.",
Value: clibase.StringArrayOf(&r.header),
Group: globalGroup,
},
{
Flag: varNoOpen,
Env: "CODER_NO_OPEN",
Description: "Suppress opening the browser after logging in.",
Value: clibase.BoolOf(&r.noOpen),
Hidden: true,
Group: globalGroup,
},
{
Flag: varForceTty,
Env: "CODER_FORCE_TTY",
Hidden: true,
Description: "Force the use of a TTY.",
Value: clibase.BoolOf(&r.forceTTY),
Group: globalGroup,
},
{
Flag: varVerbose,
FlagShorthand: "v",
Env: "CODER_VERBOSE",
Description: "Enable verbose output.",
Value: clibase.BoolOf(&r.verbose),
Group: globalGroup,
},
{
Flag: config.FlagName,
Env: "CODER_CONFIG_DIR",
Description: "Path to the global `coder` config directory.",
Default: config.DefaultDir(),
Value: clibase.StringOf(&r.globalConfig),
Group: globalGroup,
},
}
err := cmd.PrepareAll()
if err != nil {
return nil, err
}
return cmd, nil
} }
type contextKey int type contextKey int
@ -194,41 +371,12 @@ func LoggerFromContext(ctx context.Context) (slog.Logger, bool) {
return l, ok return l, ok
} }
// fixUnknownSubcommandError modifies the provided commands so that the // version prints the coder version
// ones with subcommands output the correct error message when an func (*RootCmd) version() *clibase.Cmd {
// unknown subcommand is invoked. return &clibase.Cmd{
//
// Example:
//
// unknown command "bad" for "coder templates"
func fixUnknownSubcommandError(commands []*cobra.Command) {
for _, sc := range commands {
if sc.HasSubCommands() {
if sc.Run == nil && sc.RunE == nil {
if sc.Args != nil {
// In case the developer does not know about this
// behavior in Cobra they must verify correct
// behavior. For instance, settings Args to
// `cobra.ExactArgs(0)` will not give the same
// message as `cobra.NoArgs`. Likewise, omitting the
// run function will not give the wanted error.
panic("developer error: subcommand has subcommands and Args but no Run or RunE")
}
sc.Args = cobra.NoArgs
sc.Run = func(*cobra.Command, []string) {}
}
fixUnknownSubcommandError(sc.Commands())
}
}
}
// versionCmd prints the coder version
func versionCmd() *cobra.Command {
return &cobra.Command{
Use: "version", Use: "version",
Short: "Show coder version", Short: "Show coder version",
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
var str strings.Builder var str strings.Builder
_, _ = str.WriteString("Coder ") _, _ = str.WriteString("Coder ")
if buildinfo.IsAGPL() { if buildinfo.IsAGPL() {
@ -247,7 +395,7 @@ func versionCmd() *cobra.Command {
_, _ = str.WriteString(fmt.Sprintf("Full build of Coder, supports the %s subcommand.\n", cliui.Styles.Code.Render("server"))) _, _ = str.WriteString(fmt.Sprintf("Full build of Coder, supports the %s subcommand.\n", cliui.Styles.Code.Render("server")))
} }
_, _ = fmt.Fprint(cmd.OutOrStdout(), str.String()) _, _ = fmt.Fprint(inv.Stdout, str.String())
return nil return nil
}, },
} }
@ -257,41 +405,68 @@ func isTest() bool {
return flag.Lookup("test.v") != nil return flag.Lookup("test.v") != nil
} }
// CreateClient returns a new client from the command context. // RootCmd contains parameters and helpers useful to all commands.
type RootCmd struct {
clientURL *url.URL
token string
globalConfig string
header []string
agentToken string
agentURL *url.URL
forceTTY bool
noOpen bool
verbose bool
noVersionCheck bool
noFeatureWarning bool
}
// InitClient sets client to a new client.
// It reads from global configuration files if flags are not set. // It reads from global configuration files if flags are not set.
func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) { func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc {
root := createConfig(cmd) if client == nil {
rawURL, err := cmd.Flags().GetString(varURL) panic("client is nil")
if err != nil || rawURL == "" { }
rawURL, err = root.URL().Read() if r == nil {
if err != nil { panic("root is nil")
}
return func(next clibase.HandlerFunc) clibase.HandlerFunc {
return func(i *clibase.Invocation) error {
conf := r.createConfig()
var err error
if r.clientURL == nil || r.clientURL.String() == "" {
rawURL, err := conf.URL().Read()
// If the configuration files are absent, the user is logged out // If the configuration files are absent, the user is logged out
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil, errUnauthenticated return (errUnauthenticated)
} }
return nil, err
}
}
serverURL, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil { if err != nil {
return nil, err return err
} }
token, err := cmd.Flags().GetString(varToken)
if err != nil || token == "" { r.clientURL, err = url.Parse(strings.TrimSpace(rawURL))
token, err = root.Session().Read()
if err != nil { if err != nil {
return err
}
}
if r.token == "" {
r.token, err = conf.Session().Read()
// If the configuration files are absent, the user is logged out // If the configuration files are absent, the user is logged out
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil, errUnauthenticated return (errUnauthenticated)
} }
return nil, err
}
}
client, err := createUnauthenticatedClient(cmd, serverURL)
if err != nil { if err != nil {
return nil, err return err
} }
client.SetSessionToken(token) }
err = r.setClient(client, r.clientURL)
if err != nil {
return err
}
client.SetSessionToken(r.token)
// We send these requests in parallel to minimize latency. // We send these requests in parallel to minimize latency.
var ( var (
@ -299,77 +474,71 @@ func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) {
warningErr = make(chan error) warningErr = make(chan error)
) )
go func() { go func() {
versionErr <- checkVersions(cmd, client) versionErr <- r.checkVersions(i, client)
close(versionErr) close(versionErr)
}() }()
go func() { go func() {
warningErr <- checkWarnings(cmd, client) warningErr <- r.checkWarnings(i, client)
close(warningErr) close(warningErr)
}() }()
if err = <-versionErr; err != nil { if err = <-versionErr; err != nil {
// Just log the error here. We never want to fail a command // Just log the error here. We never want to fail a command
// due to a pre-run. // due to a pre-run.
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), _, _ = fmt.Fprintf(i.Stderr,
cliui.Styles.Warn.Render("check versions error: %s"), err) cliui.Styles.Warn.Render("check versions error: %s"), err)
_, _ = fmt.Fprintln(cmd.ErrOrStderr()) _, _ = fmt.Fprintln(i.Stderr)
} }
if err = <-warningErr; err != nil { if err = <-warningErr; err != nil {
// Same as above // Same as above
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), _, _ = fmt.Fprintf(i.Stderr,
cliui.Styles.Warn.Render("check entitlement warnings error: %s"), err) cliui.Styles.Warn.Render("check entitlement warnings error: %s"), err)
_, _ = fmt.Fprintln(cmd.ErrOrStderr()) _, _ = fmt.Fprintln(i.Stderr)
} }
return client, nil return next(i)
}
}
} }
func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*codersdk.Client, error) { func (r *RootCmd) setClient(client *codersdk.Client, serverURL *url.URL) error {
client := codersdk.New(serverURL)
headers, err := cmd.Flags().GetStringArray(varHeader)
if err != nil {
return nil, err
}
transport := &headerTransport{ transport := &headerTransport{
transport: http.DefaultTransport, transport: http.DefaultTransport,
header: http.Header{}, header: http.Header{},
} }
for _, header := range headers { for _, header := range r.header {
parts := strings.SplitN(header, "=", 2) parts := strings.SplitN(header, "=", 2)
if len(parts) < 2 { if len(parts) < 2 {
return nil, xerrors.Errorf("split header %q had less than two parts", header) return xerrors.Errorf("split header %q had less than two parts", header)
} }
transport.header.Add(parts[0], parts[1]) transport.header.Add(parts[0], parts[1])
} }
client.HTTPClient.Transport = transport client.URL = serverURL
return client, nil client.HTTPClient = &http.Client{
Transport: transport,
}
return nil
}
func (r *RootCmd) createUnauthenticatedClient(serverURL *url.URL) (*codersdk.Client, error) {
var client codersdk.Client
err := r.setClient(&client, serverURL)
return &client, err
} }
// createAgentClient returns a new client from the command context. // createAgentClient returns a new client from the command context.
// It works just like CreateClient, but uses the agent token and URL instead. // It works just like CreateClient, but uses the agent token and URL instead.
func createAgentClient(cmd *cobra.Command) (*agentsdk.Client, error) { func (r *RootCmd) createAgentClient() (*agentsdk.Client, error) {
rawURL, err := cmd.Flags().GetString(varAgentURL) client := agentsdk.New(r.agentURL)
if err != nil { client.SetSessionToken(r.agentToken)
return nil, err
}
serverURL, err := url.Parse(rawURL)
if err != nil {
return nil, err
}
token, err := cmd.Flags().GetString(varAgentToken)
if err != nil {
return nil, err
}
client := agentsdk.New(serverURL)
client.SetSessionToken(token)
return client, nil return client, nil
} }
// CurrentOrganization returns the currently active organization for the authenticated user. // CurrentOrganization returns the currently active organization for the authenticated user.
func CurrentOrganization(cmd *cobra.Command, client *codersdk.Client) (codersdk.Organization, error) { func CurrentOrganization(inv *clibase.Invocation, client *codersdk.Client) (codersdk.Organization, error) {
orgs, err := client.OrganizationsByUser(cmd.Context(), codersdk.Me) orgs, err := client.OrganizationsByUser(inv.Context(), codersdk.Me)
if err != nil { if err != nil {
return codersdk.Organization{}, nil return codersdk.Organization{}, nil
} }
@ -381,7 +550,7 @@ func CurrentOrganization(cmd *cobra.Command, client *codersdk.Client) (codersdk.
// namedWorkspace fetches and returns a workspace by an identifier, which may be either // namedWorkspace fetches and returns a workspace by an identifier, which may be either
// a bare name (for a workspace owned by the current user) or a "user/workspace" combination, // a bare name (for a workspace owned by the current user) or a "user/workspace" combination,
// where user is either a username or UUID. // where user is either a username or UUID.
func namedWorkspace(cmd *cobra.Command, client *codersdk.Client, identifier string) (codersdk.Workspace, error) { func namedWorkspace(ctx context.Context, client *codersdk.Client, identifier string) (codersdk.Workspace, error) {
parts := strings.Split(identifier, "/") parts := strings.Split(identifier, "/")
var owner, name string var owner, name string
@ -396,30 +565,24 @@ func namedWorkspace(cmd *cobra.Command, client *codersdk.Client, identifier stri
return codersdk.Workspace{}, xerrors.Errorf("invalid workspace name: %q", identifier) return codersdk.Workspace{}, xerrors.Errorf("invalid workspace name: %q", identifier)
} }
return client.WorkspaceByOwnerAndName(cmd.Context(), owner, name, codersdk.WorkspaceOptions{}) return client.WorkspaceByOwnerAndName(ctx, owner, name, codersdk.WorkspaceOptions{})
} }
// createConfig consumes the global configuration flag to produce a config root. // createConfig consumes the global configuration flag to produce a config root.
func createConfig(cmd *cobra.Command) config.Root { func (r *RootCmd) createConfig() config.Root {
globalRoot, err := cmd.Flags().GetString(config.FlagName) return config.Root(r.globalConfig)
if err != nil {
panic(err)
}
return config.Root(globalRoot)
} }
// isTTY returns whether the passed reader is a TTY or not. // isTTY returns whether the passed reader is a TTY or not.
// This accepts a reader to work with Cobra's "InOrStdin" func isTTY(inv *clibase.Invocation) bool {
// function for simple testing.
func isTTY(cmd *cobra.Command) bool {
// If the `--force-tty` command is available, and set, // If the `--force-tty` command is available, and set,
// assume we're in a tty. This is primarily for cases on Windows // assume we're in a tty. This is primarily for cases on Windows
// where we may not be able to reliably detect this automatically (ie, tests) // where we may not be able to reliably detect this automatically (ie, tests)
forceTty, err := cmd.Flags().GetBool(varForceTty) forceTty, err := inv.ParsedFlags().GetBool(varForceTty)
if forceTty && err == nil { if forceTty && err == nil {
return true return true
} }
file, ok := cmd.InOrStdin().(*os.File) file, ok := inv.Stdin.(*os.File)
if !ok { if !ok {
return false return false
} }
@ -427,125 +590,30 @@ func isTTY(cmd *cobra.Command) bool {
} }
// isTTYOut returns whether the passed reader is a TTY or not. // isTTYOut returns whether the passed reader is a TTY or not.
// This accepts a reader to work with Cobra's "OutOrStdout" func isTTYOut(inv *clibase.Invocation) bool {
// function for simple testing. return isTTYWriter(inv, inv.Stdout)
func isTTYOut(cmd *cobra.Command) bool {
return isTTYWriter(cmd, cmd.OutOrStdout)
} }
// isTTYErr returns whether the passed reader is a TTY or not. // isTTYErr returns whether the passed reader is a TTY or not.
// This accepts a reader to work with Cobra's "ErrOrStderr" func isTTYErr(inv *clibase.Invocation) bool {
// function for simple testing. return isTTYWriter(inv, inv.Stderr)
func isTTYErr(cmd *cobra.Command) bool {
return isTTYWriter(cmd, cmd.ErrOrStderr)
} }
func isTTYWriter(cmd *cobra.Command, writer func() io.Writer) bool { func isTTYWriter(inv *clibase.Invocation, writer io.Writer) bool {
// If the `--force-tty` command is available, and set, // If the `--force-tty` command is available, and set,
// assume we're in a tty. This is primarily for cases on Windows // assume we're in a tty. This is primarily for cases on Windows
// where we may not be able to reliably detect this automatically (ie, tests) // where we may not be able to reliably detect this automatically (ie, tests)
forceTty, err := cmd.Flags().GetBool(varForceTty) forceTty, err := inv.ParsedFlags().GetBool(varForceTty)
if forceTty && err == nil { if forceTty && err == nil {
return true return true
} }
file, ok := writer().(*os.File) file, ok := writer.(*os.File)
if !ok { if !ok {
return false return false
} }
return isatty.IsTerminal(file.Fd()) return isatty.IsTerminal(file.Fd())
} }
var templateFunctions = template.FuncMap{
"usageHeader": usageHeader,
"isWorkspaceCommand": isWorkspaceCommand,
}
func usageHeader(s string) string {
// Customizes the color of headings to make subcommands more visually
// appealing.
return cliui.Styles.Placeholder.Render(s)
}
func isWorkspaceCommand(cmd *cobra.Command) bool {
if _, ok := cmd.Annotations["workspaces"]; ok {
return true
}
var ws bool
cmd.VisitParents(func(cmd *cobra.Command) {
if _, ok := cmd.Annotations["workspaces"]; ok {
ws = true
}
})
return ws
}
// We will eventually replace this with the clibase template describedc
// in usage.go. We don't want to continue working around
// Cobra's feature-set.
func usageTemplateCobra() string {
// usageHeader is defined in init().
return `{{usageHeader "Usage:"}}
{{- if .Runnable}}
{{.UseLine}}
{{end}}
{{- if .HasAvailableSubCommands}}
{{.CommandPath}} [command]
{{end}}
{{- if gt (len .Aliases) 0}}
{{usageHeader "Aliases:"}}
{{.NameAndAliases}}
{{end}}
{{- if .HasExample}}
{{usageHeader "Get Started:"}}
{{.Example}}
{{end}}
{{- $isRootHelp := (not .HasParent)}}
{{- if .HasAvailableSubCommands}}
{{usageHeader "Commands:"}}
{{- range .Commands}}
{{- $isRootWorkspaceCommand := (and $isRootHelp (isWorkspaceCommand .))}}
{{- if (or (and .IsAvailableCommand (not $isRootWorkspaceCommand)) (eq .Name "help"))}}
{{rpad .Name .NamePadding }} {{.Short}}
{{- end}}
{{- end}}
{{end}}
{{- if (and $isRootHelp .HasAvailableSubCommands)}}
{{usageHeader "Workspace Commands:"}}
{{- range .Commands}}
{{- if (and .IsAvailableCommand (isWorkspaceCommand .))}}
{{rpad .Name .NamePadding }} {{.Short}}
{{- end}}
{{- end}}
{{end}}
{{- if .HasAvailableLocalFlags}}
{{usageHeader "Flags:"}}
{{.LocalFlags.FlagUsagesWrapped 100 | trimTrailingWhitespaces}}
{{end}}
{{- if .HasAvailableInheritedFlags}}
{{usageHeader "Global Flags:"}}
{{.InheritedFlags.FlagUsagesWrapped 100 | trimTrailingWhitespaces}}
{{end}}
{{- if .HasHelpSubCommands}}
{{usageHeader "Additional help topics:"}}
{{- range .Commands}}
{{- if .IsAdditionalHelpTopicCommand}}
{{rpad .CommandPath .CommandPathPadding}} {{.Short}}
{{- end}}
{{- end}}
{{end}}
{{- if .HasAvailableSubCommands}}
Use "{{.CommandPath}} [command] --help" for more information about a command.
{{end}}`
}
// example represents a standard example for command usage, to be used // example represents a standard example for command usage, to be used
// with formatExamples. // with formatExamples.
type example struct { type example struct {
@ -574,36 +642,12 @@ func formatExamples(examples ...example) string {
return sb.String() return sb.String()
} }
// FormatCobraError colorizes and adds "--help" docs to cobra commands. func (r *RootCmd) checkVersions(i *clibase.Invocation, client *codersdk.Client) error {
func FormatCobraError(err error, cmd *cobra.Command) string { if r.noVersionCheck {
helpErrMsg := fmt.Sprintf("Run '%s --help' for usage.", cmd.CommandPath())
var (
httpErr *codersdk.Error
output strings.Builder
)
if xerrors.As(err, &httpErr) {
_, _ = fmt.Fprintln(&output, httpErr.Friendly())
}
// If the httpErr is nil then we just have a regular error in which
// case we want to print out what's happening.
if httpErr == nil || cliflag.IsSetBool(cmd, varVerbose) {
_, _ = fmt.Fprintln(&output, err.Error())
}
_, _ = fmt.Fprint(&output, helpErrMsg)
return cliui.Styles.Error.Render(output.String())
}
func checkVersions(cmd *cobra.Command, client *codersdk.Client) error {
if cliflag.IsSetBool(cmd, varNoVersionCheck) {
return nil return nil
} }
ctx, cancel := context.WithTimeout(cmd.Context(), 10*time.Second) ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second)
defer cancel() defer cancel()
clientVersion := buildinfo.Version() clientVersion := buildinfo.Version()
@ -629,25 +673,25 @@ func checkVersions(cmd *cobra.Command, client *codersdk.Client) error {
if !buildinfo.VersionsMatch(clientVersion, info.Version) { if !buildinfo.VersionsMatch(clientVersion, info.Version) {
warn := cliui.Styles.Warn.Copy().Align(lipgloss.Left) warn := cliui.Styles.Warn.Copy().Align(lipgloss.Left)
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), warn.Render(fmtWarningText), clientVersion, info.Version, strings.TrimPrefix(info.CanonicalVersion(), "v")) _, _ = fmt.Fprintf(i.Stderr, warn.Render(fmtWarningText), clientVersion, info.Version, strings.TrimPrefix(info.CanonicalVersion(), "v"))
_, _ = fmt.Fprintln(cmd.ErrOrStderr()) _, _ = fmt.Fprintln(i.Stderr)
} }
return nil return nil
} }
func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error { func (r *RootCmd) checkWarnings(i *clibase.Invocation, client *codersdk.Client) error {
if cliflag.IsSetBool(cmd, varNoFeatureWarning) { if r.noFeatureWarning {
return nil return nil
} }
ctx, cancel := context.WithTimeout(cmd.Context(), 10*time.Second) ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second)
defer cancel() defer cancel()
entitlements, err := client.Entitlements(ctx) entitlements, err := client.Entitlements(ctx)
if err == nil { if err == nil {
for _, w := range entitlements.Warnings { for _, w := range entitlements.Warnings {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w)) _, _ = fmt.Fprintln(i.Stderr, cliui.Styles.Warn.Render(w))
} }
} }
return nil return nil
@ -773,3 +817,94 @@ func isConnectionError(err error) bool {
return xerrors.As(err, &dnsErr) || xerrors.As(err, &opErr) return xerrors.As(err, &dnsErr) || xerrors.As(err, &opErr)
} }
type prettyErrorFormatter struct {
level int
w io.Writer
}
func (prettyErrorFormatter) prefixLines(spaces int, s string) string {
twidth, _, err := terminal.GetSize(0)
if err != nil {
twidth = 80
}
s = lipgloss.NewStyle().Width(twidth - spaces).Render(s)
var b strings.Builder
scanner := bufio.NewScanner(strings.NewReader(s))
for i := 0; scanner.Scan(); i++ {
// The first line is already padded.
if i == 0 {
_, _ = fmt.Fprintf(&b, "%s\n", scanner.Text())
continue
}
_, _ = fmt.Fprintf(&b, "%s%s\n", strings.Repeat(" ", spaces), scanner.Text())
}
return strings.TrimSuffix(strings.TrimSuffix(b.String(), "\n"), " ")
}
func (p *prettyErrorFormatter) format(err error) {
underErr := errors.Unwrap(err)
arrowStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("#515151"))
//nolint:errorlint
if _, ok := err.(*clibase.RunCommandError); ok && p.level == 0 && underErr != nil {
// We can do a better job now.
p.format(underErr)
return
}
var (
padding string
arrowWidth int
)
if p.level > 0 {
const arrow = "┗━ "
arrowWidth = utf8.RuneCount([]byte(arrow))
padding = strings.Repeat(" ", arrowWidth*p.level)
_, _ = fmt.Fprintf(p.w, "%v%v", padding, arrowStyle.Render(arrow))
}
if underErr != nil {
header := strings.TrimSuffix(err.Error(), ": "+underErr.Error())
_, _ = fmt.Fprintf(p.w, "%s\n", p.prefixLines(len(padding)+arrowWidth, header))
p.level++
p.format(underErr)
return
}
{
style := lipgloss.NewStyle().Foreground(lipgloss.Color("#D16644")).Background(lipgloss.Color("#000000")).Bold(false)
// This is the last error in a tree.
p.wrappedPrintf(
"%s\n",
p.prefixLines(
len(padding)+arrowWidth,
fmt.Sprintf(
"%s%s%s",
lipgloss.NewStyle().Inherit(style).Underline(true).Render("ERROR"),
lipgloss.NewStyle().Inherit(style).Foreground(arrowStyle.GetForeground()).Render(" ► "),
style.Render(err.Error()),
),
),
)
}
}
func (p *prettyErrorFormatter) wrappedPrintf(format string, a ...interface{}) {
s := lipgloss.NewStyle().Width(ttyWidth()).Render(
fmt.Sprintf(format, a...),
)
// Not sure why, but lipgloss is adding extra spaces we need to remove.
excessSpaceRe := regexp.MustCompile(`[[:blank:]]*\n[[:blank:]]*$`)
s = excessSpaceRe.ReplaceAllString(s, "\n")
_, _ = p.w.Write(
[]byte(
s,
),
)
}

View File

@ -24,11 +24,11 @@ func Test_formatExamples(t *testing.T) {
name: "Output examples", name: "Output examples",
examples: []example{ examples: []example{
{ {
Description: "Hello world", Description: "Hello world.",
Command: "echo hello", Command: "echo hello",
}, },
{ {
Description: "Bye bye", Description: "Bye bye.",
Command: "echo bye", Command: "echo bye",
}, },
}, },
@ -73,5 +73,7 @@ func TestMain(m *testing.M) {
// https://github.com/natefinch/lumberjack/pull/100 // https://github.com/natefinch/lumberjack/pull/100
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"), goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"),
// The pq library appears to leave around a goroutine after Close().
goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"),
) )
} }

View File

@ -10,18 +10,17 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"runtime"
"strings" "strings"
"testing" "testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/buildinfo" "github.com/coder/coder/buildinfo"
"github.com/coder/coder/cli" "github.com/coder/coder/cli"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/clitest" "github.com/coder/coder/cli/clitest"
"github.com/coder/coder/cli/config"
"github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database/dbtestutil" "github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
@ -34,39 +33,26 @@ var updateGoldenFiles = flag.Bool("update", false, "update .golden files")
var timestampRegex = regexp.MustCompile(`(?i)\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(.\d+)?Z`) var timestampRegex = regexp.MustCompile(`(?i)\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(.\d+)?Z`)
//nolint:tparallel,paralleltest // These test sets env vars.
func TestCommandHelp(t *testing.T) { func TestCommandHelp(t *testing.T) {
commonEnv := map[string]string{ t.Parallel()
"HOME": "~",
"CODER_CONFIG_DIR": "~/.config/coderv2",
}
rootClient, replacements := prepareTestData(t) rootClient, replacements := prepareTestData(t)
type testCase struct { type testCase struct {
name string name string
cmd []string cmd []string
env map[string]string
} }
tests := []testCase{ tests := []testCase{
{ {
name: "coder --help", name: "coder --help",
cmd: []string{"--help"}, cmd: []string{"--help"},
}, },
// Re-enable after clibase migrations. {
// { name: "coder server --help",
// name: "coder server --help", cmd: []string{"server", "--help"},
// cmd: []string{"server", "--help"}, },
// env: map[string]string{
// "CODER_CACHE_DIRECTORY": "~/.cache/coder",
// },
// },
{ {
name: "coder agent --help", name: "coder agent --help",
cmd: []string{"agent", "--help"}, cmd: []string{"agent", "--help"},
env: map[string]string{
"CODER_AGENT_LOG_DIR": "/tmp",
},
}, },
{ {
name: "coder list --output json", name: "coder list --output json",
@ -78,9 +64,12 @@ func TestCommandHelp(t *testing.T) {
}, },
} }
root := cli.Root(cli.AGPL()) rootCmd := new(cli.RootCmd)
root, err := rootCmd.Command(rootCmd.AGPL())
require.NoError(t, err)
ExtractCommandPathsLoop: ExtractCommandPathsLoop:
for _, cp := range extractVisibleCommandPaths(nil, root.Commands()) { for _, cp := range extractVisibleCommandPaths(nil, root.Children) {
name := fmt.Sprintf("coder %s --help", strings.Join(cp, " ")) name := fmt.Sprintf("coder %s --help", strings.Join(cp, " "))
cmd := append(cp, "--help") cmd := append(cp, "--help")
for _, tt := range tests { for _, tt := range tests {
@ -91,100 +80,88 @@ ExtractCommandPathsLoop:
tests = append(tests, testCase{name: name, cmd: cmd}) tests = append(tests, testCase{name: name, cmd: cmd})
} }
wd, err := os.Getwd()
require.NoError(t, err)
if runtime.GOOS == "windows" {
wd = strings.ReplaceAll(wd, "\\", "\\\\")
}
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
env := make(map[string]string) t.Parallel()
for k, v := range commonEnv { ctx := testutil.Context(t, testutil.WaitLong)
env[k] = v
}
for k, v := range tt.env {
env[k] = v
}
// Unset all CODER_ environment variables for a clean slate. var outBuf bytes.Buffer
for _, kv := range os.Environ() { inv, cfg := clitest.New(t, tt.cmd...)
name := strings.Split(kv, "=")[0] inv.Stderr = &outBuf
if _, ok := env[name]; !ok && strings.HasPrefix(name, "CODER_") { inv.Stdout = &outBuf
t.Setenv(name, "") inv.Environ.Set("CODER_URL", rootClient.URL.String())
} inv.Environ.Set("CODER_SESSION_TOKEN", rootClient.SessionToken())
} inv.Environ.Set("CODER_CACHE_DIRECTORY", "~/.cache")
// Override environment variables for a reproducible test.
for k, v := range env {
t.Setenv(k, v)
}
ctx, _ := testutil.Context(t)
tmpwd := "/"
if runtime.GOOS == "windows" {
tmpwd = "C:\\"
}
err := os.Chdir(tmpwd)
var buf bytes.Buffer
cmd, cfg := clitest.New(t, tt.cmd...)
clitest.SetupConfig(t, rootClient, cfg) clitest.SetupConfig(t, rootClient, cfg)
cmd.SetOut(&buf)
assert.NoError(t, err)
err = cmd.ExecuteContext(ctx)
err2 := os.Chdir(wd)
require.NoError(t, err)
require.NoError(t, err2)
got := buf.Bytes() clitest.StartWithWaiter(t, inv.WithContext(ctx)).RequireSuccess()
replace := map[string][]byte{ actual := outBuf.Bytes()
// Remove CRLF newlines (Windows). if len(actual) == 0 {
string([]byte{'\r', '\n'}): []byte("\n"), t.Fatal("no output")
// The `coder templates create --help` command prints the path
// to the working directory (--directory flag default value).
fmt.Sprintf("%q", tmpwd): []byte("\"[current directory]\""),
} }
for k, v := range replacements { for k, v := range replacements {
replace[k] = []byte(v) actual = bytes.ReplaceAll(actual, []byte(k), []byte(v))
}
for k, v := range replace {
got = bytes.ReplaceAll(got, []byte(k), v)
} }
// Replace any timestamps with a placeholder. // Replace any timestamps with a placeholder.
got = timestampRegex.ReplaceAll(got, []byte("[timestamp]")) actual = timestampRegex.ReplaceAll(actual, []byte("[timestamp]"))
gf := filepath.Join("testdata", strings.Replace(tt.name, " ", "_", -1)+".golden") homeDir, err := os.UserHomeDir()
require.NoError(t, err)
configDir := config.DefaultDir()
actual = bytes.ReplaceAll(actual, []byte(configDir), []byte("~/.config/coderv2"))
actual = bytes.ReplaceAll(actual, []byte(codersdk.DefaultCacheDir()), []byte("[cache dir]"))
// The home directory changes depending on the test environment.
actual = bytes.ReplaceAll(actual, []byte(homeDir), []byte("~"))
goldenPath := filepath.Join("testdata", strings.Replace(tt.name, " ", "_", -1)+".golden")
if *updateGoldenFiles { if *updateGoldenFiles {
t.Logf("update golden file for: %q: %s", tt.name, gf) t.Logf("update golden file for: %q: %s", tt.name, goldenPath)
err = os.WriteFile(gf, got, 0o600) err = os.WriteFile(goldenPath, actual, 0o600)
require.NoError(t, err, "update golden file") require.NoError(t, err, "update golden file")
} }
want, err := os.ReadFile(gf) expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "read golden file, run \"make update-golden-files\" and commit the changes") require.NoError(t, err, "read golden file, run \"make update-golden-files\" and commit the changes")
// Remove CRLF newlines (Windows).
want = bytes.ReplaceAll(want, []byte{'\r', '\n'}, []byte{'\n'}) // Normalize files to tolerate different operating systems.
require.Equal(t, string(want), string(got), "golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes", gf) for _, r := range []struct {
old string
new string
}{
{"\r\n", "\n"},
{`~\.cache\coder`, "~/.cache/coder"},
{`C:\Users\RUNNER~1\AppData\Local\Temp`, "/tmp"},
{os.TempDir(), "/tmp"},
} {
expected = bytes.ReplaceAll(expected, []byte(r.old), []byte(r.new))
actual = bytes.ReplaceAll(actual, []byte(r.old), []byte(r.new))
}
require.Equal(
t, string(expected), string(actual),
"golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes",
goldenPath,
)
}) })
} }
} }
func extractVisibleCommandPaths(cmdPath []string, cmds []*cobra.Command) [][]string { func extractVisibleCommandPaths(cmdPath []string, cmds []*clibase.Cmd) [][]string {
var cmdPaths [][]string var cmdPaths [][]string
for _, c := range cmds { for _, c := range cmds {
if c.Hidden { if c.Hidden {
continue continue
} }
// TODO: re-enable after clibase migration.
if c.Name() == "server" {
continue
}
cmdPath := append(cmdPath, c.Name()) cmdPath := append(cmdPath, c.Name())
cmdPaths = append(cmdPaths, cmdPath) cmdPaths = append(cmdPaths, cmdPath)
cmdPaths = append(cmdPaths, extractVisibleCommandPaths(cmdPath, c.Commands())...) cmdPaths = append(cmdPaths, extractVisibleCommandPaths(cmdPath, c.Children)...)
} }
return cmdPaths return cmdPaths
} }
@ -241,113 +218,13 @@ func prepareTestData(t *testing.T) (*codersdk.Client, map[string]string) {
func TestRoot(t *testing.T) { func TestRoot(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("FormatCobraError", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
cmd, _ := clitest.New(t, "delete")
cmd, err := cmd.ExecuteC()
errStr := cli.FormatCobraError(err, cmd)
require.Contains(t, errStr, "Run 'coder delete --help' for usage.")
})
t.Run("Verbose", func(t *testing.T) {
t.Parallel()
// Test that the verbose error is masked without verbose flag.
t.Run("NoVerboseAPIError", func(t *testing.T) {
t.Parallel()
cmd, _ := clitest.New(t)
cmd.RunE = func(cmd *cobra.Command, args []string) error {
var err error = &codersdk.Error{
Response: codersdk.Response{
Message: "This is a message.",
},
Helper: "Try this instead.",
}
err = xerrors.Errorf("wrap me: %w", err)
return err
}
cmd, err := cmd.ExecuteC()
errStr := cli.FormatCobraError(err, cmd)
require.Contains(t, errStr, "This is a message. Try this instead.")
require.NotContains(t, errStr, err.Error())
})
// Assert that a regular error is not masked when verbose is not
// specified.
t.Run("NoVerboseRegularError", func(t *testing.T) {
t.Parallel()
cmd, _ := clitest.New(t)
cmd.RunE = func(cmd *cobra.Command, args []string) error {
return xerrors.Errorf("this is a non-codersdk error: %w", xerrors.Errorf("a wrapped error"))
}
cmd, err := cmd.ExecuteC()
errStr := cli.FormatCobraError(err, cmd)
require.Contains(t, errStr, err.Error())
})
// Test that both the friendly error and the verbose error are
// displayed when verbose is passed.
t.Run("APIError", func(t *testing.T) {
t.Parallel()
cmd, _ := clitest.New(t, "--verbose")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
var err error = &codersdk.Error{
Response: codersdk.Response{
Message: "This is a message.",
},
Helper: "Try this instead.",
}
err = xerrors.Errorf("wrap me: %w", err)
return err
}
cmd, err := cmd.ExecuteC()
errStr := cli.FormatCobraError(err, cmd)
require.Contains(t, errStr, "This is a message. Try this instead.")
require.Contains(t, errStr, err.Error())
})
// Assert that a regular error is not masked when verbose specified.
t.Run("RegularError", func(t *testing.T) {
t.Parallel()
cmd, _ := clitest.New(t, "--verbose")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
return xerrors.Errorf("this is a non-codersdk error: %w", xerrors.Errorf("a wrapped error"))
}
cmd, err := cmd.ExecuteC()
errStr := cli.FormatCobraError(err, cmd)
require.Contains(t, errStr, err.Error())
})
})
})
t.Run("Version", func(t *testing.T) { t.Run("Version", func(t *testing.T) {
t.Parallel() t.Parallel()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
cmd, _ := clitest.New(t, "version") inv, _ := clitest.New(t, "version")
cmd.SetOut(buf) inv.Stdout = buf
err := cmd.Execute() err := inv.Run()
require.NoError(t, err) require.NoError(t, err)
output := buf.String() output := buf.String()
@ -370,9 +247,9 @@ func TestRoot(t *testing.T) {
})) }))
defer srv.Close() defer srv.Close()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
cmd, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL) inv, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL)
cmd.SetOut(buf) inv.Stdout = buf
// This won't succeed, because we're using the login cmd to assert requests. // This won't succeed, because we're using the login cmd to assert requests.
_ = cmd.Execute() _ = inv.Run()
}) })
} }

View File

@ -14,11 +14,10 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/tracing" "github.com/coder/coder/coderd/tracing"
@ -33,21 +32,19 @@ import (
const scaletestTracerName = "coder_scaletest" const scaletestTracerName = "coder_scaletest"
func scaletest() *cobra.Command { func (r *RootCmd) scaletest() *clibase.Cmd {
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "scaletest", Use: "scaletest",
Short: "Run a scale test against the Coder API", Short: "Run a scale test against the Coder API",
Long: "Perform scale tests against the Coder server.", Handler: func(inv *clibase.Invocation) error {
RunE: func(cmd *cobra.Command, args []string) error { return inv.Command.HelpHandler(inv)
return cmd.Help() },
Children: []*clibase.Cmd{
r.scaletestCleanup(),
r.scaletestCreateWorkspaces(),
}, },
} }
cmd.AddCommand(
scaletestCleanup(),
scaletestCreateWorkspaces(),
)
return cmd return cmd
} }
@ -58,11 +55,34 @@ type scaletestTracingFlags struct {
tracePropagate bool tracePropagate bool
} }
func (s *scaletestTracingFlags) attach(cmd *cobra.Command) { func (s *scaletestTracingFlags) attach(opts *clibase.OptionSet) {
cliflag.BoolVarP(cmd.Flags(), &s.traceEnable, "trace", "", "CODER_LOADTEST_TRACE", false, "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md") *opts = append(
cliflag.BoolVarP(cmd.Flags(), &s.traceCoder, "trace-coder", "", "CODER_LOADTEST_TRACE_CODER", false, "Whether opentelemetry traces are sent to Coder. We recommend keeping this disabled unless we advise you to enable it.") *opts,
cliflag.StringVarP(cmd.Flags(), &s.traceHoneycombAPIKey, "trace-honeycomb-api-key", "", "CODER_LOADTEST_TRACE_HONEYCOMB_API_KEY", "", "Enables trace exporting to Honeycomb.io using the provided API key.") clibase.Option{
cliflag.BoolVarP(cmd.Flags(), &s.tracePropagate, "trace-propagate", "", "CODER_LOADTEST_TRACE_PROPAGATE", false, "Enables trace propagation to the Coder backend, which will be used to correlate server-side spans with client-side spans. Only enable this if the server is configured with the exact same tracing configuration as the client.") Flag: "trace",
Env: "CODER_SCALETEST_TRACE",
Description: "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md.",
Value: clibase.BoolOf(&s.traceEnable),
},
clibase.Option{
Flag: "trace-coder",
Env: "CODER_SCALETEST_TRACE_CODER",
Description: "Whether opentelemetry traces are sent to Coder. We recommend keeping this disabled unless we advise you to enable it.",
Value: clibase.BoolOf(&s.traceCoder),
},
clibase.Option{
Flag: "trace-honeycomb-api-key",
Env: "CODER_SCALETEST_TRACE_HONEYCOMB_API_KEY",
Description: "Enables trace exporting to Honeycomb.io using the provided API key.",
Value: clibase.StringOf(&s.traceHoneycombAPIKey),
},
clibase.Option{
Flag: "trace-propagate",
Env: "CODER_SCALETEST_TRACE_PROPAGATE",
Description: "Enables trace propagation to the Coder backend, which will be used to correlate server-side spans with client-side spans. Only enable this if the server is configured with the exact same tracing configuration as the client.",
Value: clibase.BoolOf(&s.tracePropagate),
},
)
} }
// provider returns a trace.TracerProvider, a close function and a bool showing // provider returns a trace.TracerProvider, a close function and a bool showing
@ -96,24 +116,45 @@ func (s *scaletestTracingFlags) provider(ctx context.Context) (trace.TracerProvi
type scaletestStrategyFlags struct { type scaletestStrategyFlags struct {
cleanup bool cleanup bool
concurrency int concurrency int64
timeout time.Duration timeout time.Duration
timeoutPerJob time.Duration timeoutPerJob time.Duration
} }
func (s *scaletestStrategyFlags) attach(cmd *cobra.Command) { func (s *scaletestStrategyFlags) attach(opts *clibase.OptionSet) {
concurrencyLong, concurrencyEnv, concurrencyDescription := "concurrency", "CODER_LOADTEST_CONCURRENCY", "Number of concurrent jobs to run. 0 means unlimited." concurrencyLong, concurrencyEnv, concurrencyDescription := "concurrency", "CODER_SCALETEST_CONCURRENCY", "Number of concurrent jobs to run. 0 means unlimited."
timeoutLong, timeoutEnv, timeoutDescription := "timeout", "CODER_LOADTEST_TIMEOUT", "Timeout for the entire test run. 0 means unlimited." timeoutLong, timeoutEnv, timeoutDescription := "timeout", "CODER_SCALETEST_TIMEOUT", "Timeout for the entire test run. 0 means unlimited."
jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription := "job-timeout", "CODER_LOADTEST_JOB_TIMEOUT", "Timeout per job. Jobs may take longer to complete under higher concurrency limits." jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription := "job-timeout", "CODER_SCALETEST_JOB_TIMEOUT", "Timeout per job. Jobs may take longer to complete under higher concurrency limits."
if s.cleanup { if s.cleanup {
concurrencyLong, concurrencyEnv, concurrencyDescription = "cleanup-"+concurrencyLong, "CODER_LOADTEST_CLEANUP_CONCURRENCY", strings.ReplaceAll(concurrencyDescription, "jobs", "cleanup jobs") concurrencyLong, concurrencyEnv, concurrencyDescription = "cleanup-"+concurrencyLong, "CODER_SCALETEST_CLEANUP_CONCURRENCY", strings.ReplaceAll(concurrencyDescription, "jobs", "cleanup jobs")
timeoutLong, timeoutEnv, timeoutDescription = "cleanup-"+timeoutLong, "CODER_LOADTEST_CLEANUP_TIMEOUT", strings.ReplaceAll(timeoutDescription, "test", "cleanup") timeoutLong, timeoutEnv, timeoutDescription = "cleanup-"+timeoutLong, "CODER_SCALETEST_CLEANUP_TIMEOUT", strings.ReplaceAll(timeoutDescription, "test", "cleanup")
jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription = "cleanup-"+jobTimeoutLong, "CODER_LOADTEST_CLEANUP_JOB_TIMEOUT", strings.ReplaceAll(jobTimeoutDescription, "jobs", "cleanup jobs") jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription = "cleanup-"+jobTimeoutLong, "CODER_SCALETEST_CLEANUP_JOB_TIMEOUT", strings.ReplaceAll(jobTimeoutDescription, "jobs", "cleanup jobs")
} }
cliflag.IntVarP(cmd.Flags(), &s.concurrency, concurrencyLong, "", concurrencyEnv, 1, concurrencyDescription) *opts = append(
cliflag.DurationVarP(cmd.Flags(), &s.timeout, timeoutLong, "", timeoutEnv, 30*time.Minute, timeoutDescription) *opts,
cliflag.DurationVarP(cmd.Flags(), &s.timeoutPerJob, jobTimeoutLong, "", jobTimeoutEnv, 5*time.Minute, jobTimeoutDescription) clibase.Option{
Flag: concurrencyLong,
Env: concurrencyEnv,
Description: concurrencyDescription,
Default: "1",
Value: clibase.Int64Of(&s.concurrency),
},
clibase.Option{
Flag: timeoutLong,
Env: timeoutEnv,
Description: timeoutDescription,
Default: "30m",
Value: clibase.DurationOf(&s.timeout),
},
clibase.Option{
Flag: jobTimeoutLong,
Env: jobTimeoutEnv,
Description: jobTimeoutDescription,
Default: "5m",
Value: clibase.DurationOf(&s.timeoutPerJob),
},
)
} }
func (s *scaletestStrategyFlags) toStrategy() harness.ExecutionStrategy { func (s *scaletestStrategyFlags) toStrategy() harness.ExecutionStrategy {
@ -124,7 +165,7 @@ func (s *scaletestStrategyFlags) toStrategy() harness.ExecutionStrategy {
strategy = harness.ConcurrentExecutionStrategy{} strategy = harness.ConcurrentExecutionStrategy{}
} else { } else {
strategy = harness.ParallelExecutionStrategy{ strategy = harness.ParallelExecutionStrategy{
Limit: s.concurrency, Limit: int(s.concurrency),
} }
} }
@ -208,8 +249,14 @@ type scaletestOutputFlags struct {
outputSpecs []string outputSpecs []string
} }
func (s *scaletestOutputFlags) attach(cmd *cobra.Command) { func (s *scaletestOutputFlags) attach(opts *clibase.OptionSet) {
cliflag.StringArrayVarP(cmd.Flags(), &s.outputSpecs, "output", "", "CODER_SCALETEST_OUTPUTS", []string{"text"}, `Output format specs in the format "<format>[:<path>]". Not specifying a path will default to stdout. Available formats: text, json.`) *opts = append(*opts, clibase.Option{
Flag: "output",
Env: "CODER_SCALETEST_OUTPUTS",
Description: `Output format specs in the format "<format>[:<path>]". Not specifying a path will default to stdout. Available formats: text, json.`,
Default: "text",
Value: clibase.StringArrayOf(&s.outputSpecs),
})
} }
func (s *scaletestOutputFlags) parse() ([]scaleTestOutput, error) { func (s *scaletestOutputFlags) parse() ([]scaleTestOutput, error) {
@ -308,21 +355,21 @@ func (r *userCleanupRunner) Run(ctx context.Context, _ string, _ io.Writer) erro
return nil return nil
} }
func scaletestCleanup() *cobra.Command { func (r *RootCmd) scaletestCleanup() *clibase.Cmd {
cleanupStrategy := &scaletestStrategyFlags{cleanup: true} cleanupStrategy := &scaletestStrategyFlags{cleanup: true}
client := new(codersdk.Client)
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "cleanup", Use: "cleanup",
Short: "Cleanup any orphaned scaletest resources", Short: "Cleanup scaletest workspaces, then cleanup scaletest users.",
Long: "Cleanup scaletest workspaces, then cleanup scaletest users. The strategy flags will apply to each stage of the cleanup process.", Long: "The strategy flags will apply to each stage of the cleanup process.",
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.Chain(
ctx := cmd.Context() r.InitClient(client),
client, err := CreateClient(cmd) ),
if err != nil { Handler: func(inv *clibase.Invocation) error {
return err ctx := inv.Context()
}
_, err = requireAdmin(ctx, client) _, err := requireAdmin(ctx, client)
if err != nil { if err != nil {
return err return err
} }
@ -336,7 +383,7 @@ func scaletestCleanup() *cobra.Command {
}, },
} }
cmd.PrintErrln("Fetching scaletest workspaces...") cliui.Infof(inv.Stdout, "Fetching scaletest workspaces...")
var ( var (
pageNumber = 0 pageNumber = 0
limit = 100 limit = 100
@ -366,9 +413,9 @@ func scaletestCleanup() *cobra.Command {
workspaces = append(workspaces, pageWorkspaces...) workspaces = append(workspaces, pageWorkspaces...)
} }
cmd.PrintErrf("Found %d scaletest workspaces\n", len(workspaces)) cliui.Errorf(inv.Stderr, "Found %d scaletest workspaces\n", len(workspaces))
if len(workspaces) != 0 { if len(workspaces) != 0 {
cmd.Println("Deleting scaletest workspaces...") cliui.Infof(inv.Stdout, "Deleting scaletest workspaces..."+"\n")
harness := harness.NewTestHarness(cleanupStrategy.toStrategy(), harness.ConcurrentExecutionStrategy{}) harness := harness.NewTestHarness(cleanupStrategy.toStrategy(), harness.ConcurrentExecutionStrategy{})
for i, w := range workspaces { for i, w := range workspaces {
@ -384,16 +431,16 @@ func scaletestCleanup() *cobra.Command {
return xerrors.Errorf("run test harness to delete workspaces (harness failure, not a test failure): %w", err) return xerrors.Errorf("run test harness to delete workspaces (harness failure, not a test failure): %w", err)
} }
cmd.Println("Done deleting scaletest workspaces:") cliui.Infof(inv.Stdout, "Done deleting scaletest workspaces:"+"\n")
res := harness.Results() res := harness.Results()
res.PrintText(cmd.ErrOrStderr()) res.PrintText(inv.Stderr)
if res.TotalFail > 0 { if res.TotalFail > 0 {
return xerrors.Errorf("failed to delete scaletest workspaces") return xerrors.Errorf("failed to delete scaletest workspaces")
} }
} }
cmd.PrintErrln("Fetching scaletest users...") cliui.Infof(inv.Stdout, "Fetching scaletest users...")
pageNumber = 0 pageNumber = 0
limit = 100 limit = 100
var users []codersdk.User var users []codersdk.User
@ -423,9 +470,9 @@ func scaletestCleanup() *cobra.Command {
users = append(users, pageUsers...) users = append(users, pageUsers...)
} }
cmd.PrintErrf("Found %d scaletest users\n", len(users)) cliui.Errorf(inv.Stderr, "Found %d scaletest users\n", len(users))
if len(workspaces) != 0 { if len(workspaces) != 0 {
cmd.Println("Deleting scaletest users...") cliui.Infof(inv.Stdout, "Deleting scaletest users..."+"\n")
harness := harness.NewTestHarness(cleanupStrategy.toStrategy(), harness.ConcurrentExecutionStrategy{}) harness := harness.NewTestHarness(cleanupStrategy.toStrategy(), harness.ConcurrentExecutionStrategy{})
for i, u := range users { for i, u := range users {
@ -444,9 +491,9 @@ func scaletestCleanup() *cobra.Command {
return xerrors.Errorf("run test harness to delete users (harness failure, not a test failure): %w", err) return xerrors.Errorf("run test harness to delete users (harness failure, not a test failure): %w", err)
} }
cmd.Println("Done deleting scaletest users:") cliui.Infof(inv.Stdout, "Done deleting scaletest users:"+"\n")
res := harness.Results() res := harness.Results()
res.PrintText(cmd.ErrOrStderr()) res.PrintText(inv.Stderr)
if res.TotalFail > 0 { if res.TotalFail > 0 {
return xerrors.Errorf("failed to delete scaletest users") return xerrors.Errorf("failed to delete scaletest users")
@ -457,13 +504,13 @@ func scaletestCleanup() *cobra.Command {
}, },
} }
cleanupStrategy.attach(cmd) cleanupStrategy.attach(&cmd.Options)
return cmd return cmd
} }
func scaletestCreateWorkspaces() *cobra.Command { func (r *RootCmd) scaletestCreateWorkspaces() *clibase.Cmd {
var ( var (
count int count int64
template string template string
parametersFile string parametersFile string
parameters []string // key=value parameters []string // key=value
@ -494,18 +541,15 @@ func scaletestCreateWorkspaces() *cobra.Command {
output = &scaletestOutputFlags{} output = &scaletestOutputFlags{}
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
Use: "create-workspaces",
Short: "Creates many workspaces and waits for them to be ready",
Long: `Creates many users, then creates a workspace for each user and waits for them finish building and fully come online. Optionally runs a command inside each workspace, and connects to the workspace over WireGuard.
It is recommended that all rate limits are disabled on the server before running this scaletest. This test generates many login events which will be rate limited against the (most likely single) IP.`, cmd := &clibase.Cmd{
RunE: func(cmd *cobra.Command, args []string) error { Use: "create-workspaces",
ctx := cmd.Context() Short: "Creates many users, then creates a workspace for each user and waits for them finish building and fully come online. Optionally runs a command inside each workspace, and connects to the workspace over WireGuard.",
client, err := CreateClient(cmd) Long: `It is recommended that all rate limits are disabled on the server before running this scaletest. This test generates many login events which will be rate limited against the (most likely single) IP.`,
if err != nil { Middleware: r.InitClient(client),
return err Handler: func(inv *clibase.Invocation) error {
} ctx := inv.Context()
me, err := requireAdmin(ctx, client) me, err := requireAdmin(ctx, client)
if err != nil { if err != nil {
@ -612,16 +656,16 @@ It is recommended that all rate limits are disabled on the server before running
if err != nil { if err != nil {
return xerrors.Errorf("start dry run workspace creation: %w", err) return xerrors.Errorf("start dry run workspace creation: %w", err)
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Planning workspace...") _, _ = fmt.Fprintln(inv.Stdout, "Planning workspace...")
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
return client.TemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) return client.TemplateVersionDryRun(inv.Context(), templateVersion.ID, dryRun.ID)
}, },
Cancel: func() error { Cancel: func() error {
return client.CancelTemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) return client.CancelTemplateVersionDryRun(inv.Context(), templateVersion.ID, dryRun.ID)
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
return client.TemplateVersionDryRunLogsAfter(cmd.Context(), templateVersion.ID, dryRun.ID, 0) return client.TemplateVersionDryRunLogsAfter(inv.Context(), templateVersion.ID, dryRun.ID, 0)
}, },
// Don't show log output for the dry-run unless there's an error. // Don't show log output for the dry-run unless there's an error.
Silent: true, Silent: true,
@ -645,7 +689,7 @@ It is recommended that all rate limits are disabled on the server before running
tracer := tracerProvider.Tracer(scaletestTracerName) tracer := tracerProvider.Tracer(scaletestTracerName)
th := harness.NewTestHarness(strategy.toStrategy(), cleanupStrategy.toStrategy()) th := harness.NewTestHarness(strategy.toStrategy(), cleanupStrategy.toStrategy())
for i := 0; i < count; i++ { for i := 0; i < int(count); i++ {
const name = "workspacebuild" const name = "workspacebuild"
id := strconv.Itoa(i) id := strconv.Itoa(i)
@ -728,7 +772,7 @@ It is recommended that all rate limits are disabled on the server before running
} }
// TODO: live progress output // TODO: live progress output
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Running load test...") _, _ = fmt.Fprintln(inv.Stderr, "Running load test...")
testCtx, testCancel := strategy.toContext(ctx) testCtx, testCancel := strategy.toContext(ctx)
defer testCancel() defer testCancel()
err = th.Run(testCtx) err = th.Run(testCtx)
@ -738,13 +782,13 @@ It is recommended that all rate limits are disabled on the server before running
res := th.Results() res := th.Results()
for _, o := range outputs { for _, o := range outputs {
err = o.write(res, cmd.OutOrStdout()) err = o.write(res, inv.Stdout)
if err != nil { if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err) return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
} }
} }
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nCleaning up...") _, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx) cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel() defer cleanupCancel()
err = th.Cleanup(cleanupCtx) err = th.Cleanup(cleanupCtx)
@ -754,12 +798,12 @@ It is recommended that all rate limits are disabled on the server before running
// Upload traces. // Upload traces.
if tracingEnabled { if tracingEnabled {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nUploading traces...") _, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel() defer cancel()
err := closeTracing(ctx) err := closeTracing(ctx)
if err != nil { if err != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "\nError uploading traces: %+v\n", err) _, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
} }
} }
@ -771,32 +815,124 @@ It is recommended that all rate limits are disabled on the server before running
}, },
} }
cliflag.IntVarP(cmd.Flags(), &count, "count", "c", "CODER_LOADTEST_COUNT", 1, "Required: Number of workspaces to create.") cmd.Options = clibase.OptionSet{
cliflag.StringVarP(cmd.Flags(), &template, "template", "t", "CODER_LOADTEST_TEMPLATE", "", "Required: Name or ID of the template to use for workspaces.") {
cliflag.StringVarP(cmd.Flags(), &parametersFile, "parameters-file", "", "CODER_LOADTEST_PARAMETERS_FILE", "", "Path to a YAML file containing the parameters to use for each workspace.") Flag: "count",
cliflag.StringArrayVarP(cmd.Flags(), &parameters, "parameter", "", "CODER_LOADTEST_PARAMETERS", []string{}, "Parameters to use for each workspace. Can be specified multiple times. Overrides any existing parameters with the same name from --parameters-file. Format: key=value") FlagShorthand: "c",
Env: "CODER_SCALETEST_COUNT",
Default: "1",
Description: "Required: Number of workspaces to create.",
Value: clibase.Int64Of(&count),
},
{
Flag: "template",
FlagShorthand: "t",
Env: "CODER_SCALETEST_TEMPLATE",
Description: "Required: Name or ID of the template to use for workspaces.",
Value: clibase.StringOf(&template),
},
{
Flag: "parameters-file",
Env: "CODER_SCALETEST_PARAMETERS_FILE",
Description: "Path to a YAML file containing the parameters to use for each workspace.",
Value: clibase.StringOf(&parametersFile),
},
{
Flag: "parameter",
Env: "CODER_SCALETEST_PARAMETERS",
Description: "Parameters to use for each workspace. Can be specified multiple times. Overrides any existing parameters with the same name from --parameters-file. Format: key=value.",
Value: clibase.StringArrayOf(&parameters),
},
{
Flag: "no-plan",
Env: "CODER_SCALETEST_NO_PLAN",
Description: `Skip the dry-run step to plan the workspace creation. This step ensures that the given parameters are valid for the given template.`,
Value: clibase.BoolOf(&noPlan),
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes. You can cleanup manually using coder scaletest cleanup.",
Value: clibase.BoolOf(&noCleanup),
},
{
Flag: "no-wait-for-agents",
Env: "CODER_SCALETEST_NO_WAIT_FOR_AGENTS",
Description: `Do not wait for agents to start before marking the test as succeeded. This can be useful if you are running the test against a template that does not start the agent quickly.`,
Value: clibase.BoolOf(&noWaitForAgents),
},
{
Flag: "run-command",
Env: "CODER_SCALETEST_RUN_COMMAND",
Description: "Command to run inside each workspace using reconnecting-pty (i.e. web terminal protocol). " + "If not specified, no command will be run.",
Value: clibase.StringOf(&runCommand),
},
{
Flag: "run-timeout",
Env: "CODER_SCALETEST_RUN_TIMEOUT",
Default: "5s",
Description: "Timeout for the command to complete.",
Value: clibase.DurationOf(&runTimeout),
},
{
Flag: "run-expect-timeout",
Env: "CODER_SCALETEST_RUN_EXPECT_TIMEOUT",
cliflag.BoolVarP(cmd.Flags(), &noPlan, "no-plan", "", "CODER_LOADTEST_NO_PLAN", false, "Skip the dry-run step to plan the workspace creation. This step ensures that the given parameters are valid for the given template.") Description: "Expect the command to timeout." + " If the command does not finish within the given --run-timeout, it will be marked as succeeded." + " If the command finishes before the timeout, it will be marked as failed.",
cliflag.BoolVarP(cmd.Flags(), &noCleanup, "no-cleanup", "", "CODER_LOADTEST_NO_CLEANUP", false, "Do not clean up resources after the test completes. You can cleanup manually using `coder scaletest cleanup`.") Value: clibase.BoolOf(&runExpectTimeout),
// cliflag.BoolVarP(cmd.Flags(), &noCleanupFailures, "no-cleanup-failures", "", "CODER_LOADTEST_NO_CLEANUP_FAILURES", false, "Do not clean up resources from failed jobs to aid in debugging failures. You can cleanup manually using `coder scaletest cleanup`.") },
cliflag.BoolVarP(cmd.Flags(), &noWaitForAgents, "no-wait-for-agents", "", "CODER_LOADTEST_NO_WAIT_FOR_AGENTS", false, "Do not wait for agents to start before marking the test as succeeded. This can be useful if you are running the test against a template that does not start the agent quickly.") {
Flag: "run-expect-output",
Env: "CODER_SCALETEST_RUN_EXPECT_OUTPUT",
Description: "Expect the command to output the given string (on a single line). " + "If the command does not output the given string, it will be marked as failed.",
Value: clibase.StringOf(&runExpectOutput),
},
{
Flag: "run-log-output",
Env: "CODER_SCALETEST_RUN_LOG_OUTPUT",
Description: "Log the output of the command to the test logs. " + "This should be left off unless you expect small amounts of output. " + "Large amounts of output will cause high memory usage.",
Value: clibase.BoolOf(&runLogOutput),
},
{
Flag: "connect-url",
Env: "CODER_SCALETEST_CONNECT_URL",
Description: "URL to connect to inside the the workspace over WireGuard. " + "If not specified, no connections will be made over WireGuard.",
Value: clibase.StringOf(&connectURL),
},
{
Flag: "connect-mode",
Env: "CODER_SCALETEST_CONNECT_MODE",
Default: "derp",
Description: "Mode to use for connecting to the workspace.",
Value: clibase.EnumOf(&connectMode, "derp", "direct"),
},
{
Flag: "connect-hold",
Env: "CODER_SCALETEST_CONNECT_HOLD",
Default: "30s",
Description: "How long to hold the WireGuard connection open for.",
Value: clibase.DurationOf(&connectHold),
},
{
Flag: "connect-interval",
Env: "CODER_SCALETEST_CONNECT_INTERVAL",
Default: "1s",
Value: clibase.DurationOf(&connectInterval),
Description: "How long to wait between making requests to the --connect-url once the connection is established.",
},
{
Flag: "connect-timeout",
Env: "CODER_SCALETEST_CONNECT_TIMEOUT",
Default: "5s",
Description: "Timeout for each request to the --connect-url.",
Value: clibase.DurationOf(&connectTimeout),
},
}
cliflag.StringVarP(cmd.Flags(), &runCommand, "run-command", "", "CODER_LOADTEST_RUN_COMMAND", "", "Command to run inside each workspace using reconnecting-pty (i.e. web terminal protocol). If not specified, no command will be run.") tracingFlags.attach(&cmd.Options)
cliflag.DurationVarP(cmd.Flags(), &runTimeout, "run-timeout", "", "CODER_LOADTEST_RUN_TIMEOUT", 5*time.Second, "Timeout for the command to complete.") strategy.attach(&cmd.Options)
cliflag.BoolVarP(cmd.Flags(), &runExpectTimeout, "run-expect-timeout", "", "CODER_LOADTEST_RUN_EXPECT_TIMEOUT", false, "Expect the command to timeout. If the command does not finish within the given --run-timeout, it will be marked as succeeded. If the command finishes before the timeout, it will be marked as failed.") cleanupStrategy.attach(&cmd.Options)
cliflag.StringVarP(cmd.Flags(), &runExpectOutput, "run-expect-output", "", "CODER_LOADTEST_RUN_EXPECT_OUTPUT", "", "Expect the command to output the given string (on a single line). If the command does not output the given string, it will be marked as failed.") output.attach(&cmd.Options)
cliflag.BoolVarP(cmd.Flags(), &runLogOutput, "run-log-output", "", "CODER_LOADTEST_RUN_LOG_OUTPUT", false, "Log the output of the command to the test logs. This should be left off unless you expect small amounts of output. Large amounts of output will cause high memory usage.")
cliflag.StringVarP(cmd.Flags(), &connectURL, "connect-url", "", "CODER_LOADTEST_CONNECT_URL", "", "URL to connect to inside the the workspace over WireGuard. If not specified, no connections will be made over WireGuard.")
cliflag.StringVarP(cmd.Flags(), &connectMode, "connect-mode", "", "CODER_LOADTEST_CONNECT_MODE", "derp", "Mode to use for connecting to the workspace. Can be 'derp' or 'direct'.")
cliflag.DurationVarP(cmd.Flags(), &connectHold, "connect-hold", "", "CODER_LOADTEST_CONNECT_HOLD", 30*time.Second, "How long to hold the WireGuard connection open for.")
cliflag.DurationVarP(cmd.Flags(), &connectInterval, "connect-interval", "", "CODER_LOADTEST_CONNECT_INTERVAL", time.Second, "How long to wait between making requests to the --connect-url once the connection is established.")
cliflag.DurationVarP(cmd.Flags(), &connectTimeout, "connect-timeout", "", "CODER_LOADTEST_CONNECT_TIMEOUT", 5*time.Second, "Timeout for each request to the --connect-url.")
tracingFlags.attach(cmd)
strategy.attach(cmd)
cleanupStrategy.attach(cmd)
output.attach(cmd)
return cmd return cmd
} }

View File

@ -54,7 +54,7 @@ param3: 1
err = f.Close() err = f.Close()
require.NoError(t, err) require.NoError(t, err)
cmd, root := clitest.New(t, "scaletest", "create-workspaces", inv, root := clitest.New(t, "scaletest", "create-workspaces",
"--count", "2", "--count", "2",
"--template", template.Name, "--template", template.Name,
"--parameters-file", paramsFile, "--parameters-file", paramsFile,
@ -77,12 +77,12 @@ param3: 1
) )
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
done := make(chan any) done := make(chan any)
go func() { go func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.NoError(t, err) assert.NoError(t, err)
close(done) close(done)
}() }()
@ -148,19 +148,19 @@ param3: 1
require.Len(t, users.Users, len(seenUsers)+1) require.Len(t, users.Users, len(seenUsers)+1)
// Cleanup. // Cleanup.
cmd, root = clitest.New(t, "scaletest", "cleanup", inv, root = clitest.New(t, "scaletest", "cleanup",
"--cleanup-concurrency", "1", "--cleanup-concurrency", "1",
"--cleanup-timeout", "30s", "--cleanup-timeout", "30s",
"--cleanup-job-timeout", "15s", "--cleanup-job-timeout", "15s",
) )
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty = ptytest.New(t) pty = ptytest.New(t)
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
done = make(chan any) done = make(chan any)
go func() { go func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.NoError(t, err) assert.NoError(t, err)
close(done) close(done)
}() }()

View File

@ -6,9 +6,9 @@ import (
"time" "time"
"github.com/jedib0t/go-pretty/v6/table" "github.com/jedib0t/go-pretty/v6/table"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/schedule" "github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/coderd/util/ptr"
@ -46,82 +46,78 @@ When enabling scheduled stop, enter a duration in one of the following formats:
* 2m (2 minutes) * 2m (2 minutes)
* 2 (2 minutes) * 2 (2 minutes)
` `
scheduleOverrideDescriptionLong = `Override the stop time of a currently running workspace instance. scheduleOverrideDescriptionLong = `
* The new stop time is calculated from *now*. * The new stop time is calculated from *now*.
* The new stop time must be at least 30 minutes in the future. * The new stop time must be at least 30 minutes in the future.
* The workspace template may restrict the maximum workspace runtime. * The workspace template may restrict the maximum workspace runtime.
` `
) )
func schedules() *cobra.Command { func (r *RootCmd) schedules() *clibase.Cmd {
scheduleCmd := &cobra.Command{ scheduleCmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "schedule { show | start | stop | override } <workspace>", Use: "schedule { show | start | stop | override } <workspace>",
Short: "Schedule automated start and stop times for workspaces", Short: "Schedule automated start and stop times for workspaces",
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
return cmd.Help() return inv.Command.HelpHandler(inv)
},
Children: []*clibase.Cmd{
r.scheduleShow(),
r.scheduleStart(),
r.scheduleStop(),
r.scheduleOverride(),
}, },
} }
scheduleCmd.AddCommand(
scheduleShow(),
scheduleStart(),
scheduleStop(),
scheduleOverride(),
)
return scheduleCmd return scheduleCmd
} }
func scheduleShow() *cobra.Command { func (r *RootCmd) scheduleShow() *clibase.Cmd {
showCmd := &cobra.Command{ client := new(codersdk.Client)
showCmd := &clibase.Cmd{
Use: "show <workspace-name>", Use: "show <workspace-name>",
Short: "Show workspace schedule", Short: "Show workspace schedule",
Long: scheduleShowDescriptionLong, Long: scheduleShowDescriptionLong,
Args: cobra.ExactArgs(1), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(1),
client, err := CreateClient(cmd) r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return err return err
} }
workspace, err := namedWorkspace(cmd, client, args[0]) return displaySchedule(workspace, inv.Stdout)
if err != nil {
return err
}
return displaySchedule(workspace, cmd.OutOrStdout())
}, },
} }
return showCmd return showCmd
} }
func scheduleStart() *cobra.Command { func (r *RootCmd) scheduleStart() *clibase.Cmd {
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "start <workspace-name> { <start-time> [day-of-week] [location] | manual }", Use: "start <workspace-name> { <start-time> [day-of-week] [location] | manual }",
Example: formatExamples( Long: scheduleStartDescriptionLong + "\n" + formatExamples(
example{ example{
Description: "Set the workspace to start at 9:30am (in Dublin) from Monday to Friday", Description: "Set the workspace to start at 9:30am (in Dublin) from Monday to Friday",
Command: "coder schedule start my-workspace 9:30AM Mon-Fri Europe/Dublin", Command: "coder schedule start my-workspace 9:30AM Mon-Fri Europe/Dublin",
}, },
), ),
Short: "Edit workspace start schedule", Short: "Edit workspace start schedule",
Long: scheduleStartDescriptionLong, Middleware: clibase.Chain(
Args: cobra.RangeArgs(2, 4), clibase.RequireRangeArgs(2, 4),
RunE: func(cmd *cobra.Command, args []string) error { r.InitClient(client),
client, err := CreateClient(cmd) ),
if err != nil { Handler: func(inv *clibase.Invocation) error {
return err workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
}
workspace, err := namedWorkspace(cmd, client, args[0])
if err != nil { if err != nil {
return err return err
} }
var schedStr *string var schedStr *string
if args[1] != "manual" { if inv.Args[1] != "manual" {
sched, err := parseCLISchedule(args[1:]...) sched, err := parseCLISchedule(inv.Args[1:]...)
if err != nil { if err != nil {
return err return err
} }
@ -129,93 +125,89 @@ func scheduleStart() *cobra.Command {
schedStr = ptr.Ref(sched.String()) schedStr = ptr.Ref(sched.String())
} }
err = client.UpdateWorkspaceAutostart(cmd.Context(), workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{ err = client.UpdateWorkspaceAutostart(inv.Context(), workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{
Schedule: schedStr, Schedule: schedStr,
}) })
if err != nil { if err != nil {
return err return err
} }
updated, err := namedWorkspace(cmd, client, args[0]) updated, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return err return err
} }
return displaySchedule(updated, cmd.OutOrStdout()) return displaySchedule(updated, inv.Stdout)
}, },
} }
return cmd return cmd
} }
func scheduleStop() *cobra.Command { func (r *RootCmd) scheduleStop() *clibase.Cmd {
return &cobra.Command{ client := new(codersdk.Client)
Args: cobra.ExactArgs(2), return &clibase.Cmd{
Use: "stop <workspace-name> { <duration> | manual }", Use: "stop <workspace-name> { <duration> | manual }",
Example: formatExamples( Long: scheduleStopDescriptionLong + "\n" + formatExamples(
example{ example{
Command: "coder schedule stop my-workspace 2h30m", Command: "coder schedule stop my-workspace 2h30m",
}, },
), ),
Short: "Edit workspace stop schedule", Short: "Edit workspace stop schedule",
Long: scheduleStopDescriptionLong, Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(2),
client, err := CreateClient(cmd) r.InitClient(client),
if err != nil { ),
return err Handler: func(inv *clibase.Invocation) error {
} workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
workspace, err := namedWorkspace(cmd, client, args[0])
if err != nil { if err != nil {
return err return err
} }
var durMillis *int64 var durMillis *int64
if args[1] != "manual" { if inv.Args[1] != "manual" {
dur, err := parseDuration(args[1]) dur, err := parseDuration(inv.Args[1])
if err != nil { if err != nil {
return err return err
} }
durMillis = ptr.Ref(dur.Milliseconds()) durMillis = ptr.Ref(dur.Milliseconds())
} }
if err := client.UpdateWorkspaceTTL(cmd.Context(), workspace.ID, codersdk.UpdateWorkspaceTTLRequest{ if err := client.UpdateWorkspaceTTL(inv.Context(), workspace.ID, codersdk.UpdateWorkspaceTTLRequest{
TTLMillis: durMillis, TTLMillis: durMillis,
}); err != nil { }); err != nil {
return err return err
} }
updated, err := namedWorkspace(cmd, client, args[0]) updated, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return err return err
} }
return displaySchedule(updated, cmd.OutOrStdout()) return displaySchedule(updated, inv.Stdout)
}, },
} }
} }
func scheduleOverride() *cobra.Command { func (r *RootCmd) scheduleOverride() *clibase.Cmd {
overrideCmd := &cobra.Command{ client := new(codersdk.Client)
Args: cobra.ExactArgs(2), overrideCmd := &clibase.Cmd{
Use: "override-stop <workspace-name> <duration from now>", Use: "override-stop <workspace-name> <duration from now>",
Example: formatExamples( Short: "Override the stop time of a currently running workspace instance.",
Long: scheduleOverrideDescriptionLong + "\n" + formatExamples(
example{ example{
Command: "coder schedule override-stop my-workspace 90m", Command: "coder schedule override-stop my-workspace 90m",
}, },
), ),
Short: "Edit stop time of active workspace", Middleware: clibase.Chain(
Long: scheduleOverrideDescriptionLong, clibase.RequireNArgs(2),
RunE: func(cmd *cobra.Command, args []string) error { r.InitClient(client),
overrideDuration, err := parseDuration(args[1]) ),
Handler: func(inv *clibase.Invocation) error {
overrideDuration, err := parseDuration(inv.Args[1])
if err != nil { if err != nil {
return err return err
} }
client, err := CreateClient(cmd) workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil {
return xerrors.Errorf("create client: %w", err)
}
workspace, err := namedWorkspace(cmd, client, args[0])
if err != nil { if err != nil {
return xerrors.Errorf("get workspace: %w", err) return xerrors.Errorf("get workspace: %w", err)
} }
@ -227,24 +219,24 @@ func scheduleOverride() *cobra.Command {
if overrideDuration < 29*time.Minute { if overrideDuration < 29*time.Minute {
_, _ = fmt.Fprintf( _, _ = fmt.Fprintf(
cmd.OutOrStdout(), inv.Stdout,
"Please specify a duration of at least 30 minutes.\n", "Please specify a duration of at least 30 minutes.\n",
) )
return nil return nil
} }
newDeadline := time.Now().In(loc).Add(overrideDuration) newDeadline := time.Now().In(loc).Add(overrideDuration)
if err := client.PutExtendWorkspace(cmd.Context(), workspace.ID, codersdk.PutExtendWorkspaceRequest{ if err := client.PutExtendWorkspace(inv.Context(), workspace.ID, codersdk.PutExtendWorkspaceRequest{
Deadline: newDeadline, Deadline: newDeadline,
}); err != nil { }); err != nil {
return err return err
} }
updated, err := namedWorkspace(cmd, client, args[0]) updated, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return err return err
} }
return displaySchedule(updated, cmd.OutOrStdout()) return displaySchedule(updated, inv.Stdout)
}, },
} }
return overrideCmd return overrideCmd

View File

@ -42,11 +42,11 @@ func TestScheduleShow(t *testing.T) {
stdoutBuf = &bytes.Buffer{} stdoutBuf = &bytes.Buffer{}
) )
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
err := cmd.Execute() err := inv.Run()
require.NoError(t, err, "unexpected error") require.NoError(t, err, "unexpected error")
lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n")
if assert.Len(t, lines, 4) { if assert.Len(t, lines, 4) {
@ -79,11 +79,11 @@ func TestScheduleShow(t *testing.T) {
stdoutBuf = &bytes.Buffer{} stdoutBuf = &bytes.Buffer{}
) )
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
err := cmd.Execute() err := inv.Run()
require.NoError(t, err, "unexpected error") require.NoError(t, err, "unexpected error")
lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n")
if assert.Len(t, lines, 4) { if assert.Len(t, lines, 4) {
@ -104,10 +104,10 @@ func TestScheduleShow(t *testing.T) {
_ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
) )
cmd, root := clitest.New(t, "schedule", "show", "doesnotexist") inv, root := clitest.New(t, "schedule", "show", "doesnotexist")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
err := cmd.Execute() err := inv.Run()
require.ErrorContains(t, err, "status code 404", "unexpected error") require.ErrorContains(t, err, "status code 404", "unexpected error")
}) })
} }
@ -132,11 +132,11 @@ func TestScheduleStart(t *testing.T) {
) )
// Set a well-specified autostart schedule // Set a well-specified autostart schedule
cmd, root := clitest.New(t, "schedule", "start", workspace.Name, "9:30AM", "Mon-Fri", tz) inv, root := clitest.New(t, "schedule", "start", workspace.Name, "9:30AM", "Mon-Fri", tz)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err, "unexpected error") assert.NoError(t, err, "unexpected error")
lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n")
if assert.Len(t, lines, 4) { if assert.Len(t, lines, 4) {
@ -157,11 +157,11 @@ func TestScheduleStart(t *testing.T) {
stdoutBuf = &bytes.Buffer{} stdoutBuf = &bytes.Buffer{}
// unset schedule // unset schedule
cmd, root = clitest.New(t, "schedule", "start", workspace.Name, "manual") inv, root = clitest.New(t, "schedule", "start", workspace.Name, "manual")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
err = cmd.Execute() err = inv.Run()
assert.NoError(t, err, "unexpected error") assert.NoError(t, err, "unexpected error")
lines = strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") lines = strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n")
if assert.Len(t, lines, 4) { if assert.Len(t, lines, 4) {
@ -186,11 +186,11 @@ func TestScheduleStop(t *testing.T) {
) )
// Set the workspace TTL // Set the workspace TTL
cmd, root := clitest.New(t, "schedule", "stop", workspace.Name, ttl.String()) inv, root := clitest.New(t, "schedule", "stop", workspace.Name, ttl.String())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err, "unexpected error") assert.NoError(t, err, "unexpected error")
lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n")
if assert.Len(t, lines, 4) { if assert.Len(t, lines, 4) {
@ -203,11 +203,11 @@ func TestScheduleStop(t *testing.T) {
stdoutBuf = &bytes.Buffer{} stdoutBuf = &bytes.Buffer{}
// Unset the workspace TTL // Unset the workspace TTL
cmd, root = clitest.New(t, "schedule", "stop", workspace.Name, "manual") inv, root = clitest.New(t, "schedule", "stop", workspace.Name, "manual")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
err = cmd.Execute() err = inv.Run()
assert.NoError(t, err, "unexpected error") assert.NoError(t, err, "unexpected error")
lines = strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") lines = strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n")
if assert.Len(t, lines, 4) { if assert.Len(t, lines, 4) {
@ -247,12 +247,12 @@ func TestScheduleOverride(t *testing.T) {
initDeadline := time.Now().Add(time.Duration(*workspace.TTLMillis) * time.Millisecond) initDeadline := time.Now().Add(time.Duration(*workspace.TTLMillis) * time.Millisecond)
require.WithinDuration(t, initDeadline, workspace.LatestBuild.Deadline.Time, time.Minute) require.WithinDuration(t, initDeadline, workspace.LatestBuild.Deadline.Time, time.Minute)
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
// When: we execute `coder schedule override workspace <number without units>` // When: we execute `coder schedule override workspace <number without units>`
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
// Then: the deadline of the latest build is updated assuming the units are minutes // Then: the deadline of the latest build is updated assuming the units are minutes
@ -287,12 +287,12 @@ func TestScheduleOverride(t *testing.T) {
initDeadline := time.Now().Add(time.Duration(*workspace.TTLMillis) * time.Millisecond) initDeadline := time.Now().Add(time.Duration(*workspace.TTLMillis) * time.Millisecond)
require.WithinDuration(t, initDeadline, workspace.LatestBuild.Deadline.Time, time.Minute) require.WithinDuration(t, initDeadline, workspace.LatestBuild.Deadline.Time, time.Minute)
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
// When: we execute `coder bump workspace <not a number>` // When: we execute `coder bump workspace <not a number>`
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
// Then: the command fails // Then: the command fails
require.ErrorContains(t, err, "invalid duration") require.ErrorContains(t, err, "invalid duration")
}) })
@ -339,12 +339,12 @@ func TestScheduleOverride(t *testing.T) {
require.Zero(t, workspace.LatestBuild.Deadline) require.Zero(t, workspace.LatestBuild.Deadline)
require.NoError(t, err) require.NoError(t, err)
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
// When: we execute `coder bump workspace`` // When: we execute `coder bump workspace``
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
// Then: nothing happens and the deadline remains unset // Then: nothing happens and the deadline remains unset
@ -370,11 +370,10 @@ func TestScheduleStartDefaults(t *testing.T) {
) )
// Set an underspecified schedule // Set an underspecified schedule
cmd, root := clitest.New(t, "schedule", "start", workspace.Name, "9:30AM") inv, root := clitest.New(t, "schedule", "start", workspace.Name, "9:30AM")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetOut(stdoutBuf) inv.Stdout = stdoutBuf
err := inv.Run()
err := cmd.Execute()
require.NoError(t, err, "unexpected error") require.NoError(t, err, "unexpected error")
lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n")
if assert.Len(t, lines, 4) { if assert.Len(t, lines, 4) {

View File

@ -41,7 +41,6 @@ import (
"github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -97,7 +96,7 @@ func ReadGitAuthProvidersFromEnv(environ []string) ([]codersdk.GitAuthConfig, er
sort.Strings(environ) sort.Strings(environ)
var providers []codersdk.GitAuthConfig var providers []codersdk.GitAuthConfig
for _, v := range clibase.ParseEnviron(environ, envPrefix+"GITAUTH_") { for _, v := range clibase.ParseEnviron(environ, "CODER_GITAUTH_") {
tokens := strings.SplitN(v.Name, "_", 2) tokens := strings.SplitN(v.Name, "_", 2)
if len(tokens) != 2 { if len(tokens) != 2 {
return nil, xerrors.Errorf("invalid env var: %s", v.Name) return nil, xerrors.Errorf("invalid env var: %s", v.Name)
@ -157,92 +156,29 @@ func ReadGitAuthProvidersFromEnv(environ []string) ([]codersdk.GitAuthConfig, er
} }
// nolint:gocyclo // nolint:gocyclo
func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command { func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *clibase.Cmd {
root := &cobra.Command{ var (
cfg = new(codersdk.DeploymentValues)
opts = cfg.Options()
)
serverCmd := &clibase.Cmd{
Use: "server", Use: "server",
Short: "Start a Coder server", Short: "Start a Coder server",
DisableFlagParsing: true, Options: opts,
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.RequireNArgs(0),
Handler: func(inv *clibase.Invocation) error {
// Main command context for managing cancellation of running // Main command context for managing cancellation of running
// services. // services.
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(inv.Context())
defer cancel() defer cancel()
cfg := &codersdk.DeploymentValues{}
cliOpts := cfg.Options()
var configDir clibase.String
// This is a hack to get around the fact that the Cobra-defined
// flags are not available.
cliOpts.Add(clibase.Option{
Name: "Global Config",
Flag: config.FlagName,
Description: "Global Config is ignored in server mode.",
Hidden: true,
Default: config.DefaultDir(),
Value: &configDir,
})
err := cliOpts.SetDefaults()
if err != nil {
return xerrors.Errorf("set defaults: %w", err)
}
err = cliOpts.ParseEnv(clibase.ParseEnviron(os.Environ(), envPrefix))
if err != nil {
return xerrors.Errorf("parse env: %w", err)
}
flagSet := cliOpts.FlagSet()
// These parents and children will be moved once we convert the
// rest of the `cli` package to clibase.
flagSet.Usage = usageFn(cmd.ErrOrStderr(), &clibase.Cmd{
Parent: &clibase.Cmd{
Use: "coder",
},
Children: []*clibase.Cmd{
{
Use: "postgres-builtin-url",
Short: "Output the connection URL for the built-in PostgreSQL deployment.",
},
{
Use: "postgres-builtin-serve",
Short: "Run the built-in PostgreSQL deployment.",
},
},
Use: "server [flags]",
Short: "Start a Coder server",
Long: `
The server provides the Coder dashboard, API, and provisioners.
If no options are provided, the server will start with a built-in postgres
and an access URL provided by Coder's cloud service.
Use the following command to print the built-in postgres URL:
$ coder server postgres-builtin-url
Use the following command to manually run the built-in postgres:
$ coder server postgres-builtin-serve
Options may be provided via environment variables prefixed with "CODER_",
flags, and YAML configuration. The precedence is as follows:
1. Defaults
2. YAML configuration
3. Environment variables
4. Flags
`,
Options: cliOpts,
})
err = flagSet.Parse(args)
if err != nil {
return xerrors.Errorf("parse flags: %w", err)
}
if cfg.WriteConfig { if cfg.WriteConfig {
// TODO: this should output to a file. // TODO: this should output to a file.
n, err := cliOpts.ToYAML() n, err := opts.ToYAML()
if err != nil { if err != nil {
return xerrors.Errorf("generate yaml: %w", err) return xerrors.Errorf("generate yaml: %w", err)
} }
enc := yaml.NewEncoder(cmd.ErrOrStderr()) enc := yaml.NewEncoder(inv.Stderr)
err = enc.Encode(n) err = enc.Encode(n)
if err != nil { if err != nil {
return xerrors.Errorf("encode yaml: %w", err) return xerrors.Errorf("encode yaml: %w", err)
@ -255,7 +191,7 @@ flags, and YAML configuration. The precedence is as follows:
} }
// Print deprecation warnings. // Print deprecation warnings.
for _, opt := range cliOpts { for _, opt := range opts {
if opt.UseInstead == nil { if opt.UseInstead == nil {
continue continue
} }
@ -273,8 +209,8 @@ flags, and YAML configuration. The precedence is as follows:
} }
warnStr += "instead.\n" warnStr += "instead.\n"
cmd.PrintErr( cliui.Warn(inv.Stderr,
cliui.Styles.Warn.Render("WARN: ") + warnStr, warnStr,
) )
} }
@ -313,8 +249,8 @@ flags, and YAML configuration. The precedence is as follows:
filesRateLimit = -1 filesRateLimit = -1
} }
printLogo(cmd) printLogo(inv)
logger, logCloser, err := buildLogger(cmd, cfg) logger, logCloser, err := buildLogger(inv, cfg)
if err != nil { if err != nil {
return xerrors.Errorf("make logger: %w", err) return xerrors.Errorf("make logger: %w", err)
} }
@ -360,7 +296,7 @@ flags, and YAML configuration. The precedence is as follows:
shouldCoderTrace := cfg.Telemetry.Enable.Value() && !isTest() shouldCoderTrace := cfg.Telemetry.Enable.Value() && !isTest()
// Only override if telemetryTraceEnable was specifically set. // Only override if telemetryTraceEnable was specifically set.
// By default we want it to be controlled by telemetryEnable. // By default we want it to be controlled by telemetryEnable.
if cmd.Flags().Changed("telemetry-trace") { if inv.ParsedFlags().Changed("telemetry-trace") {
shouldCoderTrace = cfg.Telemetry.Trace.Value() shouldCoderTrace = cfg.Telemetry.Trace.Value()
} }
@ -389,12 +325,13 @@ flags, and YAML configuration. The precedence is as follows:
} }
} }
config := config.Root(configDir) config := r.createConfig()
builtinPostgres := false builtinPostgres := false
// Only use built-in if PostgreSQL URL isn't specified! // Only use built-in if PostgreSQL URL isn't specified!
if !cfg.InMemoryDatabase && cfg.PostgresURL == "" { if !cfg.InMemoryDatabase && cfg.PostgresURL == "" {
var closeFunc func() error var closeFunc func() error
cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath()) cliui.Infof(inv.Stdout, "Using built-in PostgreSQL (%s)", config.PostgresPath())
pgURL, closeFunc, err := startBuiltinPostgres(ctx, config, logger) pgURL, closeFunc, err := startBuiltinPostgres(ctx, config, logger)
if err != nil { if err != nil {
return err return err
@ -406,12 +343,12 @@ flags, and YAML configuration. The precedence is as follows:
} }
builtinPostgres = true builtinPostgres = true
defer func() { defer func() {
cmd.Printf("Stopping built-in PostgreSQL...\n") cliui.Infof(inv.Stdout, "Stopping built-in PostgreSQL...")
// Gracefully shut PostgreSQL down! // Gracefully shut PostgreSQL down!
if err := closeFunc(); err != nil { if err := closeFunc(); err != nil {
cmd.Printf("Failed to stop built-in PostgreSQL: %v\n", err) cliui.Errorf(inv.Stderr, "Failed to stop built-in PostgreSQL: %v", err)
} else { } else {
cmd.Printf("Stopped built-in PostgreSQL\n") cliui.Infof(inv.Stdout, "Stopped built-in PostgreSQL")
} }
}() }()
} }
@ -423,7 +360,7 @@ flags, and YAML configuration. The precedence is as follows:
if cfg.HTTPAddress.String() != "" { if cfg.HTTPAddress.String() != "" {
httpListener, err = net.Listen("tcp", cfg.HTTPAddress.String()) httpListener, err = net.Listen("tcp", cfg.HTTPAddress.String())
if err != nil { if err != nil {
return xerrors.Errorf("listen %q: %w", cfg.HTTPAddress.String(), err) return err
} }
defer httpListener.Close() defer httpListener.Close()
@ -438,7 +375,7 @@ flags, and YAML configuration. The precedence is as follows:
// We want to print out the address the user supplied, not the // We want to print out the address the user supplied, not the
// loopback device. // loopback device.
cmd.Println("Started HTTP listener at", (&url.URL{Scheme: "http", Host: listenAddrStr}).String()) _, _ = fmt.Fprintf(inv.Stdout, "Started HTTP listener at %s\n", (&url.URL{Scheme: "http", Host: listenAddrStr}).String())
// Set the http URL we want to use when connecting to ourselves. // Set the http URL we want to use when connecting to ourselves.
tcpAddr, tcpAddrValid := httpListener.Addr().(*net.TCPAddr) tcpAddr, tcpAddrValid := httpListener.Addr().(*net.TCPAddr)
@ -466,8 +403,8 @@ flags, and YAML configuration. The precedence is as follows:
// DEPRECATED: This redirect used to default to true. // DEPRECATED: This redirect used to default to true.
// It made more sense to have the redirect be opt-in. // It made more sense to have the redirect be opt-in.
if os.Getenv("CODER_TLS_REDIRECT_HTTP") == "true" || cmd.Flags().Changed("tls-redirect-http-to-https") { if inv.Environ.Get("CODER_TLS_REDIRECT_HTTP") == "true" || inv.ParsedFlags().Changed("tls-redirect-http-to-https") {
cmd.PrintErr(cliui.Styles.Warn.Render("WARN:") + " --tls-redirect-http-to-https is deprecated, please use --redirect-to-access-url instead\n") cliui.Warn(inv.Stderr, "--tls-redirect-http-to-https is deprecated, please use --redirect-to-access-url instead")
cfg.RedirectToAccessURL = cfg.TLS.RedirectHTTP cfg.RedirectToAccessURL = cfg.TLS.RedirectHTTP
} }
@ -483,7 +420,7 @@ flags, and YAML configuration. The precedence is as follows:
} }
httpsListenerInner, err := net.Listen("tcp", cfg.TLS.Address.String()) httpsListenerInner, err := net.Listen("tcp", cfg.TLS.Address.String())
if err != nil { if err != nil {
return xerrors.Errorf("listen %q: %w", cfg.TLS.Address.String(), err) return err
} }
defer httpsListenerInner.Close() defer httpsListenerInner.Close()
@ -502,7 +439,7 @@ flags, and YAML configuration. The precedence is as follows:
// We want to print out the address the user supplied, not the // We want to print out the address the user supplied, not the
// loopback device. // loopback device.
cmd.Println("Started TLS/HTTPS listener at", (&url.URL{Scheme: "https", Host: listenAddrStr}).String()) _, _ = fmt.Fprintf(inv.Stdout, "Started TLS/HTTPS listener at %s\n", (&url.URL{Scheme: "https", Host: listenAddrStr}).String())
// Set the https URL we want to use when connecting to // Set the https URL we want to use when connecting to
// ourselves. // ourselves.
@ -547,7 +484,7 @@ flags, and YAML configuration. The precedence is as follows:
tunnelDone <-chan struct{} = make(chan struct{}, 1) tunnelDone <-chan struct{} = make(chan struct{}, 1)
) )
if cfg.AccessURL.String() == "" { if cfg.AccessURL.String() == "" {
cmd.Printf("Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL\n") cliui.Infof(inv.Stderr, "Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL\n")
tunnel, err = devtunnel.New(ctx, logger.Named("devtunnel"), cfg.WgtunnelHost.String()) tunnel, err = devtunnel.New(ctx, logger.Named("devtunnel"), cfg.WgtunnelHost.String())
if err != nil { if err != nil {
return xerrors.Errorf("create tunnel: %w", err) return xerrors.Errorf("create tunnel: %w", err)
@ -586,14 +523,15 @@ flags, and YAML configuration. The precedence is as follows:
if isLocal { if isLocal {
reason = "isn't externally reachable" reason = "isn't externally reachable"
} }
cmd.Printf( cliui.Warnf(
"%s The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL by not specifying an access URL.\n", inv.Stderr,
cliui.Styles.Warn.Render("Warning:"), cliui.Styles.Field.Render(cfg.AccessURL.String()), reason, "The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL by not specifying an access URL.\n",
cliui.Styles.Field.Render(cfg.AccessURL.String()), reason,
) )
} }
// A newline is added before for visibility in terminal output. // A newline is added before for visibility in terminal output.
cmd.Printf("\nView the Web UI: %s\n", cfg.AccessURL.String()) cliui.Infof(inv.Stdout, "\nView the Web UI: %s\n", cfg.AccessURL.String())
// Used for zero-trust instance identity with Google Cloud. // Used for zero-trust instance identity with Google Cloud.
googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication()) googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication())
@ -943,7 +881,7 @@ flags, and YAML configuration. The precedence is as follows:
// than abstracting the Coder API itself. // than abstracting the Coder API itself.
coderAPI, coderAPICloser, err := newAPI(ctx, options) coderAPI, coderAPICloser, err := newAPI(ctx, options)
if err != nil { if err != nil {
return err return xerrors.Errorf("create coder API: %w", err)
} }
client := codersdk.New(localURL) client := codersdk.New(localURL)
@ -981,10 +919,15 @@ flags, and YAML configuration. The precedence is as follows:
_ = daemon.Close() _ = daemon.Close()
} }
}() }()
var provisionerdWaitGroup sync.WaitGroup
defer provisionerdWaitGroup.Wait()
provisionerdMetrics := provisionerd.NewMetrics(options.PrometheusRegistry) provisionerdMetrics := provisionerd.NewMetrics(options.PrometheusRegistry)
for i := int64(0); i < cfg.Provisioner.Daemons.Value(); i++ { for i := int64(0); i < cfg.Provisioner.Daemons.Value(); i++ {
daemonCacheDir := filepath.Join(cacheDir, fmt.Sprintf("provisioner-%d", i)) daemonCacheDir := filepath.Join(cacheDir, fmt.Sprintf("provisioner-%d", i))
daemon, err := newProvisionerDaemon(ctx, coderAPI, provisionerdMetrics, logger, cfg, daemonCacheDir, errCh, false) daemon, err := newProvisionerDaemon(
ctx, coderAPI, provisionerdMetrics, logger, cfg, daemonCacheDir, errCh, false, &provisionerdWaitGroup,
)
if err != nil { if err != nil {
return xerrors.Errorf("create provisioner daemon: %w", err) return xerrors.Errorf("create provisioner daemon: %w", err)
} }
@ -1064,7 +1007,7 @@ flags, and YAML configuration. The precedence is as follows:
} }
}() }()
cmd.Println("\n==> Logs will stream in below (press ctrl+c to gracefully exit):") cliui.Infof(inv.Stdout, "\n==> Logs will stream in below (press ctrl+c to gracefully exit):")
// Updates the systemd status from activating to activated. // Updates the systemd status from activating to activated.
_, err = daemon.SdNotify(false, daemon.SdNotifyReady) _, err = daemon.SdNotify(false, daemon.SdNotifyReady)
@ -1084,7 +1027,7 @@ flags, and YAML configuration. The precedence is as follows:
select { select {
case <-notifyCtx.Done(): case <-notifyCtx.Done():
exitErr = notifyCtx.Err() exitErr = notifyCtx.Err()
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render( _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Bold.Render(
"Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit", "Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit",
)) ))
case <-tunnelDone: case <-tunnelDone:
@ -1092,7 +1035,7 @@ flags, and YAML configuration. The precedence is as follows:
case exitErr = <-errCh: case exitErr = <-errCh:
} }
if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) { if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) {
cmd.Printf("Unexpected error, shutting down server: %s\n", exitErr) cliui.Errorf(inv.Stderr, "Unexpected error, shutting down server: %s\n", exitErr)
} }
// Begin clean shut down stage, we try to shut down services // Begin clean shut down stage, we try to shut down services
@ -1104,18 +1047,18 @@ flags, and YAML configuration. The precedence is as follows:
_, err = daemon.SdNotify(false, daemon.SdNotifyStopping) _, err = daemon.SdNotify(false, daemon.SdNotifyStopping)
if err != nil { if err != nil {
cmd.Printf("Notify systemd failed: %s", err) cliui.Errorf(inv.Stderr, "Notify systemd failed: %s", err)
} }
// Stop accepting new connections without interrupting // Stop accepting new connections without interrupting
// in-flight requests, give in-flight requests 5 seconds to // in-flight requests, give in-flight requests 5 seconds to
// complete. // complete.
cmd.Println("Shutting down API server...") cliui.Info(inv.Stdout, "Shutting down API server..."+"\n")
err = shutdownWithTimeout(httpServer.Shutdown, 3*time.Second) err = shutdownWithTimeout(httpServer.Shutdown, 3*time.Second)
if err != nil { if err != nil {
cmd.Printf("API server shutdown took longer than 3s: %s\n", err) cliui.Errorf(inv.Stderr, "API server shutdown took longer than 3s: %s\n", err)
} else { } else {
cmd.Printf("Gracefully shut down API server\n") cliui.Info(inv.Stdout, "Gracefully shut down API server\n")
} }
// Cancel any remaining in-flight requests. // Cancel any remaining in-flight requests.
shutdownConns() shutdownConns()
@ -1130,36 +1073,36 @@ flags, and YAML configuration. The precedence is as follows:
go func() { go func() {
defer wg.Done() defer wg.Done()
if ok, _ := cmd.Flags().GetBool(varVerbose); ok { if ok, _ := inv.ParsedFlags().GetBool(varVerbose); ok {
cmd.Printf("Shutting down provisioner daemon %d...\n", id) cliui.Infof(inv.Stdout, "Shutting down provisioner daemon %d...\n", id)
} }
err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second) err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second)
if err != nil { if err != nil {
cmd.PrintErrf("Failed to shutdown provisioner daemon %d: %s\n", id, err) cliui.Errorf(inv.Stderr, "Failed to shutdown provisioner daemon %d: %s\n", id, err)
return return
} }
err = provisionerDaemon.Close() err = provisionerDaemon.Close()
if err != nil { if err != nil {
cmd.PrintErrf("Close provisioner daemon %d: %s\n", id, err) cliui.Errorf(inv.Stderr, "Close provisioner daemon %d: %s\n", id, err)
return return
} }
if ok, _ := cmd.Flags().GetBool(varVerbose); ok { if ok, _ := inv.ParsedFlags().GetBool(varVerbose); ok {
cmd.Printf("Gracefully shut down provisioner daemon %d\n", id) cliui.Infof(inv.Stdout, "Gracefully shut down provisioner daemon %d\n", id)
} }
}() }()
} }
wg.Wait() wg.Wait()
cmd.Println("Waiting for WebSocket connections to close...") cliui.Info(inv.Stdout, "Waiting for WebSocket connections to close..."+"\n")
_ = coderAPICloser.Close() _ = coderAPICloser.Close()
cmd.Println("Done waiting for WebSocket connections") cliui.Info(inv.Stdout, "Done waiting for WebSocket connections"+"\n")
// Close tunnel after we no longer have in-flight connections. // Close tunnel after we no longer have in-flight connections.
if tunnel != nil { if tunnel != nil {
cmd.Println("Waiting for tunnel to close...") cliui.Infof(inv.Stdout, "Waiting for tunnel to close...")
_ = tunnel.Close() _ = tunnel.Close()
<-tunnel.Wait() <-tunnel.Wait()
cmd.Println("Done waiting for tunnel") cliui.Infof(inv.Stdout, "Done waiting for tunnel")
} }
// Ensures a last report can be sent before exit! // Ensures a last report can be sent before exit!
@ -1168,40 +1111,49 @@ flags, and YAML configuration. The precedence is as follows:
// Trigger context cancellation for any remaining services. // Trigger context cancellation for any remaining services.
cancel() cancel()
if xerrors.Is(exitErr, context.Canceled) { switch {
case xerrors.Is(exitErr, context.DeadlineExceeded):
cliui.Warnf(inv.Stderr, "Graceful shutdown timed out")
// Errors here cause a significant number of benign CI failures.
return nil
case xerrors.Is(exitErr, context.Canceled):
return nil
case exitErr != nil:
return xerrors.Errorf("graceful shutdown: %w", exitErr)
default:
return nil return nil
} }
return exitErr
}, },
} }
var pgRawURL bool var pgRawURL bool
postgresBuiltinURLCmd := &cobra.Command{
postgresBuiltinURLCmd := &clibase.Cmd{
Use: "postgres-builtin-url", Use: "postgres-builtin-url",
Short: "Output the connection URL for the built-in PostgreSQL deployment.", Short: "Output the connection URL for the built-in PostgreSQL deployment.",
RunE: func(cmd *cobra.Command, _ []string) error { Handler: func(inv *clibase.Invocation) error {
cfg := createConfig(cmd) url, err := embeddedPostgresURL(r.createConfig())
url, err := embeddedPostgresURL(cfg)
if err != nil { if err != nil {
return err return err
} }
if pgRawURL { if pgRawURL {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", url) _, _ = fmt.Fprintf(inv.Stdout, "%s\n", url)
} else { } else {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url))) _, _ = fmt.Fprintf(inv.Stdout, "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url)))
} }
return nil return nil
}, },
} }
postgresBuiltinServeCmd := &cobra.Command{
postgresBuiltinServeCmd := &clibase.Cmd{
Use: "postgres-builtin-serve", Use: "postgres-builtin-serve",
Short: "Run the built-in PostgreSQL deployment.", Short: "Run the built-in PostgreSQL deployment.",
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
ctx := cmd.Context() ctx := inv.Context()
cfg := createConfig(cmd) cfg := r.createConfig()
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr())) logger := slog.Make(sloghuman.Sink(inv.Stderr))
if ok, _ := cmd.Flags().GetBool(varVerbose); ok { if ok, _ := inv.ParsedFlags().GetBool(varVerbose); ok {
logger = logger.Leveled(slog.LevelDebug) logger = logger.Leveled(slog.LevelDebug)
} }
@ -1215,25 +1167,34 @@ flags, and YAML configuration. The precedence is as follows:
defer func() { _ = closePg() }() defer func() { _ = closePg() }()
if pgRawURL { if pgRawURL {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", url) _, _ = fmt.Fprintf(inv.Stdout, "%s\n", url)
} else { } else {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url))) _, _ = fmt.Fprintf(inv.Stdout, "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url)))
} }
<-ctx.Done() <-ctx.Done()
return nil return nil
}, },
} }
postgresBuiltinURLCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.")
postgresBuiltinServeCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.")
createAdminUserCommand := newCreateAdminUserCommand() createAdminUserCmd := r.newCreateAdminUserCommand()
root.SetHelpFunc(func(cmd *cobra.Command, args []string) {
// Help is handled by clibase in command body.
})
root.AddCommand(postgresBuiltinURLCmd, postgresBuiltinServeCmd, createAdminUserCommand)
return root rawURLOpt := clibase.Option{
Flag: "raw-url",
Value: clibase.BoolOf(&pgRawURL),
Description: "Output the raw connection URL instead of a psql command.",
}
createAdminUserCmd.Options.Add(rawURLOpt)
postgresBuiltinURLCmd.Options.Add(rawURLOpt)
postgresBuiltinServeCmd.Options.Add(rawURLOpt)
serverCmd.Children = append(
serverCmd.Children,
createAdminUserCmd, postgresBuiltinURLCmd, postgresBuiltinServeCmd,
)
return serverCmd
} }
// isLocalURL returns true if the hostname of the provided URL appears to // isLocalURL returns true if the hostname of the provided URL appears to
@ -1269,6 +1230,7 @@ func newProvisionerDaemon(
cacheDir string, cacheDir string,
errCh chan error, errCh chan error,
dev bool, dev bool,
wg *sync.WaitGroup,
) (srv *provisionerd.Server, err error) { ) (srv *provisionerd.Server, err error) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer func() { defer func() {
@ -1283,12 +1245,16 @@ func newProvisionerDaemon(
} }
terraformClient, terraformServer := provisionersdk.MemTransportPipe() terraformClient, terraformServer := provisionersdk.MemTransportPipe()
wg.Add(1)
go func() { go func() {
defer wg.Done()
<-ctx.Done() <-ctx.Done()
_ = terraformClient.Close() _ = terraformClient.Close()
_ = terraformServer.Close() _ = terraformServer.Close()
}() }()
wg.Add(1)
go func() { go func() {
defer wg.Done()
defer cancel() defer cancel()
err := terraform.Serve(ctx, &terraform.ServeOptions{ err := terraform.Serve(ctx, &terraform.ServeOptions{
@ -1317,12 +1283,16 @@ func newProvisionerDaemon(
// include echo provisioner when in dev mode // include echo provisioner when in dev mode
if dev { if dev {
echoClient, echoServer := provisionersdk.MemTransportPipe() echoClient, echoServer := provisionersdk.MemTransportPipe()
wg.Add(1)
go func() { go func() {
defer wg.Done()
<-ctx.Done() <-ctx.Done()
_ = echoClient.Close() _ = echoClient.Close()
_ = echoServer.Close() _ = echoServer.Close()
}() }()
wg.Add(1)
go func() { go func() {
defer wg.Done()
defer cancel() defer cancel()
err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer}) err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer})
@ -1355,13 +1325,13 @@ func newProvisionerDaemon(
} }
// nolint: revive // nolint: revive
func printLogo(cmd *cobra.Command) { func printLogo(inv *clibase.Invocation) {
// Only print the logo in TTYs. // Only print the logo in TTYs.
if !isTTYOut(cmd) { if !isTTYOut(inv) {
return return
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s - Your Self-Hosted Remote Development Platform\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version())) _, _ = fmt.Fprintf(inv.Stdout, "%s - Your Self-Hosted Remote Development Platform\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version()))
} }
func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) { func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) {
@ -1760,7 +1730,7 @@ func isLocalhost(host string) bool {
return host == "localhost" || host == "127.0.0.1" || host == "::1" return host == "localhost" || host == "127.0.0.1" || host == "::1"
} }
func buildLogger(cmd *cobra.Command, cfg *codersdk.DeploymentValues) (slog.Logger, func(), error) { func buildLogger(inv *clibase.Invocation, cfg *codersdk.DeploymentValues) (slog.Logger, func(), error) {
var ( var (
sinks = []slog.Sink{} sinks = []slog.Sink{}
closers = []func() error{} closers = []func() error{}
@ -1771,10 +1741,10 @@ func buildLogger(cmd *cobra.Command, cfg *codersdk.DeploymentValues) (slog.Logge
case "": case "":
case "/dev/stdout": case "/dev/stdout":
sinks = append(sinks, sinkFn(cmd.OutOrStdout())) sinks = append(sinks, sinkFn(inv.Stdout))
case "/dev/stderr": case "/dev/stderr":
sinks = append(sinks, sinkFn(cmd.ErrOrStderr())) sinks = append(sinks, sinkFn(inv.Stderr))
default: default:
fi, err := os.OpenFile(loc, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644) fi, err := os.OpenFile(loc, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)

View File

@ -4,16 +4,15 @@ package cli
import ( import (
"fmt" "fmt"
"os"
"os/signal" "os/signal"
"sort" "sort"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman" "cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/gitsshkey"
@ -23,7 +22,7 @@ import (
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func newCreateAdminUserCommand() *cobra.Command { func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd {
var ( var (
newUserDBURL string newUserDBURL string
newUserSSHKeygenAlgorithm string newUserSSHKeygenAlgorithm string
@ -31,36 +30,20 @@ func newCreateAdminUserCommand() *cobra.Command {
newUserEmail string newUserEmail string
newUserPassword string newUserPassword string
) )
createAdminUserCommand := &cobra.Command{ createAdminUserCommand := &clibase.Cmd{
Use: "create-admin-user", Use: "create-admin-user",
Short: "Create a new admin user with the given username, email and password and adds it to every organization.", Short: "Create a new admin user with the given username, email and password and adds it to every organization.",
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
ctx := cmd.Context() ctx := inv.Context()
sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(newUserSSHKeygenAlgorithm) sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(newUserSSHKeygenAlgorithm)
if err != nil { if err != nil {
return xerrors.Errorf("parse ssh keygen algorithm %q: %w", newUserSSHKeygenAlgorithm, err) return xerrors.Errorf("parse ssh keygen algorithm %q: %w", newUserSSHKeygenAlgorithm, err)
} }
if val, exists := os.LookupEnv("CODER_POSTGRES_URL"); exists { cfg := r.createConfig()
newUserDBURL = val logger := slog.Make(sloghuman.Sink(inv.Stderr))
} if r.verbose {
if val, exists := os.LookupEnv("CODER_SSH_KEYGEN_ALGORITHM"); exists {
newUserSSHKeygenAlgorithm = val
}
if val, exists := os.LookupEnv("CODER_USERNAME"); exists {
newUserUsername = val
}
if val, exists := os.LookupEnv("CODER_EMAIL"); exists {
newUserEmail = val
}
if val, exists := os.LookupEnv("CODER_PASSWORD"); exists {
newUserPassword = val
}
cfg := createConfig(cmd)
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()))
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
logger = logger.Leveled(slog.LevelDebug) logger = logger.Leveled(slog.LevelDebug)
} }
@ -68,7 +51,7 @@ func newCreateAdminUserCommand() *cobra.Command {
defer cancel() defer cancel()
if newUserDBURL == "" { if newUserDBURL == "" {
cmd.Printf("Using built-in PostgreSQL (%s)\n", cfg.PostgresPath()) cliui.Infof(inv.Stdout, "Using built-in PostgreSQL (%s)\n", cfg.PostgresPath())
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger) url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)
if err != nil { if err != nil {
return err return err
@ -110,7 +93,7 @@ func newCreateAdminUserCommand() *cobra.Command {
} }
if newUserUsername == "" { if newUserUsername == "" {
newUserUsername, err = cliui.Prompt(cmd, cliui.PromptOptions{ newUserUsername, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Username", Text: "Username",
Validate: func(val string) error { Validate: func(val string) error {
if val == "" { if val == "" {
@ -124,7 +107,7 @@ func newCreateAdminUserCommand() *cobra.Command {
} }
} }
if newUserEmail == "" { if newUserEmail == "" {
newUserEmail, err = cliui.Prompt(cmd, cliui.PromptOptions{ newUserEmail, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Email", Text: "Email",
Validate: func(val string) error { Validate: func(val string) error {
if val == "" { if val == "" {
@ -138,7 +121,7 @@ func newCreateAdminUserCommand() *cobra.Command {
} }
} }
if newUserPassword == "" { if newUserPassword == "" {
newUserPassword, err = cliui.Prompt(cmd, cliui.PromptOptions{ newUserPassword, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Password", Text: "Password",
Secret: true, Secret: true,
Validate: func(val string) error { Validate: func(val string) error {
@ -153,7 +136,7 @@ func newCreateAdminUserCommand() *cobra.Command {
} }
// Prompt again. // Prompt again.
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm password", Text: "Confirm password",
Secret: true, Secret: true,
Validate: func(val string) error { Validate: func(val string) error {
@ -191,7 +174,7 @@ func newCreateAdminUserCommand() *cobra.Command {
return orgs[i].Name < orgs[j].Name return orgs[i].Name < orgs[j].Name
}) })
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Creating user...") _, _ = fmt.Fprintln(inv.Stderr, "Creating user...")
newUser, err = tx.InsertUser(ctx, database.InsertUserParams{ newUser, err = tx.InsertUser(ctx, database.InsertUserParams{
ID: uuid.New(), ID: uuid.New(),
Email: newUserEmail, Email: newUserEmail,
@ -206,7 +189,7 @@ func newCreateAdminUserCommand() *cobra.Command {
return xerrors.Errorf("insert user: %w", err) return xerrors.Errorf("insert user: %w", err)
} }
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Generating user SSH key...") _, _ = fmt.Fprintln(inv.Stderr, "Generating user SSH key...")
privateKey, publicKey, err := gitsshkey.Generate(sshKeygenAlgorithm) privateKey, publicKey, err := gitsshkey.Generate(sshKeygenAlgorithm)
if err != nil { if err != nil {
return xerrors.Errorf("generate user gitsshkey: %w", err) return xerrors.Errorf("generate user gitsshkey: %w", err)
@ -223,7 +206,7 @@ func newCreateAdminUserCommand() *cobra.Command {
} }
for _, org := range orgs { for _, org := range orgs {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Adding user to organization %q (%s) as admin...\n", org.Name, org.ID.String()) _, _ = fmt.Fprintf(inv.Stderr, "Adding user to organization %q (%s) as admin...\n", org.Name, org.ID.String())
_, err := tx.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{ _, err := tx.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
OrganizationID: org.ID, OrganizationID: org.ID,
UserID: newUser.ID, UserID: newUser.ID,
@ -242,21 +225,50 @@ func newCreateAdminUserCommand() *cobra.Command {
return err return err
} }
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "") _, _ = fmt.Fprintln(inv.Stderr, "")
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "User created successfully.") _, _ = fmt.Fprintln(inv.Stderr, "User created successfully.")
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "ID: "+newUser.ID.String()) _, _ = fmt.Fprintln(inv.Stderr, "ID: "+newUser.ID.String())
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Username: "+newUser.Username) _, _ = fmt.Fprintln(inv.Stderr, "Username: "+newUser.Username)
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Email: "+newUser.Email) _, _ = fmt.Fprintln(inv.Stderr, "Email: "+newUser.Email)
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Password: ********") _, _ = fmt.Fprintln(inv.Stderr, "Password: ********")
return nil return nil
}, },
} }
createAdminUserCommand.Flags().StringVar(&newUserDBURL, "postgres-url", "", "URL of a PostgreSQL database. If empty, the built-in PostgreSQL deployment will be used (Coder must not be already running in this case). Consumes $CODER_POSTGRES_URL.")
createAdminUserCommand.Flags().StringVar(&newUserSSHKeygenAlgorithm, "ssh-keygen-algorithm", "ed25519", "The algorithm to use for generating ssh keys. Accepted values are \"ed25519\", \"ecdsa\", or \"rsa4096\". Consumes $CODER_SSH_KEYGEN_ALGORITHM.") createAdminUserCommand.Options.Add(
createAdminUserCommand.Flags().StringVar(&newUserUsername, "username", "", "The username of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_USERNAME.") clibase.Option{
createAdminUserCommand.Flags().StringVar(&newUserEmail, "email", "", "The email of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_EMAIL.") Env: "CODER_POSTGRES_URL",
createAdminUserCommand.Flags().StringVar(&newUserPassword, "password", "", "The password of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_PASSWORD.") Flag: "postgres-url",
Description: "URL of a PostgreSQL database. If empty, the built-in PostgreSQL deployment will be used (Coder must not be already running in this case).",
Value: clibase.StringOf(&newUserDBURL),
},
clibase.Option{
Env: "CODER_SSH_KEYGEN_ALGORITHM",
Flag: "ssh-keygen-algorithm",
Description: "The algorithm to use for generating ssh keys. Accepted values are \"ed25519\", \"ecdsa\", or \"rsa4096\".",
Default: "ed25519",
Value: clibase.StringOf(&newUserSSHKeygenAlgorithm),
},
clibase.Option{
Env: "CODER_USERNAME",
Flag: "username",
Description: "The username of the new user. If not specified, you will be prompted via stdin.",
Value: clibase.StringOf(&newUserUsername),
},
clibase.Option{
Env: "CODER_EMAIL",
Flag: "email",
Description: "The email of the new user. If not specified, you will be prompted via stdin.",
Value: clibase.StringOf(&newUserEmail),
},
clibase.Option{
Env: "CODER_PASSWORD",
Flag: "password",
Description: "The password of the new user. If not specified, you will be prompted via stdin.",
Value: clibase.StringOf(&newUserPassword),
},
)
return createAdminUserCommand return createAdminUserCommand
} }

View File

@ -92,9 +92,7 @@ func TestServerCreateAdminUser(t *testing.T) {
defer sqlDB.Close() defer sqlDB.Close()
db := database.New(sqlDB) db := database.New(sqlDB)
// Sometimes generating SSH keys takes a really long time if there isn't ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
// enough entropy. We don't want the tests to fail in these cases.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel() defer cancel()
pingCtx, pingCancel := context.WithTimeout(ctx, testutil.WaitShort) pingCtx, pingCancel := context.WithTimeout(ctx, testutil.WaitShort)
@ -120,7 +118,7 @@ func TestServerCreateAdminUser(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "create-admin-user", "server", "create-admin-user",
"--postgres-url", connectionURL, "--postgres-url", connectionURL,
"--ssh-keygen-algorithm", "ed25519", "--ssh-keygen-algorithm", "ed25519",
@ -129,14 +127,9 @@ func TestServerCreateAdminUser(t *testing.T) {
"--password", password, "--password", password,
) )
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) inv.Stdout = pty.Output()
root.SetErr(pty.Output()) inv.Stderr = pty.Output()
errC := make(chan error, 1) clitest.Start(t, inv)
go func() {
err := root.ExecuteContext(ctx)
t.Log("root.ExecuteContext() returned:", err)
errC <- err
}()
pty.ExpectMatchContext(ctx, "Creating user...") pty.ExpectMatchContext(ctx, "Creating user...")
pty.ExpectMatchContext(ctx, "Generating user SSH key...") pty.ExpectMatchContext(ctx, "Generating user SSH key...")
@ -147,13 +140,11 @@ func TestServerCreateAdminUser(t *testing.T) {
pty.ExpectMatchContext(ctx, email) pty.ExpectMatchContext(ctx, email)
pty.ExpectMatchContext(ctx, "****") pty.ExpectMatchContext(ctx, "****")
require.NoError(t, <-errC)
verifyUser(t, connectionURL, username, email, password) verifyUser(t, connectionURL, username, email, password)
}) })
//nolint:paralleltest
t.Run("Env", func(t *testing.T) { t.Run("Env", func(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" || testing.Short() { if runtime.GOOS != "linux" || testing.Short() {
// Skip on non-Linux because it spawns a PostgreSQL instance. // Skip on non-Linux because it spawns a PostgreSQL instance.
t.SkipNow() t.SkipNow()
@ -162,35 +153,26 @@ func TestServerCreateAdminUser(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer closeFunc() defer closeFunc()
// Sometimes generating SSH keys takes a really long time if there isn't ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
// enough entropy. We don't want the tests to fail in these cases.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel() defer cancel()
t.Setenv("CODER_POSTGRES_URL", connectionURL) inv, _ := clitest.New(t, "server", "create-admin-user")
t.Setenv("CODER_SSH_KEYGEN_ALGORITHM", "ed25519") inv.Environ.Set("CODER_POSTGRES_URL", connectionURL)
t.Setenv("CODER_USERNAME", username) inv.Environ.Set("CODER_SSH_KEYGEN_ALGORITHM", "ed25519")
t.Setenv("CODER_EMAIL", email) inv.Environ.Set("CODER_USERNAME", username)
t.Setenv("CODER_PASSWORD", password) inv.Environ.Set("CODER_EMAIL", email)
inv.Environ.Set("CODER_PASSWORD", password)
root, _ := clitest.New(t, "server", "create-admin-user")
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) inv.Stdout = pty.Output()
root.SetErr(pty.Output()) inv.Stderr = pty.Output()
errC := make(chan error, 1) clitest.Start(t, inv)
go func() {
err := root.ExecuteContext(ctx)
t.Log("root.ExecuteContext() returned:", err)
errC <- err
}()
pty.ExpectMatchContext(ctx, "User created successfully.") pty.ExpectMatchContext(ctx, "User created successfully.")
pty.ExpectMatchContext(ctx, username) pty.ExpectMatchContext(ctx, username)
pty.ExpectMatchContext(ctx, email) pty.ExpectMatchContext(ctx, email)
pty.ExpectMatchContext(ctx, "****") pty.ExpectMatchContext(ctx, "****")
require.NoError(t, <-errC)
verifyUser(t, connectionURL, username, email, password) verifyUser(t, connectionURL, username, email, password)
}) })
@ -205,34 +187,25 @@ func TestServerCreateAdminUser(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer closeFunc() defer closeFunc()
// Sometimes generating SSH keys takes a really long time if there isn't ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
// enough entropy. We don't want the tests to fail in these cases.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel() defer cancel()
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "create-admin-user", "server", "create-admin-user",
"--postgres-url", connectionURL, "--postgres-url", connectionURL,
"--ssh-keygen-algorithm", "ed25519", "--ssh-keygen-algorithm", "ed25519",
) )
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
root.SetIn(pty.Input())
root.SetOutput(pty.Output())
root.SetErr(pty.Output())
errC := make(chan error, 1)
go func() {
err := root.ExecuteContext(ctx)
t.Log("root.ExecuteContext() returned:", err)
errC <- err
}()
pty.ExpectMatchContext(ctx, "> Username") clitest.Start(t, inv)
pty.ExpectMatchContext(ctx, "Username")
pty.WriteLine(username) pty.WriteLine(username)
pty.ExpectMatchContext(ctx, "> Email") pty.ExpectMatchContext(ctx, "Email")
pty.WriteLine(email) pty.WriteLine(email)
pty.ExpectMatchContext(ctx, "> Password") pty.ExpectMatchContext(ctx, "Password")
pty.WriteLine(password) pty.WriteLine(password)
pty.ExpectMatchContext(ctx, "> Confirm password") pty.ExpectMatchContext(ctx, "Confirm password")
pty.WriteLine(password) pty.WriteLine(password)
pty.ExpectMatchContext(ctx, "User created successfully.") pty.ExpectMatchContext(ctx, "User created successfully.")
@ -240,8 +213,6 @@ func TestServerCreateAdminUser(t *testing.T) {
pty.ExpectMatchContext(ctx, email) pty.ExpectMatchContext(ctx, email)
pty.ExpectMatchContext(ctx, "****") pty.ExpectMatchContext(ctx, "****")
require.NoError(t, <-errC)
verifyUser(t, connectionURL, username, email, password) verifyUser(t, connectionURL, username, email, password)
}) })
@ -267,10 +238,10 @@ func TestServerCreateAdminUser(t *testing.T) {
"--password", "x", "--password", "x",
) )
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) root.Stdout = pty.Output()
root.SetErr(pty.Output()) root.Stderr = pty.Output()
err = root.ExecuteContext(ctx) err = root.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "'email' failed on the 'email' tag") require.ErrorContains(t, err, "'email' failed on the 'email' tag")
require.ErrorContains(t, err, "'username' failed on the 'username' tag") require.ErrorContains(t, err, "'username' failed on the 'username' tag")

View File

@ -8,72 +8,24 @@ import (
"io" "io"
"os" "os"
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd" "github.com/coder/coder/coderd"
) )
func Server(_ func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command { func (r *RootCmd) Server(_ func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *clibase.Cmd {
root := &cobra.Command{ root := &clibase.Cmd{
Use: "server", Use: "server",
Short: "Start a Coder server", Short: "Start a Coder server",
// We accept RawArgs so all commands and flags are accepted.
RawArgs: true,
Hidden: true, Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
serverUnsupported(cmd.ErrOrStderr()) serverUnsupported(inv.Stderr)
return nil return nil
}, },
} }
var pgRawURL bool
postgresBuiltinURLCmd := &cobra.Command{
Use: "postgres-builtin-url",
Short: "Output the connection URL for the built-in PostgreSQL deployment.",
Hidden: true,
RunE: func(cmd *cobra.Command, _ []string) error {
serverUnsupported(cmd.ErrOrStderr())
return nil
},
}
postgresBuiltinServeCmd := &cobra.Command{
Use: "postgres-builtin-serve",
Short: "Run the built-in PostgreSQL deployment.",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
serverUnsupported(cmd.ErrOrStderr())
return nil
},
}
var (
newUserDBURL string
newUserSSHKeygenAlgorithm string
newUserUsername string
newUserEmail string
newUserPassword string
)
createAdminUserCommand := &cobra.Command{
Use: "create-admin-user",
Short: "Create a new admin user with the given username, email and password and adds it to every organization.",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
serverUnsupported(cmd.ErrOrStderr())
return nil
},
}
// We still have to attach the flags to the commands so users don't get
// an error when they try to use them.
postgresBuiltinURLCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.")
postgresBuiltinServeCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.")
createAdminUserCommand.Flags().StringVar(&newUserDBURL, "postgres-url", "", "URL of a PostgreSQL database. If empty, the built-in PostgreSQL deployment will be used (Coder must not be already running in this case). Consumes $CODER_POSTGRES_URL.")
createAdminUserCommand.Flags().StringVar(&newUserSSHKeygenAlgorithm, "ssh-keygen-algorithm", "ed25519", "The algorithm to use for generating ssh keys. Accepted values are \"ed25519\", \"ecdsa\", or \"rsa4096\". Consumes $CODER_SSH_KEYGEN_ALGORITHM.")
createAdminUserCommand.Flags().StringVar(&newUserUsername, "username", "", "The username of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_USERNAME.")
createAdminUserCommand.Flags().StringVar(&newUserEmail, "email", "", "The email of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_EMAIL.")
createAdminUserCommand.Flags().StringVar(&newUserPassword, "password", "", "The password of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_PASSWORD.")
root.AddCommand(postgresBuiltinURLCmd, postgresBuiltinServeCmd, createAdminUserCommand)
return root return root
} }

View File

@ -108,78 +108,66 @@ func TestServer(t *testing.T) {
connectionURL, closeFunc, err := postgres.Open() connectionURL, closeFunc, err := postgres.Open()
require.NoError(t, err) require.NoError(t, err)
defer closeFunc() defer closeFunc()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, cfg := clitest.New(t, // Postgres + race detector + CI = slow.
ctx := testutil.Context(t, testutil.WaitSuperLong*3)
inv, cfg := clitest.New(t,
"server", "server",
"--http-address", ":0", "--http-address", ":0",
"--access-url", "http://example.com", "--access-url", "http://example.com",
"--postgres-url", connectionURL, "--postgres-url", connectionURL,
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t) clitest.Start(t, inv.WithContext(ctx))
root.SetOutput(pty.Output())
root.SetErr(pty.Output())
errC := make(chan error, 1)
go func() {
errC <- root.ExecuteContext(ctx)
}()
accessURL := waitAccessURL(t, cfg) accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL) client := codersdk.New(accessURL)
_, err = client.CreateFirstUser(ctx, coderdtest.FirstUserParams) _, err = client.CreateFirstUser(ctx, coderdtest.FirstUserParams)
require.NoError(t, err) require.NoError(t, err)
cancelFunc()
require.NoError(t, <-errC)
}) })
t.Run("BuiltinPostgres", func(t *testing.T) { t.Run("BuiltinPostgres", func(t *testing.T) {
t.Parallel() t.Parallel()
if testing.Short() { if testing.Short() {
t.SkipNow() t.SkipNow()
} }
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, cfg := clitest.New(t, inv, cfg := clitest.New(t,
"server", "server",
"--http-address", ":0", "--http-address", ":0",
"--access-url", "http://example.com", "--access-url", "http://example.com",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t)
root.SetOutput(pty.Output()) const superDuperLong = testutil.WaitSuperLong * 3
root.SetErr(pty.Output())
errC := make(chan error, 1) ctx := testutil.Context(t, superDuperLong)
go func() { clitest.Start(t, inv.WithContext(ctx))
errC <- root.ExecuteContext(ctx)
}()
//nolint:gocritic // Embedded postgres take a while to fire up. //nolint:gocritic // Embedded postgres take a while to fire up.
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
rawURL, err := cfg.URL().Read() rawURL, err := cfg.URL().Read()
return err == nil && rawURL != "" return err == nil && rawURL != ""
}, 3*time.Minute, testutil.IntervalFast, "failed to get access URL") }, superDuperLong, testutil.IntervalFast, "failed to get access URL")
cancelFunc()
require.NoError(t, <-errC)
}) })
t.Run("BuiltinPostgresURL", func(t *testing.T) { t.Run("BuiltinPostgresURL", func(t *testing.T) {
t.Parallel() t.Parallel()
root, _ := clitest.New(t, "server", "postgres-builtin-url") root, _ := clitest.New(t, "server", "postgres-builtin-url")
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) root.Stdout = pty.Output()
err := root.Execute() err := root.Run()
require.NoError(t, err) require.NoError(t, err)
pty.ExpectMatch("psql") pty.ExpectMatch("psql")
}) })
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) { t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url") root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) root.Stdout = pty.Output()
err := root.ExecuteContext(ctx) err := root.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
got := pty.ReadLine(ctx) got := pty.ReadLine(ctx)
@ -192,93 +180,62 @@ func TestServer(t *testing.T) {
// reachable. // reachable.
t.Run("LocalAccessURL", func(t *testing.T) { t.Run("LocalAccessURL", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) inv, cfg := clitest.New(t,
defer cancelFunc()
root, cfg := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":0", "--http-address", ":0",
"--access-url", "http://localhost:3000/", "--access-url", "http://localhost:3000/",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
root.SetIn(pty.Input()) clitest.Start(t, inv)
root.SetOut(pty.Output())
errC := make(chan error, 1)
go func() {
errC <- root.ExecuteContext(ctx)
}()
// Just wait for startup // Just wait for startup
_ = waitAccessURL(t, cfg) _ = waitAccessURL(t, cfg)
pty.ExpectMatch("this may cause unexpected problems when creating workspaces") pty.ExpectMatch("this may cause unexpected problems when creating workspaces")
pty.ExpectMatch("View the Web UI: http://localhost:3000/") pty.ExpectMatch("View the Web UI: http://localhost:3000/")
cancelFunc()
require.NoError(t, <-errC)
}) })
// Validate that an https scheme is prepended to a remote access URL // Validate that an https scheme is prepended to a remote access URL
// and that a warning is printed for a host that cannot be resolved. // and that a warning is printed for a host that cannot be resolved.
t.Run("RemoteAccessURL", func(t *testing.T) { t.Run("RemoteAccessURL", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, cfg := clitest.New(t, inv, cfg := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":0", "--http-address", ":0",
"--access-url", "https://foobarbaz.mydomain", "--access-url", "https://foobarbaz.mydomain",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
root.SetIn(pty.Input())
root.SetOut(pty.Output()) clitest.Start(t, inv)
errC := make(chan error, 1)
go func() {
errC <- root.ExecuteContext(ctx)
}()
// Just wait for startup // Just wait for startup
_ = waitAccessURL(t, cfg) _ = waitAccessURL(t, cfg)
pty.ExpectMatch("this may cause unexpected problems when creating workspaces") pty.ExpectMatch("this may cause unexpected problems when creating workspaces")
pty.ExpectMatch("View the Web UI: https://foobarbaz.mydomain") pty.ExpectMatch("View the Web UI: https://foobarbaz.mydomain")
cancelFunc()
require.NoError(t, <-errC)
}) })
t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) { t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) inv, cfg := clitest.New(t,
defer cancelFunc()
root, cfg := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":0", "--http-address", ":0",
"--access-url", "https://google.com", "--access-url", "https://google.com",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
root.SetIn(pty.Input()) clitest.Start(t, inv)
root.SetOut(pty.Output())
errC := make(chan error, 1)
go func() {
errC <- root.ExecuteContext(ctx)
}()
// Just wait for startup // Just wait for startup
_ = waitAccessURL(t, cfg) _ = waitAccessURL(t, cfg)
pty.ExpectMatch("View the Web UI: https://google.com") pty.ExpectMatch("View the Web UI: https://google.com")
cancelFunc()
require.NoError(t, <-errC)
}) })
t.Run("NoSchemeAccessURL", func(t *testing.T) { t.Run("NoSchemeAccessURL", func(t *testing.T) {
@ -293,7 +250,7 @@ func TestServer(t *testing.T) {
"--access-url", "google.com", "--access-url", "google.com",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
err := root.ExecuteContext(ctx) err := root.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
}) })
@ -312,7 +269,7 @@ func TestServer(t *testing.T) {
"--tls-min-version", "tls9", "--tls-min-version", "tls9",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
err := root.ExecuteContext(ctx) err := root.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
}) })
t.Run("TLSBadClientAuth", func(t *testing.T) { t.Run("TLSBadClientAuth", func(t *testing.T) {
@ -330,7 +287,7 @@ func TestServer(t *testing.T) {
"--tls-client-auth", "something", "--tls-client-auth", "something",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
err := root.ExecuteContext(ctx) err := root.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
}) })
t.Run("TLSInvalid", func(t *testing.T) { t.Run("TLSInvalid", func(t *testing.T) {
@ -382,7 +339,7 @@ func TestServer(t *testing.T) {
} }
args = append(args, c.args...) args = append(args, c.args...)
root, _ := clitest.New(t, args...) root, _ := clitest.New(t, args...)
err := root.ExecuteContext(ctx) err := root.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
t.Logf("args: %v", args) t.Logf("args: %v", args)
require.ErrorContains(t, err, c.errContains) require.ErrorContains(t, err, c.errContains)
@ -406,7 +363,7 @@ func TestServer(t *testing.T) {
"--tls-key-file", keyPath, "--tls-key-file", keyPath,
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
clitest.Start(ctx, t, root) clitest.Start(t, root.WithContext(ctx))
// Verify HTTPS // Verify HTTPS
accessURL := waitAccessURL(t, cfg) accessURL := waitAccessURL(t, cfg)
@ -445,8 +402,8 @@ func TestServer(t *testing.T) {
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOut(pty.Output()) root.Stdout = pty.Output()
clitest.Start(ctx, t, root) clitest.Start(t, root.WithContext(ctx))
accessURL := waitAccessURL(t, cfg) accessURL := waitAccessURL(t, cfg)
require.Equal(t, "https", accessURL.Scheme) require.Equal(t, "https", accessURL.Scheme)
@ -511,7 +468,7 @@ func TestServer(t *testing.T) {
defer cancelFunc() defer cancelFunc()
certPath, keyPath := generateTLSCertificate(t) certPath, keyPath := generateTLSCertificate(t)
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":0", "--http-address", ":0",
@ -523,14 +480,8 @@ func TestServer(t *testing.T) {
"--tls-key-file", keyPath, "--tls-key-file", keyPath,
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
root.SetOutput(pty.Output()) clitest.Start(t, inv)
root.SetErr(pty.Output())
errC := make(chan error, 1)
go func() {
errC <- root.ExecuteContext(ctx)
}()
// We can't use waitAccessURL as it will only return the HTTP URL. // We can't use waitAccessURL as it will only return the HTTP URL.
const httpLinePrefix = "Started HTTP listener at" const httpLinePrefix = "Started HTTP listener at"
@ -572,9 +523,6 @@ func TestServer(t *testing.T) {
defer client.HTTPClient.CloseIdleConnections() defer client.HTTPClient.CloseIdleConnections()
_, err = client.HasFirstUser(ctx) _, err = client.HasFirstUser(ctx)
require.NoError(t, err) require.NoError(t, err)
cancelFunc()
require.NoError(t, <-errC)
}) })
t.Run("TLSRedirect", func(t *testing.T) { t.Run("TLSRedirect", func(t *testing.T) {
@ -670,15 +618,11 @@ func TestServer(t *testing.T) {
flags = append(flags, "--redirect-to-access-url") flags = append(flags, "--redirect-to-access-url")
} }
root, _ := clitest.New(t, flags...) inv, _ := clitest.New(t, flags...)
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) pty.Attach(inv)
root.SetErr(pty.Output())
errC := make(chan error, 1) clitest.Start(t, inv)
go func() {
errC <- root.ExecuteContext(ctx)
}()
var ( var (
httpAddr string httpAddr string
@ -742,8 +686,6 @@ func TestServer(t *testing.T) {
if err != nil { if err != nil {
require.ErrorContains(t, err, "Invalid application URL") require.ErrorContains(t, err, "Invalid application URL")
} }
cancelFunc()
require.NoError(t, <-errC)
} }
}) })
} }
@ -762,18 +704,19 @@ func TestServer(t *testing.T) {
) )
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) root.Stdout = pty.Output()
root.SetErr(pty.Output()) root.Stderr = pty.Output()
serverStop := make(chan error, 1) serverStop := make(chan error, 1)
go func() { go func() {
err := root.ExecuteContext(ctx) err := root.WithContext(ctx).Run()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
close(serverStop) close(serverStop)
}() }()
pty.ExpectMatch("Started HTTP listener at http://0.0.0.0:") pty.ExpectMatch("Started HTTP listener")
pty.ExpectMatch("http://0.0.0.0:")
cancelFunc() cancelFunc()
<-serverStop <-serverStop
@ -781,32 +724,19 @@ func TestServer(t *testing.T) {
t.Run("CanListenUnspecifiedv6", func(t *testing.T) { t.Run("CanListenUnspecifiedv6", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", "[::]:0", "--http-address", "[::]:0",
"--access-url", "http://example.com", "--access-url", "http://example.com",
) )
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
root.SetOutput(pty.Output()) clitest.Start(t, inv)
root.SetErr(pty.Output())
serverClose := make(chan struct{}, 1)
go func() {
err := root.ExecuteContext(ctx)
if err != nil {
t.Error(err)
}
close(serverClose)
}()
pty.ExpectMatch("Started HTTP listener at http://[::]:") pty.ExpectMatch("Started HTTP listener at")
pty.ExpectMatch("http://[::]:")
cancelFunc()
<-serverClose
}) })
t.Run("NoAddress", func(t *testing.T) { t.Run("NoAddress", func(t *testing.T) {
@ -814,14 +744,14 @@ func TestServer(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":80", "--http-address", ":80",
"--tls-enable=false", "--tls-enable=false",
"--tls-address", "", "--tls-address", "",
) )
err := root.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "tls-address") require.ErrorContains(t, err, "tls-address")
}) })
@ -831,13 +761,13 @@ func TestServer(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--tls-enable=true", "--tls-enable=true",
"--tls-address", "", "--tls-address", "",
) )
err := root.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "must not be empty") require.ErrorContains(t, err, "must not be empty")
}) })
@ -854,7 +784,7 @@ func TestServer(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, cfg := clitest.New(t, inv, cfg := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--address", ":0", "--address", ":0",
@ -862,9 +792,9 @@ func TestServer(t *testing.T) {
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) inv.Stdout = pty.Output()
root.SetErr(pty.Output()) inv.Stderr = pty.Output()
clitest.Start(ctx, t, root) clitest.Start(t, inv.WithContext(ctx))
pty.ExpectMatch("is deprecated") pty.ExpectMatch("is deprecated")
@ -892,9 +822,9 @@ func TestServer(t *testing.T) {
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) root.Stdout = pty.Output()
root.SetErr(pty.Output()) root.Stderr = pty.Output()
clitest.Start(ctx, t, root) clitest.Start(t, root.WithContext(ctx))
pty.ExpectMatch("is deprecated") pty.ExpectMatch("is deprecated")
@ -935,7 +865,7 @@ func TestServer(t *testing.T) {
) )
serverErr := make(chan error, 1) serverErr := make(chan error, 1)
go func() { go func() {
serverErr <- root.ExecuteContext(ctx) serverErr <- root.WithContext(ctx).Run()
}() }()
_ = waitAccessURL(t, cfg) _ = waitAccessURL(t, cfg)
currentProcess, err := os.FindProcess(os.Getpid()) currentProcess, err := os.FindProcess(os.Getpid())
@ -949,10 +879,8 @@ func TestServer(t *testing.T) {
}) })
t.Run("TracerNoLeak", func(t *testing.T) { t.Run("TracerNoLeak", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":0", "--http-address", ":0",
@ -960,18 +888,14 @@ func TestServer(t *testing.T) {
"--trace=true", "--trace=true",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
errC := make(chan error, 1) ctx, cancel := context.WithCancel(context.Background())
go func() { defer cancel()
errC <- root.ExecuteContext(ctx) clitest.Start(t, inv.WithContext(ctx))
}() cancel()
cancelFunc()
require.NoError(t, <-errC)
require.Error(t, goleak.Find()) require.Error(t, goleak.Find())
}) })
t.Run("Telemetry", func(t *testing.T) { t.Run("Telemetry", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
deployment := make(chan struct{}, 64) deployment := make(chan struct{}, 64)
snapshot := make(chan *telemetry.Snapshot, 64) snapshot := make(chan *telemetry.Snapshot, 64)
@ -990,7 +914,7 @@ func TestServer(t *testing.T) {
server := httptest.NewServer(r) server := httptest.NewServer(r)
defer server.Close() defer server.Close()
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":0", "--http-address", ":0",
@ -999,21 +923,13 @@ func TestServer(t *testing.T) {
"--telemetry-url", server.URL, "--telemetry-url", server.URL,
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
errC := make(chan error, 1) clitest.Start(t, inv)
go func() {
errC <- root.ExecuteContext(ctx)
}()
<-deployment <-deployment
<-snapshot <-snapshot
cancelFunc()
<-errC
}) })
t.Run("Prometheus", func(t *testing.T) { t.Run("Prometheus", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
random, err := net.Listen("tcp", "127.0.0.1:0") random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)
_ = random.Close() _ = random.Close()
@ -1021,7 +937,7 @@ func TestServer(t *testing.T) {
require.True(t, valid) require.True(t, valid)
randomPort := tcpAddr.Port randomPort := tcpAddr.Port
root, cfg := clitest.New(t, inv, cfg := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":0", "--http-address", ":0",
@ -1031,10 +947,11 @@ func TestServer(t *testing.T) {
"--prometheus-address", ":"+strconv.Itoa(randomPort), "--prometheus-address", ":"+strconv.Itoa(randomPort),
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
serverErr := make(chan error, 1)
go func() { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
serverErr <- root.ExecuteContext(ctx) defer cancel()
}()
clitest.Start(t, inv)
_ = waitAccessURL(t, cfg) _ = waitAccessURL(t, cfg)
var res *http.Response var res *http.Response
@ -1045,6 +962,7 @@ func TestServer(t *testing.T) {
res, err = http.DefaultClient.Do(req) res, err = http.DefaultClient.Do(req)
return err == nil return err == nil
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
defer res.Body.Close()
scanner := bufio.NewScanner(res.Body) scanner := bufio.NewScanner(res.Body)
hasActiveUsers := false hasActiveUsers := false
@ -1065,16 +983,12 @@ func TestServer(t *testing.T) {
require.NoError(t, scanner.Err()) require.NoError(t, scanner.Err())
require.True(t, hasActiveUsers) require.True(t, hasActiveUsers)
require.True(t, hasWorkspaces) require.True(t, hasWorkspaces)
cancelFunc()
<-serverErr
}) })
t.Run("GitHubOAuth", func(t *testing.T) { t.Run("GitHubOAuth", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
fakeRedirect := "https://fake-url.com" fakeRedirect := "https://fake-url.com"
root, cfg := clitest.New(t, inv, cfg := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
"--http-address", ":0", "--http-address", ":0",
@ -1084,10 +998,7 @@ func TestServer(t *testing.T) {
"--oauth2-github-client-secret", "fake", "--oauth2-github-client-secret", "fake",
"--oauth2-github-enterprise-base-url", fakeRedirect, "--oauth2-github-enterprise-base-url", fakeRedirect,
) )
serverErr := make(chan error, 1) clitest.Start(t, inv)
go func() {
serverErr <- root.ExecuteContext(ctx)
}()
accessURL := waitAccessURL(t, cfg) accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL) client := codersdk.New(accessURL)
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
@ -1095,7 +1006,7 @@ func TestServer(t *testing.T) {
} }
githubURL, err := accessURL.Parse("/api/v2/users/oauth2/github") githubURL, err := accessURL.Parse("/api/v2/users/oauth2/github")
require.NoError(t, err) require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, githubURL.String(), nil) req, err := http.NewRequestWithContext(inv.Context(), http.MethodGet, githubURL.String(), nil)
require.NoError(t, err) require.NoError(t, err)
res, err := client.HTTPClient.Do(req) res, err := client.HTTPClient.Do(req)
require.NoError(t, err) require.NoError(t, err)
@ -1103,8 +1014,6 @@ func TestServer(t *testing.T) {
fakeURL, err := res.Location() fakeURL, err := res.Location()
require.NoError(t, err) require.NoError(t, err)
require.True(t, strings.HasPrefix(fakeURL.String(), fakeRedirect), fakeURL.String()) require.True(t, strings.HasPrefix(fakeURL.String(), fakeRedirect), fakeURL.String())
cancelFunc()
<-serverErr
}) })
t.Run("RateLimit", func(t *testing.T) { t.Run("RateLimit", func(t *testing.T) {
@ -1123,7 +1032,7 @@ func TestServer(t *testing.T) {
) )
serverErr := make(chan error, 1) serverErr := make(chan error, 1)
go func() { go func() {
serverErr <- root.ExecuteContext(ctx) serverErr <- root.WithContext(ctx).Run()
}() }()
accessURL := waitAccessURL(t, cfg) accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL) client := codersdk.New(accessURL)
@ -1152,7 +1061,7 @@ func TestServer(t *testing.T) {
) )
serverErr := make(chan error, 1) serverErr := make(chan error, 1)
go func() { go func() {
serverErr <- root.ExecuteContext(ctx) serverErr <- root.WithContext(ctx).Run()
}() }()
accessURL := waitAccessURL(t, cfg) accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL) client := codersdk.New(accessURL)
@ -1180,7 +1089,7 @@ func TestServer(t *testing.T) {
) )
serverErr := make(chan error, 1) serverErr := make(chan error, 1)
go func() { go func() {
serverErr <- root.ExecuteContext(ctx) serverErr <- root.WithContext(ctx).Run()
}() }()
accessURL := waitAccessURL(t, cfg) accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL) client := codersdk.New(accessURL)
@ -1230,9 +1139,9 @@ func TestServer(t *testing.T) {
"--access-url", "http://example.com", "--access-url", "http://example.com",
"--log-human", fiName, "--log-human", fiName,
) )
clitest.Start(context.Background(), t, root) clitest.Start(t, root)
waitFile(t, fiName, testutil.WaitShort) waitFile(t, fiName, testutil.WaitLong)
}) })
t.Run("Human", func(t *testing.T) { t.Run("Human", func(t *testing.T) {
@ -1247,7 +1156,7 @@ func TestServer(t *testing.T) {
"--access-url", "http://example.com", "--access-url", "http://example.com",
"--log-human", fi, "--log-human", fi,
) )
clitest.Start(context.Background(), t, root) clitest.Start(t, root)
waitFile(t, fi, testutil.WaitShort) waitFile(t, fi, testutil.WaitShort)
}) })
@ -1264,7 +1173,7 @@ func TestServer(t *testing.T) {
"--access-url", "http://example.com", "--access-url", "http://example.com",
"--log-json", fi, "--log-json", fi,
) )
clitest.Start(context.Background(), t, root) clitest.Start(t, root)
waitFile(t, fi, testutil.WaitShort) waitFile(t, fi, testutil.WaitShort)
}) })
@ -1276,7 +1185,7 @@ func TestServer(t *testing.T) {
fi := testutil.TempFile(t, "", "coder-logging-test-*") fi := testutil.TempFile(t, "", "coder-logging-test-*")
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "server",
"--verbose", "--verbose",
"--in-memory", "--in-memory",
@ -1286,18 +1195,9 @@ func TestServer(t *testing.T) {
) )
// Attach pty so we get debug output from the command if this test // Attach pty so we get debug output from the command if this test
// fails. // fails.
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
root.SetOut(pty.Output())
root.SetErr(pty.Output())
serverErr := make(chan error, 1) clitest.Start(t, inv.WithContext(ctx))
go func() {
serverErr <- root.ExecuteContext(ctx)
}()
defer func() {
cancelFunc()
<-serverErr
}()
// Wait for server to listen on HTTP, this is a good // Wait for server to listen on HTTP, this is a good
// starting point for expecting logs. // starting point for expecting logs.
@ -1319,7 +1219,7 @@ func TestServer(t *testing.T) {
// which can take a long time and end up failing the test. // which can take a long time and end up failing the test.
// This is why we wait extra long below for server to listen on // This is why we wait extra long below for server to listen on
// HTTP. // HTTP.
root, _ := clitest.New(t, inv, _ := clitest.New(t,
"server", "server",
"--verbose", "--verbose",
"--in-memory", "--in-memory",
@ -1331,11 +1231,9 @@ func TestServer(t *testing.T) {
) )
// Attach pty so we get debug output from the command if this test // Attach pty so we get debug output from the command if this test
// fails. // fails.
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
root.SetOut(pty.Output())
root.SetErr(pty.Output())
clitest.Start(ctx, t, root) clitest.Start(t, inv)
// Wait for server to listen on HTTP, this is a good // Wait for server to listen on HTTP, this is a good
// starting point for expecting logs. // starting point for expecting logs.

View File

@ -1,32 +1,32 @@
package cli package cli
import ( import (
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
) )
func show() *cobra.Command { func (r *RootCmd) show() *clibase.Cmd {
return &cobra.Command{ client := new(codersdk.Client)
Annotations: workspaceCommand, return &clibase.Cmd{
Use: "show <workspace>", Use: "show <workspace>",
Short: "Display details of a workspace's resources and agents", Short: "Display details of a workspace's resources and agents",
Args: cobra.ExactArgs(1), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(1),
client, err := CreateClient(cmd) r.InitClient(client),
if err != nil { ),
return err Handler: func(inv *clibase.Invocation) error {
} buildInfo, err := client.BuildInfo(inv.Context())
buildInfo, err := client.BuildInfo(cmd.Context())
if err != nil { if err != nil {
return xerrors.Errorf("get server version: %w", err) return xerrors.Errorf("get server version: %w", err)
} }
workspace, err := namedWorkspace(cmd, client, args[0]) workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return xerrors.Errorf("get workspace: %w", err) return xerrors.Errorf("get workspace: %w", err)
} }
return cliui.WorkspaceResources(cmd.OutOrStdout(), workspace.LatestBuild.Resources, cliui.WorkspaceResourcesOptions{ return cliui.WorkspaceResources(inv.Stdout, workspace.LatestBuild.Resources, cliui.WorkspaceResourcesOptions{
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
ServerVersion: buildInfo.Version, ServerVersion: buildInfo.Version,
}) })

View File

@ -31,15 +31,13 @@ func TestShow(t *testing.T) {
"show", "show",
workspace.Name, workspace.Name,
} }
cmd, root := clitest.New(t, args...) inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() { go func() {
defer close(doneChan) defer close(doneChan)
err := cmd.Execute() err := inv.Run()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
matches := []struct { matches := []struct {

View File

@ -6,43 +6,41 @@ import (
"time" "time"
"github.com/jedib0t/go-pretty/v6/table" "github.com/jedib0t/go-pretty/v6/table"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
tsspeedtest "tailscale.com/net/speedtest" tsspeedtest "tailscale.com/net/speedtest"
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman" "cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func speedtest() *cobra.Command { func (r *RootCmd) speedtest() *clibase.Cmd {
var ( var (
direct bool direct bool
duration time.Duration duration time.Duration
direction string direction string
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "speedtest <workspace>", Use: "speedtest <workspace>",
Args: cobra.ExactArgs(1),
Short: "Run upload and download tests from your machine to a workspace", Short: "Run upload and download tests from your machine to a workspace",
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.Chain(
ctx, cancel := context.WithCancel(cmd.Context()) clibase.RequireNArgs(1),
r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel() defer cancel()
client, err := CreateClient(cmd) workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, codersdk.Me, inv.Args[0])
if err != nil {
return xerrors.Errorf("create codersdk client: %w", err)
}
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
if err != nil { if err != nil {
return err return err
} }
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID) return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@ -53,9 +51,9 @@ func speedtest() *cobra.Command {
} }
logger, ok := LoggerFromContext(ctx) logger, ok := LoggerFromContext(ctx)
if !ok { if !ok {
logger = slog.Make(sloghuman.Sink(cmd.ErrOrStderr())) logger = slog.Make(sloghuman.Sink(inv.Stderr))
} }
if cliflag.IsSetBool(cmd, varVerbose) { if r.verbose {
logger = logger.Leveled(slog.LevelDebug) logger = logger.Leveled(slog.LevelDebug)
} }
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{ conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{
@ -84,14 +82,14 @@ func speedtest() *cobra.Command {
} }
peer := status.Peer[status.Peers()[0]] peer := status.Peer[status.Peers()[0]]
if !p2p && direct { if !p2p && direct {
cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay) cliui.Infof(inv.Stdout, "Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay)
continue continue
} }
via := peer.Relay via := peer.Relay
if via == "" { if via == "" {
via = "direct" via = "direct"
} }
cmd.Printf("%dms via %s\n", dur.Milliseconds(), via) cliui.Infof(inv.Stdout, "%dms via %s\n", dur.Milliseconds(), via)
break break
} }
} else { } else {
@ -106,7 +104,7 @@ func speedtest() *cobra.Command {
default: default:
return xerrors.Errorf("invalid direction: %q", direction) return xerrors.Errorf("invalid direction: %q", direction)
} }
cmd.Printf("Starting a %ds %s test...\n", int(duration.Seconds()), tsDir) cliui.Infof(inv.Stdout, "Starting a %ds %s test...\n", int(duration.Seconds()), tsDir)
results, err := conn.Speedtest(ctx, tsDir, duration) results, err := conn.Speedtest(ctx, tsDir, duration)
if err != nil { if err != nil {
return err return err
@ -123,16 +121,31 @@ func speedtest() *cobra.Command {
fmt.Sprintf("%.4f Mbits/sec", r.MBitsPerSecond()), fmt.Sprintf("%.4f Mbits/sec", r.MBitsPerSecond()),
}) })
} }
_, err = fmt.Fprintln(cmd.OutOrStdout(), tableWriter.Render()) _, err = fmt.Fprintln(inv.Stdout, tableWriter.Render())
return err return err
}, },
} }
cliflag.BoolVarP(cmd.Flags(), &direct, "direct", "d", "", false, cmd.Options = clibase.OptionSet{
"Specifies whether to wait for a direct connection before testing speed.") {
cliflag.StringVarP(cmd.Flags(), &direction, "direction", "", "", "down", Description: "Specifies whether to wait for a direct connection before testing speed.",
"Specifies whether to run in reverse mode where the client receives and the server sends. (up|down)", Flag: "direct",
) FlagShorthand: "d",
cmd.Flags().DurationVarP(&duration, "time", "t", tsspeedtest.DefaultDuration,
"Specifies the duration to monitor traffic.") Value: clibase.BoolOf(&direct),
},
{
Description: "Specifies whether to run in reverse mode where the client receives and the server sends.",
Flag: "direction",
Default: "down",
Value: clibase.EnumOf(&direction, "up", "down"),
},
{
Description: "Specifies the duration to monitor traffic.",
Flag: "time",
FlagShorthand: "t",
Default: tsspeedtest.DefaultDuration.String(),
Value: clibase.DurationOf(&duration),
},
}
return cmd return cmd
} }

View File

@ -48,18 +48,18 @@ func TestSpeedtest(t *testing.T) {
a.LifecycleState == codersdk.WorkspaceAgentLifecycleReady a.LifecycleState == codersdk.WorkspaceAgentLifecycleReady
}, testutil.WaitLong, testutil.IntervalFast, "agent is not ready") }, testutil.WaitLong, testutil.IntervalFast, "agent is not ready")
cmd, root := clitest.New(t, "speedtest", workspace.Name) inv, root := clitest.New(t, "speedtest", workspace.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
ctx = cli.ContextWithLogger(ctx, slogtest.Make(t, nil).Named("speedtest").Leveled(slog.LevelDebug)) ctx = cli.ContextWithLogger(ctx, slogtest.Make(t, nil).Named("speedtest").Leveled(slog.LevelDebug))
cmdDone := tGo(t, func() { cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.NoError(t, err) assert.NoError(t, err)
}) })
<-cmdDone <-cmdDone

View File

@ -18,14 +18,13 @@ import (
"github.com/gofrs/flock" "github.com/gofrs/flock"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mattn/go-isatty" "github.com/mattn/go-isatty"
"github.com/spf13/cobra"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
gosshagent "golang.org/x/crypto/ssh/agent" gosshagent "golang.org/x/crypto/ssh/agent"
"golang.org/x/term" "golang.org/x/term"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/agent" "github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/autobuild/notify" "github.com/coder/coder/coderd/autobuild/notify"
"github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/coderd/util/ptr"
@ -38,55 +37,41 @@ var (
autostopNotifyCountdown = []time.Duration{30 * time.Minute} autostopNotifyCountdown = []time.Duration{30 * time.Minute}
) )
func ssh() *cobra.Command { func (r *RootCmd) ssh() *clibase.Cmd {
var ( var (
stdio bool stdio bool
shuffle bool
forwardAgent bool forwardAgent bool
forwardGPG bool forwardGPG bool
identityAgent string identityAgent string
wsPollInterval time.Duration wsPollInterval time.Duration
noWait bool noWait bool
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "ssh <workspace>", Use: "ssh <workspace>",
Short: "Start a shell into a workspace", Short: "Start a shell into a workspace",
Args: cobra.ArbitraryArgs, Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(1),
ctx, cancel := context.WithCancel(cmd.Context()) r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel() defer cancel()
client, err := CreateClient(cmd) workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, codersdk.Me, inv.Args[0])
if err != nil {
return err
}
if shuffle {
err := cobra.ExactArgs(0)(cmd, args)
if err != nil {
return err
}
} else {
err := cobra.MinimumNArgs(1)(cmd, args)
if err != nil {
return err
}
}
workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle)
if err != nil { if err != nil {
return err return err
} }
updateWorkspaceBanner, outdated := verifyWorkspaceOutdated(client, workspace) updateWorkspaceBanner, outdated := verifyWorkspaceOutdated(client, workspace)
if outdated && isTTYErr(cmd) { if outdated && isTTYErr(inv) {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), updateWorkspaceBanner) _, _ = fmt.Fprintln(inv.Stderr, updateWorkspaceBanner)
} }
// OpenSSH passes stderr directly to the calling TTY. // OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed. // This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID) return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@ -120,9 +105,9 @@ func ssh() *cobra.Command {
defer rawSSH.Close() defer rawSSH.Close()
go func() { go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH) _, _ = io.Copy(inv.Stdout, rawSSH)
}() }()
_, _ = io.Copy(rawSSH, cmd.InOrStdin()) _, _ = io.Copy(rawSSH, inv.Stdin)
return nil return nil
} }
@ -168,15 +153,15 @@ func ssh() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("upload GPG public keys and ownertrust to workspace: %w", err) return xerrors.Errorf("upload GPG public keys and ownertrust to workspace: %w", err)
} }
closer, err := forwardGPGAgent(ctx, cmd.ErrOrStderr(), sshClient) closer, err := forwardGPGAgent(ctx, inv.Stderr, sshClient)
if err != nil { if err != nil {
return xerrors.Errorf("forward GPG socket: %w", err) return xerrors.Errorf("forward GPG socket: %w", err)
} }
defer closer.Close() defer closer.Close()
} }
stdoutFile, validOut := cmd.OutOrStdout().(*os.File) stdoutFile, validOut := inv.Stdout.(*os.File)
stdinFile, validIn := cmd.InOrStdin().(*os.File) stdinFile, validIn := inv.Stdin.(*os.File)
if validOut && validIn && isatty.IsTerminal(stdoutFile.Fd()) { if validOut && validIn && isatty.IsTerminal(stdoutFile.Fd()) {
state, err := term.MakeRaw(int(stdinFile.Fd())) state, err := term.MakeRaw(int(stdinFile.Fd()))
if err != nil { if err != nil {
@ -208,9 +193,9 @@ func ssh() *cobra.Command {
return err return err
} }
sshSession.Stdin = cmd.InOrStdin() sshSession.Stdin = inv.Stdin
sshSession.Stdout = cmd.OutOrStdout() sshSession.Stdout = inv.Stdout
sshSession.Stderr = cmd.ErrOrStderr() sshSession.Stderr = inv.Stderr
err = sshSession.Shell() err = sshSession.Shell()
if err != nil { if err != nil {
@ -243,53 +228,70 @@ func ssh() *cobra.Command {
return nil return nil
}, },
} }
cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.") cmd.Options = clibase.OptionSet{
cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace") {
_ = cmd.Flags().MarkHidden("shuffle") Flag: "stdio",
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK") Env: "CODER_SSH_STDIO",
cliflag.BoolVarP(cmd.Flags(), &forwardGPG, "forward-gpg", "G", "CODER_SSH_FORWARD_GPG", false, "Specifies whether to forward the GPG agent. Unsupported on Windows workspaces, but supports all clients. Requires gnupg (gpg, gpgconf) on both the client and workspace. The GPG agent must already be running locally and will not be started for you. If a GPG agent is already running in the workspace, it will be attempted to be killed.") Description: "Specifies whether to emit SSH output over stdin/stdout.",
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled") Value: clibase.BoolOf(&stdio),
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") },
cliflag.BoolVarP(cmd.Flags(), &noWait, "no-wait", "", "CODER_SSH_NO_WAIT", false, "Specifies whether to wait for a workspace to become ready before logging in (only applicable when the login before ready option has not been enabled). Note that the workspace agent may still be in the process of executing the startup script and the workspace may be in an incomplete state.") {
Flag: "forward-agent",
FlagShorthand: "A",
Env: "CODER_SSH_FORWARD_AGENT",
Description: "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK.",
Value: clibase.BoolOf(&forwardAgent),
},
{
Flag: "forward-gpg",
FlagShorthand: "G",
Env: "CODER_SSH_FORWARD_GPG",
Description: "Specifies whether to forward the GPG agent. Unsupported on Windows workspaces, but supports all clients. Requires gnupg (gpg, gpgconf) on both the client and workspace. The GPG agent must already be running locally and will not be started for you. If a GPG agent is already running in the workspace, it will be attempted to be killed.",
Value: clibase.BoolOf(&forwardGPG),
},
{
Flag: "identity-agent",
Env: "CODER_SSH_IDENTITY_AGENT",
Description: "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled.",
Value: clibase.StringOf(&identityAgent),
},
{
Flag: "workspace-poll-interval",
Env: "CODER_WORKSPACE_POLL_INTERVAL",
Description: "Specifies how often to poll for workspace automated shutdown.",
Default: "1m",
Value: clibase.DurationOf(&wsPollInterval),
},
{
Flag: "no-wait",
Env: "CODER_SSH_NO_WAIT",
Description: "Specifies whether to wait for a workspace to become ready before logging in (only applicable when the login before ready option has not been enabled). Note that the workspace agent may still be in the process of executing the startup script and the workspace may be in an incomplete state.",
Value: clibase.BoolOf(&noWait),
},
}
return cmd return cmd
} }
// getWorkspaceAgent returns the workspace and agent selected using either the // getWorkspaceAgent returns the workspace and agent selected using either the
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent // `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
// if `shuffle` is true. // if `shuffle` is true.
func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive func getWorkspaceAndAgent(ctx context.Context, inv *clibase.Invocation, client *codersdk.Client, userID string, in string) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
var ( var (
workspace codersdk.Workspace workspace codersdk.Workspace
workspaceParts = strings.Split(in, ".") workspaceParts = strings.Split(in, ".")
err error err error
) )
if shuffle {
res, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
Owner: userID,
})
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
if len(res.Workspaces) == 0 {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("no workspaces to shuffle")
}
workspace, err = cryptorand.Element(res.Workspaces) workspace, err = namedWorkspace(inv.Context(), client, workspaceParts[0])
if err != nil { if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
} }
} else {
workspace, err = namedWorkspace(cmd, client, workspaceParts[0])
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
}
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be in start transition to ssh") return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be in start transition to ssh")
} }
if workspace.LatestBuild.Job.CompletedAt == nil { if workspace.LatestBuild.Job.CompletedAt == nil {
err := cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID) err := cliui.WorkspaceBuild(ctx, inv.Stderr, client, workspace.LatestBuild.ID)
if err != nil { if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
} }
@ -322,9 +324,6 @@ func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *coder
} }
if workspaceAgent.ID == uuid.Nil { if workspaceAgent.ID == uuid.Nil {
if len(agents) > 1 { if len(agents) > 1 {
if !shuffle {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent")
}
workspaceAgent, err = cryptorand.Element(agents) workspaceAgent, err = cryptorand.Element(agents)
if err != nil { if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err

View File

@ -87,18 +87,15 @@ func TestSSH(t *testing.T) {
t.Parallel() t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t, nil) client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
cmd, root := clitest.New(t, "ssh", workspace.Name) inv, root := clitest.New(t, "ssh", workspace.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetErr(pty.Output())
cmd.SetOut(pty.Output())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
cmdDone := tGo(t, func() { cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.NoError(t, err) assert.NoError(t, err)
}) })
pty.ExpectMatch("Waiting") pty.ExpectMatch("Waiting")
@ -128,18 +125,18 @@ func TestSSH(t *testing.T) {
a[0].TroubleshootingUrl = wantURL a[0].TroubleshootingUrl = wantURL
return a return a
}) })
cmd, root := clitest.New(t, "ssh", workspace.Name) inv, root := clitest.New(t, "ssh", workspace.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
cmdDone := tGo(t, func() { cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.ErrorIs(t, err, cliui.Canceled) assert.ErrorIs(t, err, cliui.Canceled)
}) })
pty.ExpectMatch(wantURL) pty.ExpectMatch(wantURL)
@ -173,13 +170,13 @@ func TestSSH(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name) inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetIn(clientOutput) inv.Stdin = clientOutput
cmd.SetOut(serverInput) inv.Stdout = serverInput
cmd.SetErr(io.Discard) inv.Stderr = io.Discard
cmdDone := tGo(t, func() { cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -262,19 +259,17 @@ func TestSSH(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
cmd, root := clitest.New(t, inv, root := clitest.New(t,
"ssh", "ssh",
workspace.Name, workspace.Name,
"--forward-agent", "--forward-agent",
"--identity-agent", agentSock, // Overrides $SSH_AUTH_SOCK. "--identity-agent", agentSock, // Overrides $SSH_AUTH_SOCK.
) )
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input()) inv.Stderr = pty.Output()
cmd.SetOut(pty.Output())
cmd.SetErr(pty.Output())
cmdDone := tGo(t, func() { cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.NoError(t, err, "ssh command failed") assert.NoError(t, err, "ssh command failed")
}) })
@ -466,18 +461,18 @@ Expire-Date: 0
}) })
defer agentCloser.Close() defer agentCloser.Close()
cmd, root := clitest.New(t, inv, root := clitest.New(t,
"ssh", "ssh",
workspace.Name, workspace.Name,
"--forward-gpg", "--forward-gpg",
) )
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
tpty := ptytest.New(t) tpty := ptytest.New(t)
cmd.SetIn(tpty.Input()) inv.Stdin = tpty.Input()
cmd.SetOut(tpty.Output()) inv.Stdout = tpty.Output()
cmd.SetErr(tpty.Output()) inv.Stderr = tpty.Output()
cmdDone := tGo(t, func() { cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
assert.NoError(t, err, "ssh command failed") assert.NoError(t, err, "ssh command failed")
}) })
// Prevent the test from hanging if the asserts below kill the test // Prevent the test from hanging if the asserts below kill the test

View File

@ -4,43 +4,44 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func start() *cobra.Command { func (r *RootCmd) start() *clibase.Cmd {
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "start <workspace>", Use: "start <workspace>",
Short: "Start a workspace", Short: "Start a workspace",
Args: cobra.ExactArgs(1), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(1),
client, err := CreateClient(cmd) r.InitClient(client),
),
Options: clibase.OptionSet{
cliui.SkipPromptOption(),
},
Handler: func(inv *clibase.Invocation) error {
workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return err return err
} }
workspace, err := namedWorkspace(cmd, client, args[0]) build, err := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
if err != nil {
return err
}
build, err := client.CreateWorkspaceBuild(cmd.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStart, Transition: codersdk.WorkspaceTransitionStart,
}) })
if err != nil { if err != nil {
return err return err
} }
err = cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStdout(), client, build.ID) err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, build.ID)
if err != nil { if err != nil {
return err return err
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThe %s workspace has been started at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) _, _ = fmt.Fprintf(inv.Stdout, "\nThe %s workspace has been started at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp)))
return nil return nil
}, },
} }
cliui.AllowSkipPrompt(cmd)
return cmd return cmd
} }

View File

@ -6,78 +6,92 @@ import (
"os" "os"
"strconv" "strconv"
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func state() *cobra.Command { func (r *RootCmd) state() *clibase.Cmd {
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "state", Use: "state",
Short: "Manually manage Terraform state to fix broken workspaces", Short: "Manually manage Terraform state to fix broken workspaces",
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
return cmd.Help() return inv.Command.HelpHandler(inv)
},
Children: []*clibase.Cmd{
r.statePull(),
r.statePush(),
}, },
} }
cmd.AddCommand(statePull(), statePush())
return cmd return cmd
} }
func statePull() *cobra.Command { func (r *RootCmd) statePull() *clibase.Cmd {
var buildNumber int var buildNumber int64
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "pull <workspace> [file]", Use: "pull <workspace> [file]",
Short: "Pull a Terraform state file from a workspace.", Short: "Pull a Terraform state file from a workspace.",
Args: cobra.MinimumNArgs(1), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireRangeArgs(1, 2),
client, err := CreateClient(cmd) r.InitClient(client),
if err != nil { ),
return err Handler: func(inv *clibase.Invocation) error {
} var err error
var build codersdk.WorkspaceBuild var build codersdk.WorkspaceBuild
if buildNumber == 0 { if buildNumber == 0 {
workspace, err := namedWorkspace(cmd, client, args[0]) workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return err return err
} }
build = workspace.LatestBuild build = workspace.LatestBuild
} else { } else {
build, err = client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(cmd.Context(), codersdk.Me, args[0], strconv.Itoa(buildNumber)) build, err = client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(inv.Context(), codersdk.Me, inv.Args[0], strconv.FormatInt(buildNumber, 10))
if err != nil { if err != nil {
return err return err
} }
} }
state, err := client.WorkspaceBuildState(cmd.Context(), build.ID) state, err := client.WorkspaceBuildState(inv.Context(), build.ID)
if err != nil { if err != nil {
return err return err
} }
if len(args) < 2 { if len(inv.Args) < 2 {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), string(state)) _, _ = fmt.Fprintln(inv.Stdout, string(state))
return nil return nil
} }
return os.WriteFile(args[1], state, 0o600) return os.WriteFile(inv.Args[1], state, 0o600)
}, },
} }
cmd.Flags().IntVarP(&buildNumber, "build", "b", 0, "Specify a workspace build to target by name.") cmd.Options = clibase.OptionSet{
buildNumberOption(&buildNumber),
}
return cmd return cmd
} }
func statePush() *cobra.Command { func buildNumberOption(n *int64) clibase.Option {
var buildNumber int return clibase.Option{
cmd := &cobra.Command{ Flag: "build",
Use: "push <workspace> <file>", FlagShorthand: "b",
Args: cobra.ExactArgs(2), Description: "Specify a workspace build to target by name. Defaults to latest.",
Short: "Push a Terraform state file to a workspace.", Value: clibase.Int64Of(n),
RunE: func(cmd *cobra.Command, args []string) error {
client, err := CreateClient(cmd)
if err != nil {
return err
} }
workspace, err := namedWorkspace(cmd, client, args[0]) }
func (r *RootCmd) statePush() *clibase.Cmd {
var buildNumber int64
client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "push <workspace> <file>",
Short: "Push a Terraform state file to a workspace.",
Middleware: clibase.Chain(
clibase.RequireNArgs(2),
r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return err return err
} }
@ -85,23 +99,23 @@ func statePush() *cobra.Command {
if buildNumber == 0 { if buildNumber == 0 {
build = workspace.LatestBuild build = workspace.LatestBuild
} else { } else {
build, err = client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(cmd.Context(), codersdk.Me, args[0], strconv.Itoa(buildNumber)) build, err = client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(inv.Context(), codersdk.Me, inv.Args[0], strconv.FormatInt((buildNumber), 10))
if err != nil { if err != nil {
return err return err
} }
} }
var state []byte var state []byte
if args[1] == "-" { if inv.Args[1] == "-" {
state, err = io.ReadAll(cmd.InOrStdin()) state, err = io.ReadAll(inv.Stdin)
} else { } else {
state, err = os.ReadFile(args[1]) state, err = os.ReadFile(inv.Args[1])
} }
if err != nil { if err != nil {
return err return err
} }
build, err = client.CreateWorkspaceBuild(cmd.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ build, err = client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
TemplateVersionID: build.TemplateVersionID, TemplateVersionID: build.TemplateVersionID,
Transition: build.Transition, Transition: build.Transition,
ProvisionerState: state, ProvisionerState: state,
@ -109,9 +123,11 @@ func statePush() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
return cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStderr(), client, build.ID) return cliui.WorkspaceBuild(inv.Context(), inv.Stderr, client, build.ID)
}, },
} }
cmd.Flags().IntVarP(&buildNumber, "build", "b", 0, "Specify a workspace build to target by name.") cmd.Options = clibase.OptionSet{
buildNumberOption(&buildNumber),
}
return cmd return cmd
} }

View File

@ -38,9 +38,9 @@ func TestStatePull(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
statefilePath := filepath.Join(t.TempDir(), "state") statefilePath := filepath.Join(t.TempDir(), "state")
cmd, root := clitest.New(t, "state", "pull", workspace.Name, statefilePath) inv, root := clitest.New(t, "state", "pull", workspace.Name, statefilePath)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
err := cmd.Execute() err := inv.Run()
require.NoError(t, err) require.NoError(t, err)
gotState, err := os.ReadFile(statefilePath) gotState, err := os.ReadFile(statefilePath)
require.NoError(t, err) require.NoError(t, err)
@ -65,11 +65,11 @@ func TestStatePull(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, root := clitest.New(t, "state", "pull", workspace.Name) inv, root := clitest.New(t, "state", "pull", workspace.Name)
var gotState bytes.Buffer var gotState bytes.Buffer
cmd.SetOut(&gotState) inv.Stdout = &gotState
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
err := cmd.Execute() err := inv.Run()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, wantState, bytes.TrimSpace(gotState.Bytes())) require.Equal(t, wantState, bytes.TrimSpace(gotState.Bytes()))
}) })
@ -96,9 +96,9 @@ func TestStatePush(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
err = stateFile.Close() err = stateFile.Close()
require.NoError(t, err) require.NoError(t, err)
cmd, root := clitest.New(t, "state", "push", workspace.Name, stateFile.Name()) inv, root := clitest.New(t, "state", "push", workspace.Name, stateFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
err = cmd.Execute() err = inv.Run()
require.NoError(t, err) require.NoError(t, err)
}) })
@ -114,10 +114,10 @@ func TestStatePush(t *testing.T) {
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, root := clitest.New(t, "state", "push", "--build", strconv.Itoa(int(workspace.LatestBuild.BuildNumber)), workspace.Name, "-") inv, root := clitest.New(t, "state", "push", "--build", strconv.Itoa(int(workspace.LatestBuild.BuildNumber)), workspace.Name, "-")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
cmd.SetIn(strings.NewReader("some magic state")) inv.Stdin = strings.NewReader("some magic state")
err := cmd.Execute() err := inv.Run()
require.NoError(t, err) require.NoError(t, err)
}) })
} }

View File

@ -4,20 +4,26 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func stop() *cobra.Command { func (r *RootCmd) stop() *clibase.Cmd {
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "stop <workspace>", Use: "stop <workspace>",
Short: "Stop a workspace", Short: "Stop a workspace",
Args: cobra.ExactArgs(1), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireNArgs(1),
_, err := cliui.Prompt(cmd, cliui.PromptOptions{ r.InitClient(client),
),
Options: clibase.OptionSet{
cliui.SkipPromptOption(),
},
Handler: func(inv *clibase.Invocation) error {
_, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm stop workspace?", Text: "Confirm stop workspace?",
IsConfirm: true, IsConfirm: true,
}) })
@ -25,30 +31,25 @@ func stop() *cobra.Command {
return err return err
} }
client, err := CreateClient(cmd) workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil { if err != nil {
return err return err
} }
workspace, err := namedWorkspace(cmd, client, args[0]) build, err := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
if err != nil {
return err
}
build, err := client.CreateWorkspaceBuild(cmd.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStop, Transition: codersdk.WorkspaceTransitionStop,
}) })
if err != nil { if err != nil {
return err return err
} }
err = cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStdout(), client, build.ID) err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, build.ID)
if err != nil { if err != nil {
return err return err
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThe %s workspace has been stopped at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) _, _ = fmt.Fprintf(inv.Stdout, "\nThe %s workspace has been stopped at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp)))
return nil return nil
}, },
} }
cliui.AllowSkipPrompt(cmd)
return cmd return cmd
} }

View File

@ -11,9 +11,9 @@ import (
"unicode/utf8" "unicode/utf8"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/coderd/util/ptr"
@ -21,7 +21,7 @@ import (
"github.com/coder/coder/provisionerd" "github.com/coder/coder/provisionerd"
) )
func templateCreate() *cobra.Command { func (r *RootCmd) templateCreate() *clibase.Cmd {
var ( var (
provisioner string provisioner string
provisionerTags []string provisionerTags []string
@ -32,22 +32,21 @@ func templateCreate() *cobra.Command {
uploadFlags templateUploadFlags uploadFlags templateUploadFlags
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "create [name]", Use: "create [name]",
Short: "Create a template from the current directory or as specified by flag", Short: "Create a template from the current directory or as specified by flag",
Args: cobra.MaximumNArgs(1), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireRangeArgs(0, 1),
client, err := CreateClient(cmd) r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
organization, err := CurrentOrganization(inv, client)
if err != nil { if err != nil {
return err return err
} }
organization, err := CurrentOrganization(cmd, client) templateName, err := uploadFlags.templateName(inv.Args)
if err != nil {
return err
}
templateName, err := uploadFlags.templateName(args)
if err != nil { if err != nil {
return err return err
} }
@ -56,13 +55,13 @@ func templateCreate() *cobra.Command {
return xerrors.Errorf("Template name must be less than 32 characters") return xerrors.Errorf("Template name must be less than 32 characters")
} }
_, err = client.TemplateByName(cmd.Context(), organization.ID, templateName) _, err = client.TemplateByName(inv.Context(), organization.ID, templateName)
if err == nil { if err == nil {
return xerrors.Errorf("A template already exists named %q!", templateName) return xerrors.Errorf("A template already exists named %q!", templateName)
} }
// Confirm upload of the directory. // Confirm upload of the directory.
resp, err := uploadFlags.upload(cmd, client) resp, err := uploadFlags.upload(inv, client)
if err != nil { if err != nil {
return err return err
} }
@ -72,7 +71,7 @@ func templateCreate() *cobra.Command {
return err return err
} }
job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{ job, _, err := createValidTemplateVersion(inv, createValidTemplateVersionArgs{
Client: client, Client: client,
Organization: organization, Organization: organization,
Provisioner: database.ProvisionerType(provisioner), Provisioner: database.ProvisionerType(provisioner),
@ -87,7 +86,7 @@ func templateCreate() *cobra.Command {
} }
if !uploadFlags.stdin() { if !uploadFlags.stdin() {
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Confirm create?", Text: "Confirm create?",
IsConfirm: true, IsConfirm: true,
}) })
@ -102,34 +101,58 @@ func templateCreate() *cobra.Command {
DefaultTTLMillis: ptr.Ref(defaultTTL.Milliseconds()), DefaultTTLMillis: ptr.Ref(defaultTTL.Milliseconds()),
} }
_, err = client.CreateTemplate(cmd.Context(), organization.ID, createReq) _, err = client.CreateTemplate(inv.Context(), organization.ID, createReq)
if err != nil { if err != nil {
return err return err
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "\n"+cliui.Styles.Wrap.Render( _, _ = fmt.Fprintln(inv.Stdout, "\n"+cliui.Styles.Wrap.Render(
"The "+cliui.Styles.Keyword.Render(templateName)+" template has been created at "+cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))+"! "+ "The "+cliui.Styles.Keyword.Render(templateName)+" template has been created at "+cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))+"! "+
"Developers can provision a workspace with this template using:")+"\n") "Developers can provision a workspace with this template using:")+"\n")
_, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+cliui.Styles.Code.Render(fmt.Sprintf("coder create --template=%q [workspace name]", templateName))) _, _ = fmt.Fprintln(inv.Stdout, " "+cliui.Styles.Code.Render(fmt.Sprintf("coder create --template=%q [workspace name]", templateName)))
_, _ = fmt.Fprintln(cmd.OutOrStdout()) _, _ = fmt.Fprintln(inv.Stdout)
return nil return nil
}, },
} }
cmd.Flags().StringVarP(&parameterFile, "parameter-file", "", "", "Specify a file path with parameter values.") cmd.Options = clibase.OptionSet{
cmd.Flags().StringVarP(&variablesFile, "variables-file", "", "", "Specify a file path with values for Terraform-managed variables.") {
cmd.Flags().StringArrayVarP(&variables, "variable", "", []string{}, "Specify a set of values for Terraform-managed variables.") Flag: "parameter-file",
cmd.Flags().StringArrayVarP(&provisionerTags, "provisioner-tag", "", []string{}, "Specify a set of tags to target provisioner daemons.") Description: "Specify a file path with parameter values.",
cmd.Flags().DurationVarP(&defaultTTL, "default-ttl", "", 24*time.Hour, "Specify a default TTL for workspaces created from this template.") Value: clibase.StringOf(&parameterFile),
uploadFlags.register(cmd.Flags()) },
cmd.Flags().StringVarP(&provisioner, "test.provisioner", "", "terraform", "Customize the provisioner backend") {
// This is for testing! Flag: "variables-file",
err := cmd.Flags().MarkHidden("test.provisioner") Description: "Specify a file path with values for Terraform-managed variables.",
if err != nil { Value: clibase.StringOf(&variablesFile),
panic(err) },
{
Flag: "variable",
Description: "Specify a set of values for Terraform-managed variables.",
Value: clibase.StringArrayOf(&variables),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: clibase.StringArrayOf(&provisionerTags),
},
{
Flag: "default-ttl",
Description: "Specify a default TTL for workspaces created from this template.",
Default: "24h",
Value: clibase.DurationOf(&defaultTTL),
},
uploadFlags.option(),
{
Flag: "test.provisioner",
Description: "Customize the provisioner backend.",
Default: "terraform",
Value: clibase.StringOf(&provisioner),
Hidden: true,
},
cliui.SkipPromptOption(),
} }
cliui.AllowSkipPrompt(cmd)
return cmd return cmd
} }
@ -153,7 +176,7 @@ type createValidTemplateVersionArgs struct {
ProvisionerTags map[string]string ProvisionerTags map[string]string
} }
func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVersionArgs, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) { func createValidTemplateVersion(inv *clibase.Invocation, args createValidTemplateVersionArgs, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) {
client := args.Client client := args.Client
variableValues, err := loadVariableValuesFromFile(args.VariablesFile) variableValues, err := loadVariableValuesFromFile(args.VariablesFile)
@ -179,21 +202,21 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
if args.Template != nil { if args.Template != nil {
req.TemplateID = args.Template.ID req.TemplateID = args.Template.ID
} }
version, err := client.CreateTemplateVersion(cmd.Context(), args.Organization.ID, req) version, err := client.CreateTemplateVersion(inv.Context(), args.Organization.ID, req)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
version, err := client.TemplateVersion(cmd.Context(), version.ID) version, err := client.TemplateVersion(inv.Context(), version.ID)
return version.Job, err return version.Job, err
}, },
Cancel: func() error { Cancel: func() error {
return client.CancelTemplateVersion(cmd.Context(), version.ID) return client.CancelTemplateVersion(inv.Context(), version.ID)
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
return client.TemplateVersionLogsAfter(cmd.Context(), version.ID, 0) return client.TemplateVersionLogsAfter(inv.Context(), version.ID, 0)
}, },
}) })
if err != nil { if err != nil {
@ -202,15 +225,15 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
return nil, nil, err return nil, nil, err
} }
} }
version, err = client.TemplateVersion(cmd.Context(), version.ID) version, err = client.TemplateVersion(inv.Context(), version.ID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
parameterSchemas, err := client.TemplateVersionSchema(cmd.Context(), version.ID) parameterSchemas, err := client.TemplateVersionSchema(inv.Context(), version.ID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
parameterValues, err := client.TemplateVersionParameters(cmd.Context(), version.ID) parameterValues, err := client.TemplateVersionParameters(inv.Context(), version.ID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -220,13 +243,13 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
// version instead of prompting if we are updating template versions. // version instead of prompting if we are updating template versions.
lastParameterValues := make(map[string]codersdk.Parameter) lastParameterValues := make(map[string]codersdk.Parameter)
if args.ReuseParameters && args.Template != nil { if args.ReuseParameters && args.Template != nil {
activeVersion, err := client.TemplateVersion(cmd.Context(), args.Template.ActiveVersionID) activeVersion, err := client.TemplateVersion(inv.Context(), args.Template.ActiveVersionID)
if err != nil { if err != nil {
return nil, nil, xerrors.Errorf("Fetch current active template version: %w", err) return nil, nil, xerrors.Errorf("Fetch current active template version: %w", err)
} }
// We don't want to compute the params, we only want to copy from this scope // We don't want to compute the params, we only want to copy from this scope
values, err := client.Parameters(cmd.Context(), codersdk.ParameterImportJob, activeVersion.Job.ID) values, err := client.Parameters(inv.Context(), codersdk.ParameterImportJob, activeVersion.Job.ID)
if err != nil { if err != nil {
return nil, nil, xerrors.Errorf("Fetch previous version parameters: %w", err) return nil, nil, xerrors.Errorf("Fetch previous version parameters: %w", err)
} }
@ -244,7 +267,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
// parameterMapFromFile can be nil if parameter file is not specified // parameterMapFromFile can be nil if parameter file is not specified
var parameterMapFromFile map[string]string var parameterMapFromFile map[string]string
if args.ParameterFile != "" { if args.ParameterFile != "" {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n") _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n")
parameterMapFromFile, err = createParameterMapFromFile(args.ParameterFile) parameterMapFromFile, err = createParameterMapFromFile(args.ParameterFile)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -275,15 +298,15 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
missingSchemas = append(missingSchemas, parameterSchema) missingSchemas = append(missingSchemas, parameterSchema)
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("This template has required variables! They are scoped to the template, and not viewable after being set.")) _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("This template has required variables! They are scoped to the template, and not viewable after being set."))
if len(pulled) > 0 { if len(pulled) > 0 {
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render(fmt.Sprintf("The following parameter values are being pulled from the latest template version: %s.", strings.Join(pulled, ", ")))) _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render(fmt.Sprintf("The following parameter values are being pulled from the latest template version: %s.", strings.Join(pulled, ", "))))
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Use \"--always-prompt\" flag to change the values.")) _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("Use \"--always-prompt\" flag to change the values."))
} }
_, _ = fmt.Fprint(cmd.OutOrStdout(), "\r\n") _, _ = fmt.Fprint(inv.Stdout, "\r\n")
for _, parameterSchema := range missingSchemas { for _, parameterSchema := range missingSchemas {
parameterValue, err := getParameterValueFromMapOrInput(cmd, parameterMapFromFile, parameterSchema) parameterValue, err := getParameterValueFromMapOrInput(inv, parameterMapFromFile, parameterSchema)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -293,19 +316,19 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
SourceScheme: codersdk.ParameterSourceSchemeData, SourceScheme: codersdk.ParameterSourceSchemeData,
DestinationScheme: parameterSchema.DefaultDestinationScheme, DestinationScheme: parameterSchema.DefaultDestinationScheme,
}) })
_, _ = fmt.Fprintln(cmd.OutOrStdout()) _, _ = fmt.Fprintln(inv.Stdout)
} }
// This recursion is only 1 level deep in practice. // This recursion is only 1 level deep in practice.
// The first pass populates the missing parameters, so it does not enter this `if` block again. // The first pass populates the missing parameters, so it does not enter this `if` block again.
return createValidTemplateVersion(cmd, args, parameters...) return createValidTemplateVersion(inv, args, parameters...)
} }
if version.Job.Status != codersdk.ProvisionerJobSucceeded { if version.Job.Status != codersdk.ProvisionerJobSucceeded {
return nil, nil, xerrors.New(version.Job.Error) return nil, nil, xerrors.New(version.Job.Error)
} }
resources, err := client.TemplateVersionResources(cmd.Context(), version.ID) resources, err := client.TemplateVersionResources(inv.Context(), version.ID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -317,7 +340,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
startResources = append(startResources, r) startResources = append(startResources, r)
} }
} }
err = cliui.WorkspaceResources(cmd.OutOrStdout(), startResources, cliui.WorkspaceResourcesOptions{ err = cliui.WorkspaceResources(inv.Stdout, startResources, cliui.WorkspaceResourcesOptions{
HideAgentState: true, HideAgentState: true,
HideAccess: true, HideAccess: true,
Title: "Template Preview", Title: "Template Preview",

View File

@ -55,16 +55,11 @@ func TestTemplateCreate(t *testing.T) {
"--test.provisioner", string(database.ProvisionerTypeEcho), "--test.provisioner", string(database.ProvisionerTypeEcho),
"--default-ttl", "24h", "--default-ttl", "24h",
} }
cmd, root := clitest.New(t, args...) inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error) clitest.Start(t, inv)
go func() {
execDone <- cmd.Execute()
}()
matches := []struct { matches := []struct {
match string match string
@ -81,8 +76,6 @@ func TestTemplateCreate(t *testing.T) {
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
} }
require.NoError(t, <-execDone)
}) })
t.Run("CreateStdin", func(t *testing.T) { t.Run("CreateStdin", func(t *testing.T) {
@ -103,18 +96,13 @@ func TestTemplateCreate(t *testing.T) {
"--test.provisioner", string(database.ProvisionerTypeEcho), "--test.provisioner", string(database.ProvisionerTypeEcho),
"--default-ttl", "24h", "--default-ttl", "24h",
} }
cmd, root := clitest.New(t, args...) inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(bytes.NewReader(source)) inv.Stdin = bytes.NewReader(source)
cmd.SetOut(pty.Output()) inv.Stdout = pty.Output()
execDone := make(chan error) require.NoError(t, inv.Run())
go func() {
execDone <- cmd.Execute()
}()
require.NoError(t, <-execDone)
}) })
t.Run("WithParameter", func(t *testing.T) { t.Run("WithParameter", func(t *testing.T) {
@ -126,17 +114,11 @@ func TestTemplateCreate(t *testing.T) {
ProvisionApply: echo.ProvisionComplete, ProvisionApply: echo.ProvisionComplete,
ProvisionPlan: echo.ProvisionComplete, ProvisionPlan: echo.ProvisionComplete,
}) })
cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho)) inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho))
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error)
go func() {
execDone <- cmd.Execute()
}()
clitest.Start(t, inv)
matches := []struct { matches := []struct {
match string match string
write string write string
@ -149,8 +131,6 @@ func TestTemplateCreate(t *testing.T) {
pty.ExpectMatch(m.match) pty.ExpectMatch(m.match)
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
require.NoError(t, <-execDone)
}) })
t.Run("WithParameterFileContainingTheValue", func(t *testing.T) { t.Run("WithParameterFileContainingTheValue", func(t *testing.T) {
@ -166,16 +146,11 @@ func TestTemplateCreate(t *testing.T) {
removeTmpDirUntilSuccessAfterTest(t, tempDir) removeTmpDirUntilSuccessAfterTest(t, tempDir)
parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml") parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml")
_, _ = parameterFile.WriteString("region: \"bananas\"") _, _ = parameterFile.WriteString("region: \"bananas\"")
cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--parameter-file", parameterFile.Name()) inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--parameter-file", parameterFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error) clitest.Start(t, inv)
go func() {
execDone <- cmd.Execute()
}()
matches := []struct { matches := []struct {
match string match string
@ -188,8 +163,6 @@ func TestTemplateCreate(t *testing.T) {
pty.ExpectMatch(m.match) pty.ExpectMatch(m.match)
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
require.NoError(t, <-execDone)
}) })
t.Run("WithParameterFileNotContainingTheValue", func(t *testing.T) { t.Run("WithParameterFileNotContainingTheValue", func(t *testing.T) {
@ -205,16 +178,11 @@ func TestTemplateCreate(t *testing.T) {
removeTmpDirUntilSuccessAfterTest(t, tempDir) removeTmpDirUntilSuccessAfterTest(t, tempDir)
parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml") parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml")
_, _ = parameterFile.WriteString("zone: \"bananas\"") _, _ = parameterFile.WriteString("zone: \"bananas\"")
cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--parameter-file", parameterFile.Name()) inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--parameter-file", parameterFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error) clitest.Start(t, inv)
go func() {
execDone <- cmd.Execute()
}()
matches := []struct { matches := []struct {
match string match string
@ -237,8 +205,6 @@ func TestTemplateCreate(t *testing.T) {
pty.ExpectMatch(m.match) pty.ExpectMatch(m.match)
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
require.NoError(t, <-execDone)
}) })
t.Run("Recreate template with same name (create, delete, create)", func(t *testing.T) { t.Run("Recreate template with same name (create, delete, create)", func(t *testing.T) {
@ -259,10 +225,10 @@ func TestTemplateCreate(t *testing.T) {
"--directory", source, "--directory", source,
"--test.provisioner", string(database.ProvisionerTypeEcho), "--test.provisioner", string(database.ProvisionerTypeEcho),
} }
cmd, root := clitest.New(t, args...) inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
return cmd.Execute() return inv.Run()
} }
del := func() error { del := func() error {
args := []string{ args := []string{
@ -271,10 +237,10 @@ func TestTemplateCreate(t *testing.T) {
"my-template", "my-template",
"--yes", "--yes",
} }
cmd, root := clitest.New(t, args...) inv, root := clitest.New(t, args...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
return cmd.Execute() return inv.Run()
} }
err := create() err := create()
@ -289,15 +255,10 @@ func TestTemplateCreate(t *testing.T) {
t.Parallel() t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
coderdtest.CreateFirstUser(t, client) coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.New(t, "templates", "create", "1234567890123456789012345678901234567891", "--test.provisioner", string(database.ProvisionerTypeEcho)) inv, root := clitest.New(t, "templates", "create", "1234567890123456789012345678901234567891", "--test.provisioner", string(database.ProvisionerTypeEcho))
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
execDone := make(chan error) clitest.StartWithWaiter(t, inv).RequireContains("Template name must be less than 32 characters")
go func() {
execDone <- cmd.Execute()
}()
require.EqualError(t, <-execDone, "Template name must be less than 32 characters")
}) })
t.Run("WithVariablesFileWithoutRequiredValue", func(t *testing.T) { t.Run("WithVariablesFileWithoutRequiredValue", func(t *testing.T) {
@ -309,7 +270,7 @@ func TestTemplateCreate(t *testing.T) {
templateVariables := []*proto.TemplateVariable{ templateVariables := []*proto.TemplateVariable{
{ {
Name: "first_variable", Name: "first_variable",
Description: "This is the first variable", Description: "This is the first variable.",
Type: "string", Type: "string",
Required: true, Required: true,
Sensitive: true, Sensitive: true,
@ -329,17 +290,11 @@ func TestTemplateCreate(t *testing.T) {
removeTmpDirUntilSuccessAfterTest(t, tempDir) removeTmpDirUntilSuccessAfterTest(t, tempDir)
variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml") variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml")
_, _ = variablesFile.WriteString(`second_variable: foobar`) _, _ = variablesFile.WriteString(`second_variable: foobar`)
cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name()) inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error)
go func() {
execDone <- cmd.Execute()
}()
clitest.Start(t, inv)
matches := []struct { matches := []struct {
match string match string
write string write string
@ -352,8 +307,6 @@ func TestTemplateCreate(t *testing.T) {
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
} }
require.Error(t, <-execDone)
}) })
t.Run("WithVariablesFileWithTheRequiredValue", func(t *testing.T) { t.Run("WithVariablesFileWithTheRequiredValue", func(t *testing.T) {
@ -365,7 +318,7 @@ func TestTemplateCreate(t *testing.T) {
templateVariables := []*proto.TemplateVariable{ templateVariables := []*proto.TemplateVariable{
{ {
Name: "first_variable", Name: "first_variable",
Description: "This is the first variable", Description: "This is the first variable.",
Type: "string", Type: "string",
Required: true, Required: true,
Sensitive: true, Sensitive: true,
@ -385,16 +338,11 @@ func TestTemplateCreate(t *testing.T) {
removeTmpDirUntilSuccessAfterTest(t, tempDir) removeTmpDirUntilSuccessAfterTest(t, tempDir)
variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml") variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml")
_, _ = variablesFile.WriteString(`first_variable: foobar`) _, _ = variablesFile.WriteString(`first_variable: foobar`)
cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name()) inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name())
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error) clitest.Start(t, inv)
go func() {
execDone <- cmd.Execute()
}()
matches := []struct { matches := []struct {
match string match string
@ -409,8 +357,6 @@ func TestTemplateCreate(t *testing.T) {
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
} }
require.NoError(t, <-execDone)
}) })
t.Run("WithVariableOption", func(t *testing.T) { t.Run("WithVariableOption", func(t *testing.T) {
t.Parallel() t.Parallel()
@ -421,7 +367,7 @@ func TestTemplateCreate(t *testing.T) {
templateVariables := []*proto.TemplateVariable{ templateVariables := []*proto.TemplateVariable{
{ {
Name: "first_variable", Name: "first_variable",
Description: "This is the first variable", Description: "This is the first variable.",
Type: "string", Type: "string",
Required: true, Required: true,
Sensitive: true, Sensitive: true,
@ -429,16 +375,11 @@ func TestTemplateCreate(t *testing.T) {
} }
source := clitest.CreateTemplateVersionSource(t, source := clitest.CreateTemplateVersionSource(t,
createEchoResponsesWithTemplateVariables(templateVariables)) createEchoResponsesWithTemplateVariables(templateVariables))
cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variable", "first_variable=foobar") inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variable", "first_variable=foobar")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error) clitest.Start(t, inv)
go func() {
execDone <- cmd.Execute()
}()
matches := []struct { matches := []struct {
match string match string
@ -451,8 +392,6 @@ func TestTemplateCreate(t *testing.T) {
pty.ExpectMatch(m.match) pty.ExpectMatch(m.match)
pty.WriteLine(m.write) pty.WriteLine(m.write)
} }
require.NoError(t, <-execDone)
}) })
} }

View File

@ -5,35 +5,38 @@ import (
"strings" "strings"
"time" "time"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func templateDelete() *cobra.Command { func (r *RootCmd) templateDelete() *clibase.Cmd {
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "delete [name...]", Use: "delete [name...]",
Short: "Delete templates", Short: "Delete templates",
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.Chain(
r.InitClient(client),
),
Options: clibase.OptionSet{
cliui.SkipPromptOption(),
},
Handler: func(inv *clibase.Invocation) error {
var ( var (
ctx = cmd.Context() ctx = inv.Context()
templateNames = []string{} templateNames = []string{}
templates = []codersdk.Template{} templates = []codersdk.Template{}
) )
client, err := CreateClient(cmd) organization, err := CurrentOrganization(inv, client)
if err != nil {
return err
}
organization, err := CurrentOrganization(cmd, client)
if err != nil { if err != nil {
return err return err
} }
if len(args) > 0 { if len(inv.Args) > 0 {
templateNames = args templateNames = inv.Args
for _, templateName := range templateNames { for _, templateName := range templateNames {
template, err := client.TemplateByName(ctx, organization.ID, templateName) template, err := client.TemplateByName(ctx, organization.ID, templateName)
@ -57,7 +60,7 @@ func templateDelete() *cobra.Command {
opts = append(opts, template.Name) opts = append(opts, template.Name)
} }
selection, err := cliui.Select(cmd, cliui.SelectOptions{ selection, err := cliui.Select(inv, cliui.SelectOptions{
Options: opts, Options: opts,
}) })
if err != nil { if err != nil {
@ -73,7 +76,7 @@ func templateDelete() *cobra.Command {
} }
// Confirm deletion of the template. // Confirm deletion of the template.
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(strings.Join(templateNames, ", "))), Text: fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(strings.Join(templateNames, ", "))),
IsConfirm: true, IsConfirm: true,
Default: cliui.ConfirmNo, Default: cliui.ConfirmNo,
@ -88,13 +91,12 @@ func templateDelete() *cobra.Command {
return xerrors.Errorf("delete template %q: %w", template.Name, err) return xerrors.Errorf("delete template %q: %w", template.Name, err)
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Deleted template "+cliui.Styles.Code.Render(template.Name)+" at "+cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))+"!") _, _ = fmt.Fprintln(inv.Stdout, "Deleted template "+cliui.Styles.Code.Render(template.Name)+" at "+cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))+"!")
} }
return nil return nil
}, },
} }
cliui.AllowSkipPrompt(cmd)
return cmd return cmd
} }

View File

@ -27,16 +27,14 @@ func TestTemplateDelete(t *testing.T) {
_ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "templates", "delete", template.Name) inv, root := clitest.New(t, "templates", "delete", template.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error) execDone := make(chan error)
go func() { go func() {
execDone <- cmd.Execute() execDone <- inv.Run()
}() }()
pty.ExpectMatch(fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(template.Name))) pty.ExpectMatch(fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(template.Name)))
@ -65,9 +63,9 @@ func TestTemplateDelete(t *testing.T) {
templateNames = append(templateNames, template.Name) templateNames = append(templateNames, template.Name)
} }
cmd, root := clitest.New(t, append([]string{"templates", "delete", "--yes"}, templateNames...)...) inv, root := clitest.New(t, append([]string{"templates", "delete", "--yes"}, templateNames...)...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
require.NoError(t, cmd.Execute()) require.NoError(t, inv.Run())
for _, template := range templates { for _, template := range templates {
_, err := client.Template(context.Background(), template.ID) _, err := client.Template(context.Background(), template.ID)
@ -92,15 +90,13 @@ func TestTemplateDelete(t *testing.T) {
templateNames = append(templateNames, template.Name) templateNames = append(templateNames, template.Name)
} }
cmd, root := clitest.New(t, append([]string{"templates", "delete"}, templateNames...)...) inv, root := clitest.New(t, append([]string{"templates", "delete"}, templateNames...)...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error) execDone := make(chan error)
go func() { go func() {
execDone <- cmd.Execute() execDone <- inv.Run()
}() }()
pty.ExpectMatch(fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(strings.Join(templateNames, ", ")))) pty.ExpectMatch(fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(strings.Join(templateNames, ", "))))
@ -123,16 +119,14 @@ func TestTemplateDelete(t *testing.T) {
_ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
cmd, root := clitest.New(t, "templates", "delete") inv, root := clitest.New(t, "templates", "delete")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
execDone := make(chan error) execDone := make(chan error)
go func() { go func() {
execDone <- cmd.Execute() execDone <- inv.Run()
}() }()
pty.WriteLine("yes") pty.WriteLine("yes")

View File

@ -5,14 +5,14 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func templateEdit() *cobra.Command { func (r *RootCmd) templateEdit() *clibase.Cmd {
var ( var (
name string name string
displayName string displayName string
@ -22,19 +22,18 @@ func templateEdit() *cobra.Command {
maxTTL time.Duration maxTTL time.Duration
allowUserCancelWorkspaceJobs bool allowUserCancelWorkspaceJobs bool
) )
client := new(codersdk.Client)
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "edit <template> [flags]", Use: "edit <template>",
Args: cobra.ExactArgs(1), Middleware: clibase.Chain(
clibase.RequireNArgs(1),
r.InitClient(client),
),
Short: "Edit the metadata of a template by name.", Short: "Edit the metadata of a template by name.",
RunE: func(cmd *cobra.Command, args []string) error { Handler: func(inv *clibase.Invocation) error {
client, err := CreateClient(cmd)
if err != nil {
return xerrors.Errorf("create client: %w", err)
}
if maxTTL != 0 { if maxTTL != 0 {
entitlements, err := client.Entitlements(cmd.Context()) entitlements, err := client.Entitlements(inv.Context())
var sdkErr *codersdk.Error var sdkErr *codersdk.Error
if xerrors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusNotFound { if xerrors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusNotFound {
return xerrors.Errorf("your deployment appears to be an AGPL deployment, so you cannot set --max-ttl") return xerrors.Errorf("your deployment appears to be an AGPL deployment, so you cannot set --max-ttl")
@ -47,11 +46,11 @@ func templateEdit() *cobra.Command {
} }
} }
organization, err := CurrentOrganization(cmd, client) organization, err := CurrentOrganization(inv, client)
if err != nil { if err != nil {
return xerrors.Errorf("get current organization: %w", err) return xerrors.Errorf("get current organization: %w", err)
} }
template, err := client.TemplateByName(cmd.Context(), organization.ID, args[0]) template, err := client.TemplateByName(inv.Context(), organization.ID, inv.Args[0])
if err != nil { if err != nil {
return xerrors.Errorf("get workspace template: %w", err) return xerrors.Errorf("get workspace template: %w", err)
} }
@ -67,23 +66,54 @@ func templateEdit() *cobra.Command {
AllowUserCancelWorkspaceJobs: allowUserCancelWorkspaceJobs, AllowUserCancelWorkspaceJobs: allowUserCancelWorkspaceJobs,
} }
_, err = client.UpdateTemplateMeta(cmd.Context(), template.ID, req) _, err = client.UpdateTemplateMeta(inv.Context(), template.ID, req)
if err != nil { if err != nil {
return xerrors.Errorf("update template metadata: %w", err) return xerrors.Errorf("update template metadata: %w", err)
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Updated template metadata at %s!\n", cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) _, _ = fmt.Fprintf(inv.Stdout, "Updated template metadata at %s!\n", cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp)))
return nil return nil
}, },
} }
cmd.Flags().StringVarP(&name, "name", "", "", "Edit the template name.") cmd.Options = clibase.OptionSet{
cmd.Flags().StringVarP(&displayName, "display-name", "", "", "Edit the template display name.") {
cmd.Flags().StringVarP(&description, "description", "", "", "Edit the template description.") Flag: "name",
cmd.Flags().StringVarP(&icon, "icon", "", "", "Edit the template icon path.") Description: "Edit the template name.",
cmd.Flags().DurationVarP(&defaultTTL, "default-ttl", "", 0, "Edit the template default time before shutdown - workspaces created from this template default to this value.") Value: clibase.StringOf(&name),
cmd.Flags().DurationVarP(&maxTTL, "max-ttl", "", 0, "Edit the template maximum time before shutdown - workspaces created from this template must shutdown within the given duration after starting. This is an enterprise-only feature.") },
cmd.Flags().BoolVarP(&allowUserCancelWorkspaceJobs, "allow-user-cancel-workspace-jobs", "", true, "Allow users to cancel in-progress workspace jobs.") {
cliui.AllowSkipPrompt(cmd) Flag: "display-name",
Description: "Edit the template display name.",
Value: clibase.StringOf(&displayName),
},
{
Flag: "description",
Description: "Edit the template description.",
Value: clibase.StringOf(&description),
},
{
Flag: "icon",
Description: "Edit the template icon path.",
Value: clibase.StringOf(&icon),
},
{
Flag: "default-ttl",
Description: "Edit the template default time before shutdown - workspaces created from this template default to this value.",
Value: clibase.DurationOf(&defaultTTL),
},
{
Flag: "max-ttl",
Description: "Edit the template maximum time before shutdown - workspaces created from this template must shutdown within the given duration after starting. This is an enterprise-only feature.",
Value: clibase.DurationOf(&maxTTL),
},
{
Flag: "allow-user-cancel-workspace-jobs",
Description: "Allow users to cancel in-progress workspace jobs.",
Default: "true",
Value: clibase.BoolOf(&allowUserCancelWorkspaceJobs),
},
cliui.SkipPromptOption(),
}
return cmd return cmd
} }

View File

@ -55,11 +55,11 @@ func TestTemplateEdit(t *testing.T) {
"--default-ttl", defaultTTL.String(), "--default-ttl", defaultTTL.String(),
"--allow-user-cancel-workspace-jobs=" + strconv.FormatBool(allowUserCancelWorkspaceJobs), "--allow-user-cancel-workspace-jobs=" + strconv.FormatBool(allowUserCancelWorkspaceJobs),
} }
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
@ -92,11 +92,11 @@ func TestTemplateEdit(t *testing.T) {
"--default-ttl", (time.Duration(template.DefaultTTLMillis) * time.Millisecond).String(), "--default-ttl", (time.Duration(template.DefaultTTLMillis) * time.Millisecond).String(),
"--allow-user-cancel-workspace-jobs=" + strconv.FormatBool(template.AllowUserCancelWorkspaceJobs), "--allow-user-cancel-workspace-jobs=" + strconv.FormatBool(template.AllowUserCancelWorkspaceJobs),
} }
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.ErrorContains(t, err, "not modified") require.ErrorContains(t, err, "not modified")
@ -125,11 +125,11 @@ func TestTemplateEdit(t *testing.T) {
"--name", template.Name, "--name", template.Name,
"--display-name", " a-b-c", "--display-name", " a-b-c",
} }
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.Error(t, err, "client call must fail") require.Error(t, err, "client call must fail")
_, isSdkError := codersdk.AsError(err) _, isSdkError := codersdk.AsError(err)
@ -175,11 +175,11 @@ func TestTemplateEdit(t *testing.T) {
"--display-name", displayName, "--display-name", displayName,
"--icon", icon, "--icon", icon,
} }
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
@ -221,11 +221,11 @@ func TestTemplateEdit(t *testing.T) {
"edit", "edit",
template.Name, template.Name,
} }
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
@ -260,11 +260,11 @@ func TestTemplateEdit(t *testing.T) {
template.Name, template.Name,
"--max-ttl", "1h", "--max-ttl", "1h",
} }
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "appears to be an AGPL deployment") require.ErrorContains(t, err, "appears to be an AGPL deployment")
@ -332,11 +332,11 @@ func TestTemplateEdit(t *testing.T) {
template.Name, template.Name,
"--max-ttl", "1h", "--max-ttl", "1h",
} }
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, proxyClient, root) clitest.SetupConfig(t, proxyClient, root)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "license is not entitled") require.ErrorContains(t, err, "license is not entitled")
@ -419,11 +419,11 @@ func TestTemplateEdit(t *testing.T) {
template.Name, template.Name,
"--max-ttl", "1h", "--max-ttl", "1h",
} }
cmd, root := clitest.New(t, cmdArgs...) inv, root := clitest.New(t, cmdArgs...)
clitest.SetupConfig(t, proxyClient, root) clitest.SetupConfig(t, proxyClient, root)
ctx, _ := testutil.Context(t) ctx := testutil.Context(t, testutil.WaitLong)
err = cmd.ExecuteContext(ctx) err = inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
require.EqualValues(t, 1, atomic.LoadInt64(&updateTemplateCalled)) require.EqualValues(t, 1, atomic.LoadInt64(&updateTemplateCalled))

View File

@ -6,19 +6,19 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/spf13/cobra" "github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/examples" "github.com/coder/coder/examples"
"github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk"
) )
func templateInit() *cobra.Command { func (*RootCmd) templateInit() *clibase.Cmd {
return &cobra.Command{ return &clibase.Cmd{
Use: "init [directory]", Use: "init [directory]",
Short: "Get started with a templated template.", Short: "Get started with a templated template.",
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.RequireRangeArgs(0, 1),
Handler: func(inv *clibase.Invocation) error {
exampleList, err := examples.List() exampleList, err := examples.List()
if err != nil { if err != nil {
return err return err
@ -36,10 +36,10 @@ func templateInit() *cobra.Command {
exampleByName[name] = example exampleByName[name] = example
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Wrap.Render( _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Wrap.Render(
"A template defines infrastructure as code to be provisioned "+ "A template defines infrastructure as code to be provisioned "+
"for individual developer workspaces. Select an example to be copied to the active directory:\n")) "for individual developer workspaces. Select an example to be copied to the active directory:\n"))
option, err := cliui.Select(cmd, cliui.SelectOptions{ option, err := cliui.Select(inv, cliui.SelectOptions{
Options: exampleNames, Options: exampleNames,
}) })
if err != nil { if err != nil {
@ -55,8 +55,8 @@ func templateInit() *cobra.Command {
return err return err
} }
var directory string var directory string
if len(args) > 0 { if len(inv.Args) > 0 {
directory = args[0] directory = inv.Args[0]
} else { } else {
directory = filepath.Join(workingDir, selectedTemplate.ID) directory = filepath.Join(workingDir, selectedTemplate.ID)
} }
@ -66,7 +66,7 @@ func templateInit() *cobra.Command {
} else { } else {
relPath = "./" + relPath relPath = "./" + relPath
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Extracting %s to %s...\n", cliui.Styles.Field.Render(selectedTemplate.ID), relPath) _, _ = fmt.Fprintf(inv.Stdout, "Extracting %s to %s...\n", cliui.Styles.Field.Render(selectedTemplate.ID), relPath)
err = os.MkdirAll(directory, 0o700) err = os.MkdirAll(directory, 0o700)
if err != nil { if err != nil {
return err return err
@ -75,9 +75,9 @@ func templateInit() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Create your template by running:") _, _ = fmt.Fprintln(inv.Stdout, "Create your template by running:")
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render(cliui.Styles.Code.Render("cd "+relPath+" && coder templates create"))+"\n") _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render(cliui.Styles.Code.Render("cd "+relPath+" && coder templates create"))+"\n")
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Wrap.Render("Examples provide a starting point and are expected to be edited! 🎨")) _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Wrap.Render("Examples provide a starting point and are expected to be edited! 🎨"))
return nil return nil
}, },
} }

View File

@ -15,12 +15,9 @@ func TestTemplateInit(t *testing.T) {
t.Run("Extract", func(t *testing.T) { t.Run("Extract", func(t *testing.T) {
t.Parallel() t.Parallel()
tempDir := t.TempDir() tempDir := t.TempDir()
cmd, _ := clitest.New(t, "templates", "init", tempDir) inv, _ := clitest.New(t, "templates", "init", tempDir)
pty := ptytest.New(t) ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input()) clitest.Run(t, inv)
cmd.SetOut(pty.Output())
err := cmd.Execute()
require.NoError(t, err)
files, err := os.ReadDir(tempDir) files, err := os.ReadDir(tempDir)
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, len(files), 0) require.Greater(t, len(files), 0)

View File

@ -4,52 +4,53 @@ import (
"fmt" "fmt"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/spf13/cobra"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
) )
func templateList() *cobra.Command { func (r *RootCmd) templateList() *clibase.Cmd {
formatter := cliui.NewOutputFormatter( formatter := cliui.NewOutputFormatter(
cliui.TableFormat([]templateTableRow{}, []string{"name", "last updated", "used by"}), cliui.TableFormat([]templateTableRow{}, []string{"name", "last updated", "used by"}),
cliui.JSONFormat(), cliui.JSONFormat(),
) )
cmd := &cobra.Command{ client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "list", Use: "list",
Short: "List all the templates available for the organization", Short: "List all the templates available for the organization",
Aliases: []string{"ls"}, Aliases: []string{"ls"},
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.Chain(
client, err := CreateClient(cmd) r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
organization, err := CurrentOrganization(inv, client)
if err != nil { if err != nil {
return err return err
} }
organization, err := CurrentOrganization(cmd, client) templates, err := client.TemplatesByOrganization(inv.Context(), organization.ID)
if err != nil {
return err
}
templates, err := client.TemplatesByOrganization(cmd.Context(), organization.ID)
if err != nil { if err != nil {
return err return err
} }
if len(templates) == 0 { if len(templates) == 0 {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "%s No templates found in %s! Create one:\n\n", Caret, color.HiWhiteString(organization.Name)) _, _ = fmt.Fprintf(inv.Stderr, "%s No templates found in %s! Create one:\n\n", Caret, color.HiWhiteString(organization.Name))
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), color.HiMagentaString(" $ coder templates create <directory>\n")) _, _ = fmt.Fprintln(inv.Stderr, color.HiMagentaString(" $ coder templates create <directory>\n"))
return nil return nil
} }
rows := templatesToRows(templates...) rows := templatesToRows(templates...)
out, err := formatter.Format(cmd.Context(), rows) out, err := formatter.Format(inv.Context(), rows)
if err != nil { if err != nil {
return err return err
} }
_, err = fmt.Fprintln(cmd.OutOrStdout(), out) _, err = fmt.Fprintln(inv.Stdout, out)
return err return err
}, },
} }
formatter.AttachFlags(cmd) formatter.AttachOptions(&cmd.Options)
return cmd return cmd
} }

View File

@ -30,19 +30,17 @@ func TestTemplateList(t *testing.T) {
_ = coderdtest.AwaitTemplateVersionJob(t, client, secondVersion.ID) _ = coderdtest.AwaitTemplateVersionJob(t, client, secondVersion.ID)
secondTemplate := coderdtest.CreateTemplate(t, client, user.OrganizationID, secondVersion.ID) secondTemplate := coderdtest.CreateTemplate(t, client, user.OrganizationID, secondVersion.ID)
cmd, root := clitest.New(t, "templates", "list") inv, root := clitest.New(t, "templates", "list")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancelFunc() defer cancelFunc()
errC := make(chan error) errC := make(chan error)
go func() { go func() {
errC <- cmd.ExecuteContext(ctx) errC <- inv.WithContext(ctx).Run()
}() }()
// expect that templates are listed alphabetically // expect that templates are listed alphabetically
@ -67,15 +65,15 @@ func TestTemplateList(t *testing.T) {
_ = coderdtest.AwaitTemplateVersionJob(t, client, secondVersion.ID) _ = coderdtest.AwaitTemplateVersionJob(t, client, secondVersion.ID)
_ = coderdtest.CreateTemplate(t, client, user.OrganizationID, secondVersion.ID) _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, secondVersion.ID)
cmd, root := clitest.New(t, "templates", "list", "--output=json") inv, root := clitest.New(t, "templates", "list", "--output=json")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancelFunc() defer cancelFunc()
out := bytes.NewBuffer(nil) out := bytes.NewBuffer(nil)
cmd.SetOut(out) inv.Stdout = out
err := cmd.ExecuteContext(ctx) err := inv.WithContext(ctx).Run()
require.NoError(t, err) require.NoError(t, err)
var templates []codersdk.Template var templates []codersdk.Template
@ -87,19 +85,19 @@ func TestTemplateList(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{}) client := coderdtest.New(t, &coderdtest.Options{})
coderdtest.CreateFirstUser(t, client) coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.New(t, "templates", "list") inv, root := clitest.New(t, "templates", "list")
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) inv.Stdin = pty.Input()
cmd.SetErr(pty.Output()) inv.Stderr = pty.Output()
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancelFunc() defer cancelFunc()
errC := make(chan error) errC := make(chan error)
go func() { go func() {
errC <- cmd.ExecuteContext(ctx) errC <- inv.WithContext(ctx).Run()
}() }()
require.NoError(t, <-errC) require.NoError(t, <-errC)

18
cli/templateplan.go Normal file
View File

@ -0,0 +1,18 @@
package cli
import (
"github.com/coder/coder/cli/clibase"
)
func (*RootCmd) templatePlan() *clibase.Cmd {
return &clibase.Cmd{
Use: "plan <directory>",
Middleware: clibase.Chain(
clibase.RequireNArgs(1),
),
Short: "Plan a template push from the current directory",
Handler: func(inv *clibase.Invocation) error {
return nil
},
}
}

View File

@ -7,37 +7,37 @@ import (
"sort" "sort"
"github.com/codeclysm/extract" "github.com/codeclysm/extract"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func templatePull() *cobra.Command { func (r *RootCmd) templatePull() *clibase.Cmd {
var tarMode bool var tarMode bool
cmd := &cobra.Command{
client := new(codersdk.Client)
cmd := &clibase.Cmd{
Use: "pull <name> [destination]", Use: "pull <name> [destination]",
Short: "Download the latest version of a template to a path.", Short: "Download the latest version of a template to a path.",
Args: cobra.RangeArgs(1, 2), Middleware: clibase.Chain(
RunE: func(cmd *cobra.Command, args []string) error { clibase.RequireRangeArgs(1, 2),
r.InitClient(client),
),
Handler: func(inv *clibase.Invocation) error {
var ( var (
ctx = cmd.Context() ctx = inv.Context()
templateName = args[0] templateName = inv.Args[0]
dest string dest string
) )
if len(args) > 1 { if len(inv.Args) > 1 {
dest = args[1] dest = inv.Args[1]
}
client, err := CreateClient(cmd)
if err != nil {
return xerrors.Errorf("create client: %w", err)
} }
// TODO(JonA): Do we need to add a flag for organization? // TODO(JonA): Do we need to add a flag for organization?
organization, err := CurrentOrganization(cmd, client) organization, err := CurrentOrganization(inv, client)
if err != nil { if err != nil {
return xerrors.Errorf("current organization: %w", err) return xerrors.Errorf("current organization: %w", err)
} }
@ -78,7 +78,7 @@ func templatePull() *cobra.Command {
} }
if tarMode { if tarMode {
_, err = cmd.OutOrStdout().Write(raw) _, err = inv.Stdout.Write(raw)
return err return err
} }
@ -97,7 +97,7 @@ func templatePull() *cobra.Command {
} }
if len(ents) > 0 { if len(ents) > 0 {
_, err = cliui.Prompt(cmd, cliui.PromptOptions{ _, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Directory %q is not empty, existing files may be overwritten.\nContinue extracting?", dest), Text: fmt.Sprintf("Directory %q is not empty, existing files may be overwritten.\nContinue extracting?", dest),
Default: "No", Default: "No",
Secret: false, Secret: false,
@ -108,14 +108,21 @@ func templatePull() *cobra.Command {
} }
} }
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Extracting template to %q\n", dest) _, _ = fmt.Fprintf(inv.Stderr, "Extracting template to %q\n", dest)
err = extract.Tar(ctx, bytes.NewReader(raw), dest, nil) err = extract.Tar(ctx, bytes.NewReader(raw), dest, nil)
return err return err
}, },
} }
cmd.Flags().BoolVar(&tarMode, "tar", false, "output the template as a tar archive to stdout") cmd.Options = clibase.OptionSet{
cliui.AllowSkipPrompt(cmd) {
Description: "Output the template as a tar archive to stdout.",
Flag: "tar",
Value: clibase.BoolOf(&tarMode),
},
cliui.SkipPromptOption(),
}
return cmd return cmd
} }

View File

@ -46,8 +46,8 @@ func TestTemplatePull(t *testing.T) {
t.Run("NoName", func(t *testing.T) { t.Run("NoName", func(t *testing.T) {
t.Parallel() t.Parallel()
cmd, _ := clitest.New(t, "templates", "pull") inv, _ := clitest.New(t, "templates", "pull")
err := cmd.Execute() err := inv.Run()
require.Error(t, err) require.Error(t, err)
}) })
@ -77,13 +77,13 @@ func TestTemplatePull(t *testing.T) {
// are being sorted correctly. // are being sorted correctly.
_ = coderdtest.UpdateTemplateVersion(t, client, user.OrganizationID, source2, template.ID) _ = coderdtest.UpdateTemplateVersion(t, client, user.OrganizationID, source2, template.ID)
cmd, root := clitest.New(t, "templates", "pull", "--tar", template.Name) inv, root := clitest.New(t, "templates", "pull", "--tar", template.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
var buf bytes.Buffer var buf bytes.Buffer
cmd.SetOut(&buf) inv.Stdout = &buf
err = cmd.Execute() err = inv.Run()
require.NoError(t, err) require.NoError(t, err)
require.True(t, bytes.Equal(expected, buf.Bytes()), "tar files differ") require.True(t, bytes.Equal(expected, buf.Bytes()), "tar files differ")
@ -124,20 +124,12 @@ func TestTemplatePull(t *testing.T) {
err = extract.Tar(ctx, bytes.NewReader(expected), expectedDest, nil) err = extract.Tar(ctx, bytes.NewReader(expected), expectedDest, nil)
require.NoError(t, err) require.NoError(t, err)
cmd, root := clitest.New(t, "templates", "pull", template.Name, actualDest) inv, root := clitest.New(t, "templates", "pull", template.Name, actualDest)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
errChan := make(chan error) require.NoError(t, inv.Run())
go func() {
defer close(errChan)
errChan <- cmd.Execute()
}()
require.NoError(t, <-errChan)
require.Equal(t, require.Equal(t,
dirSum(t, expectedDest), dirSum(t, expectedDest),
@ -190,23 +182,17 @@ func TestTemplatePull(t *testing.T) {
err = extract.Tar(ctx, bytes.NewReader(expected), expectedDest, nil) err = extract.Tar(ctx, bytes.NewReader(expected), expectedDest, nil)
require.NoError(t, err) require.NoError(t, err)
cmd, root := clitest.New(t, "templates", "pull", template.Name, conflictDest) inv, root := clitest.New(t, "templates", "pull", template.Name, conflictDest)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
pty := ptytest.New(t) pty := ptytest.New(t).Attach(inv)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
errChan := make(chan error) waiter := clitest.StartWithWaiter(t, inv)
go func() {
defer close(errChan)
errChan <- cmd.Execute()
}()
pty.ExpectMatch("not empty") pty.ExpectMatch("not empty")
pty.WriteLine("no") pty.WriteLine("no")
require.Error(t, <-errChan) waiter.RequireError()
ents, err := os.ReadDir(conflictDest) ents, err := os.ReadDir(conflictDest)
require.NoError(t, err) require.NoError(t, err)

View File

@ -4,15 +4,13 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"os"
"path/filepath" "path/filepath"
"time" "time"
"github.com/briandowns/spinner" "github.com/briandowns/spinner"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
@ -24,22 +22,27 @@ type templateUploadFlags struct {
directory string directory string
} }
func (pf *templateUploadFlags) register(f *pflag.FlagSet) { func (pf *templateUploadFlags) option() clibase.Option {
currentDirectory, _ := os.Getwd() return clibase.Option{
f.StringVarP(&pf.directory, "directory", "d", currentDirectory, "Specify the directory to create from, use '-' to read tar from stdin") Flag: "directory",
FlagShorthand: "d",
Description: "Specify the directory to create from, use '-' to read tar from stdin.",
Default: ".",
Value: clibase.StringOf(&pf.directory),
}
} }
func (pf *templateUploadFlags) stdin() bool { func (pf *templateUploadFlags) stdin() bool {
return pf.directory == "-" return pf.directory == "-"
} }
func (pf *templateUploadFlags) upload(cmd *cobra.Command, client *codersdk.Client) (*codersdk.UploadResponse, error) { func (pf *templateUploadFlags) upload(inv *clibase.Invocation, client *codersdk.Client) (*codersdk.UploadResponse, error) {
var content io.Reader var content io.Reader
if pf.stdin() { if pf.stdin() {
content = cmd.InOrStdin() content = inv.Stdin
} else { } else {
prettyDir := prettyDirectoryPath(pf.directory) prettyDir := prettyDirectoryPath(pf.directory)
_, err := cliui.Prompt(cmd, cliui.PromptOptions{ _, err := cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Upload %q?", prettyDir), Text: fmt.Sprintf("Upload %q?", prettyDir),
IsConfirm: true, IsConfirm: true,
Default: cliui.ConfirmYes, Default: cliui.ConfirmYes,
@ -58,12 +61,12 @@ func (pf *templateUploadFlags) upload(cmd *cobra.Command, client *codersdk.Clien
} }
spin := spinner.New(spinner.CharSets[5], 100*time.Millisecond) spin := spinner.New(spinner.CharSets[5], 100*time.Millisecond)
spin.Writer = cmd.OutOrStdout() spin.Writer = inv.Stdout
spin.Suffix = cliui.Styles.Keyword.Render(" Uploading directory...") spin.Suffix = cliui.Styles.Keyword.Render(" Uploading directory...")
spin.Start() spin.Start()
defer spin.Stop() defer spin.Stop()
resp, err := client.Upload(cmd.Context(), codersdk.ContentTypeTar, bufio.NewReader(content)) resp, err := client.Upload(inv.Context(), codersdk.ContentTypeTar, bufio.NewReader(content))
if err != nil { if err != nil {
return nil, xerrors.Errorf("upload: %w", err) return nil, xerrors.Errorf("upload: %w", err)
} }
@ -79,14 +82,14 @@ func (pf *templateUploadFlags) templateName(args []string) (string, error) {
return args[0], nil return args[0], nil
} }
name := filepath.Base(pf.directory)
if len(args) > 0 { if len(args) > 0 {
name = args[0] return args[0], nil
} }
return name, nil // If no name is provided, use the directory name.
return filepath.Base(pf.directory), nil
} }
func templatePush() *cobra.Command { func (r *RootCmd) templatePush() *clibase.Cmd {
var ( var (
versionName string versionName string
provisioner string provisioner string
@ -97,32 +100,31 @@ func templatePush() *cobra.Command {
provisionerTags []string provisionerTags []string
uploadFlags templateUploadFlags uploadFlags templateUploadFlags
) )
client := new(codersdk.Client)
cmd := &cobra.Command{ cmd := &clibase.Cmd{
Use: "push [template]", Use: "push [template]",
Args: cobra.MaximumNArgs(1),
Short: "Push a new template version from the current directory or as specified by flag", Short: "Push a new template version from the current directory or as specified by flag",
RunE: func(cmd *cobra.Command, args []string) error { Middleware: clibase.Chain(
client, err := CreateClient(cmd) clibase.RequireRangeArgs(0, 1),
if err != nil { r.InitClient(client),
return err ),
} Handler: func(inv *clibase.Invocation) error {
organization, err := CurrentOrganization(cmd, client) organization, err := CurrentOrganization(inv, client)
if err != nil { if err != nil {
return err return err
} }
name, err := uploadFlags.templateName(args) name, err := uploadFlags.templateName(inv.Args)
if err != nil { if err != nil {
return err return err
} }
template, err := client.TemplateByName(cmd.Context(), organization.ID, name) template, err := client.TemplateByName(inv.Context(), organization.ID, name)
if err != nil { if err != nil {
return err return err
} }
resp, err := uploadFlags.upload(cmd, client) resp, err := uploadFlags.upload(inv, client)
if err != nil { if err != nil {
return err return err
} }
@ -132,7 +134,7 @@ func templatePush() *cobra.Command {
return err return err
} }
job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{ job, _, err := createValidTemplateVersion(inv, createValidTemplateVersionArgs{
Name: versionName, Name: versionName,
Client: client, Client: client,
Organization: organization, Organization: organization,
@ -153,32 +155,60 @@ func templatePush() *cobra.Command {
return xerrors.Errorf("job failed: %s", job.Job.Status) return xerrors.Errorf("job failed: %s", job.Job.Status)
} }
err = client.UpdateActiveTemplateVersion(cmd.Context(), template.ID, codersdk.UpdateActiveTemplateVersion{ err = client.UpdateActiveTemplateVersion(inv.Context(), template.ID, codersdk.UpdateActiveTemplateVersion{
ID: job.ID, ID: job.ID,
}) })
if err != nil { if err != nil {
return err return err
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Updated version at %s!\n", cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) _, _ = fmt.Fprintf(inv.Stdout, "Updated version at %s!\n", cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp)))
return nil return nil
}, },
} }
cmd.Flags().StringVarP(&provisioner, "test.provisioner", "", "terraform", "Customize the provisioner backend") cmd.Options = clibase.OptionSet{
cmd.Flags().StringVarP(&parameterFile, "parameter-file", "", "", "Specify a file path with parameter values.") {
cmd.Flags().StringVarP(&variablesFile, "variables-file", "", "", "Specify a file path with values for Terraform-managed variables.") Flag: "test.provisioner",
cmd.Flags().StringArrayVarP(&variables, "variable", "", []string{}, "Specify a set of values for Terraform-managed variables.") FlagShorthand: "p",
cmd.Flags().StringVarP(&versionName, "name", "", "", "Specify a name for the new template version. It will be automatically generated if not provided.") Description: "Customize the provisioner backend.",
cmd.Flags().StringArrayVarP(&provisionerTags, "provisioner-tag", "", []string{}, "Specify a set of tags to target provisioner daemons.") Default: "terraform",
cmd.Flags().BoolVar(&alwaysPrompt, "always-prompt", false, "Always prompt all parameters. Does not pull parameter values from active template version") Value: clibase.StringOf(&provisioner),
uploadFlags.register(cmd.Flags())
cliui.AllowSkipPrompt(cmd)
// This is for testing! // This is for testing!
err := cmd.Flags().MarkHidden("test.provisioner") Hidden: true,
if err != nil { },
panic(err) {
Flag: "parameter-file",
Description: "Specify a file path with parameter values.",
Value: clibase.StringOf(&parameterFile),
},
{
Flag: "variables-file",
Description: "Specify a file path with values for Terraform-managed variables.",
Value: clibase.StringOf(&variablesFile),
},
{
Flag: "variable",
Description: "Specify a set of values for Terraform-managed variables.",
Value: clibase.StringArrayOf(&variables),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: clibase.StringArrayOf(&provisionerTags),
},
{
Flag: "name",
Description: "Specify a name for the new template version. It will be automatically generated if not provided.",
Value: clibase.StringOf(&versionName),
},
{
Flag: "always-prompt",
Description: "Always prompt all parameters. Does not pull parameter values from active template version.",
Value: clibase.BoolOf(&alwaysPrompt),
},
cliui.SkipPromptOption(),
uploadFlags.option(),
} }
return cmd return cmd
} }

Some files were not shown because too many files have changed in this diff Show More