mirror of https://github.com/coder/coder.git
548 lines
13 KiB
Go
548 lines
13 KiB
Go
package clibase
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/spf13/pflag"
|
|
"golang.org/x/exp/slices"
|
|
"golang.org/x/xerrors"
|
|
)
|
|
|
|
// Cmd describes an executable command.
|
|
type Cmd struct {
|
|
// Parent is the direct parent of the command.
|
|
Parent *Cmd
|
|
// Children is a list of direct descendants.
|
|
Children []*Cmd
|
|
// Use is provided in form "command [flags] [args...]".
|
|
Use string
|
|
|
|
// Aliases is a list of alternative names for the command.
|
|
Aliases []string
|
|
|
|
// Short is a one-line description of the command.
|
|
Short string
|
|
|
|
// Hidden determines whether the command should be hidden from help.
|
|
Hidden bool
|
|
|
|
// RawArgs determines whether the command should receive unparsed arguments.
|
|
// No flags are parsed when set, and the command is responsible for parsing
|
|
// its own flags.
|
|
RawArgs bool
|
|
|
|
// Long is a detailed description of the command,
|
|
// presented on its help page. It may contain examples.
|
|
Long string
|
|
Options OptionSet
|
|
Annotations Annotations
|
|
|
|
// Middleware is called before the Handler.
|
|
// Use Chain() to combine multiple middlewares.
|
|
Middleware MiddlewareFunc
|
|
Handler HandlerFunc
|
|
HelpHandler HandlerFunc
|
|
}
|
|
|
|
// AddSubcommands adds the given subcommands, setting their
|
|
// Parent field automatically.
|
|
func (c *Cmd) AddSubcommands(cmds ...*Cmd) {
|
|
for _, cmd := range cmds {
|
|
cmd.Parent = c
|
|
c.Children = append(c.Children, cmd)
|
|
}
|
|
}
|
|
|
|
// Walk calls fn for the command and all its children.
|
|
func (c *Cmd) Walk(fn func(*Cmd)) {
|
|
fn(c)
|
|
for _, child := range c.Children {
|
|
child.Parent = c
|
|
child.Walk(fn)
|
|
}
|
|
}
|
|
|
|
// PrepareAll performs initialization and linting on the command and all its children.
|
|
func (c *Cmd) PrepareAll() error {
|
|
if c.Use == "" {
|
|
return xerrors.New("command must have a Use field so that it has a name")
|
|
}
|
|
var merr error
|
|
|
|
slices.SortFunc(c.Options, func(a, b Option) bool {
|
|
return a.Flag < b.Flag
|
|
})
|
|
for _, opt := range c.Options {
|
|
if opt.Name == "" {
|
|
switch {
|
|
case opt.Flag != "":
|
|
opt.Name = opt.Flag
|
|
case opt.Env != "":
|
|
opt.Name = opt.Env
|
|
case opt.YAML != "":
|
|
opt.Name = opt.YAML
|
|
default:
|
|
merr = errors.Join(merr, xerrors.Errorf("option must have a Name, Flag, Env or YAML field"))
|
|
}
|
|
}
|
|
if opt.Description != "" {
|
|
// Enforce that description uses sentence form.
|
|
if unicode.IsLower(rune(opt.Description[0])) {
|
|
merr = errors.Join(merr, xerrors.Errorf("option %q description should start with a capital letter", opt.Name))
|
|
}
|
|
if !strings.HasSuffix(opt.Description, ".") {
|
|
merr = errors.Join(merr, xerrors.Errorf("option %q description should end with a period", opt.Name))
|
|
}
|
|
}
|
|
}
|
|
slices.SortFunc(c.Children, func(a, b *Cmd) bool {
|
|
return a.Name() < b.Name()
|
|
})
|
|
for _, child := range c.Children {
|
|
child.Parent = c
|
|
err := child.PrepareAll()
|
|
if err != nil {
|
|
merr = errors.Join(merr, xerrors.Errorf("command %v: %w", child.Name(), err))
|
|
}
|
|
}
|
|
return merr
|
|
}
|
|
|
|
// Name returns the first word in the Use string.
|
|
func (c *Cmd) Name() string {
|
|
return strings.Split(c.Use, " ")[0]
|
|
}
|
|
|
|
// FullName returns the full invocation name of the command,
|
|
// as seen on the command line.
|
|
func (c *Cmd) FullName() string {
|
|
var names []string
|
|
if c.Parent != nil {
|
|
names = append(names, c.Parent.FullName())
|
|
}
|
|
names = append(names, c.Name())
|
|
return strings.Join(names, " ")
|
|
}
|
|
|
|
// FullName returns usage of the command, preceded
|
|
// by the usage of its parents.
|
|
func (c *Cmd) FullUsage() string {
|
|
var uses []string
|
|
if c.Parent != nil {
|
|
uses = append(uses, c.Parent.FullName())
|
|
}
|
|
uses = append(uses, c.Use)
|
|
return strings.Join(uses, " ")
|
|
}
|
|
|
|
// Invoke creates a new invocation of the command, with
|
|
// stdio discarded.
|
|
//
|
|
// The returned invocation is not live until Run() is called.
|
|
func (c *Cmd) Invoke(args ...string) *Invocation {
|
|
return &Invocation{
|
|
Command: c,
|
|
Args: args,
|
|
Stdout: io.Discard,
|
|
Stderr: io.Discard,
|
|
Stdin: strings.NewReader(""),
|
|
}
|
|
}
|
|
|
|
// Invocation represents an instance of a command being executed.
|
|
type Invocation struct {
|
|
ctx context.Context
|
|
Command *Cmd
|
|
parsedFlags *pflag.FlagSet
|
|
Args []string
|
|
// Environ is a list of environment variables. Use EnvsWithPrefix to parse
|
|
// os.Environ.
|
|
Environ Environ
|
|
Stdout io.Writer
|
|
Stderr io.Writer
|
|
Stdin io.Reader
|
|
}
|
|
|
|
// WithOS returns the invocation as a main package, filling in the invocation's unset
|
|
// fields with OS defaults.
|
|
func (i *Invocation) WithOS() *Invocation {
|
|
return i.with(func(i *Invocation) {
|
|
i.Stdout = os.Stdout
|
|
i.Stderr = os.Stderr
|
|
i.Stdin = os.Stdin
|
|
i.Args = os.Args[1:]
|
|
i.Environ = ParseEnviron(os.Environ(), "")
|
|
})
|
|
}
|
|
|
|
func (i *Invocation) Context() context.Context {
|
|
if i.ctx == nil {
|
|
return context.Background()
|
|
}
|
|
return i.ctx
|
|
}
|
|
|
|
func (i *Invocation) ParsedFlags() *pflag.FlagSet {
|
|
if i.parsedFlags == nil {
|
|
panic("flags not parsed, has Run() been called?")
|
|
}
|
|
return i.parsedFlags
|
|
}
|
|
|
|
type runState struct {
|
|
allArgs []string
|
|
commandDepth int
|
|
|
|
flagParseErr error
|
|
}
|
|
|
|
func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
|
|
fs2 := pflag.NewFlagSet("", pflag.ContinueOnError)
|
|
fs2.Usage = func() {}
|
|
fs.VisitAll(func(f *pflag.Flag) {
|
|
if f.Name == without {
|
|
return
|
|
}
|
|
fs2.AddFlag(f)
|
|
})
|
|
return fs2
|
|
}
|
|
|
|
// run recursively executes the command and its children.
|
|
// allArgs is wired through the stack so that global flags can be accepted
|
|
// anywhere in the command invocation.
|
|
func (i *Invocation) run(state *runState) error {
|
|
err := i.Command.Options.SetDefaults()
|
|
if err != nil {
|
|
return xerrors.Errorf("setting defaults: %w", err)
|
|
}
|
|
|
|
// If we set the Default of an array but later see a flag for it, we
|
|
// don't want to append, we want to replace. So, we need to keep the state
|
|
// of defaulted array options.
|
|
defaultedArrays := make(map[string]int)
|
|
for _, opt := range i.Command.Options {
|
|
sv, ok := opt.Value.(pflag.SliceValue)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if opt.Flag == "" {
|
|
continue
|
|
}
|
|
|
|
defaultedArrays[opt.Flag] = len(sv.GetSlice())
|
|
}
|
|
|
|
err = i.Command.Options.ParseEnv(i.Environ)
|
|
if err != nil {
|
|
return xerrors.Errorf("parsing env: %w", err)
|
|
}
|
|
|
|
// Now the fun part, argument parsing!
|
|
|
|
children := make(map[string]*Cmd)
|
|
for _, child := range i.Command.Children {
|
|
child.Parent = i.Command
|
|
for _, name := range append(child.Aliases, child.Name()) {
|
|
if _, ok := children[name]; ok {
|
|
return xerrors.Errorf("duplicate command name: %s", name)
|
|
}
|
|
children[name] = child
|
|
}
|
|
}
|
|
|
|
if i.parsedFlags == nil {
|
|
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
|
|
// We handle Usage ourselves.
|
|
i.parsedFlags.Usage = func() {}
|
|
}
|
|
|
|
// If we find a duplicate flag, we want the deeper command's flag to override
|
|
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
|
|
// have to create a copy of the flagset without a value.
|
|
i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
|
|
if i.parsedFlags.Lookup(f.Name) != nil {
|
|
i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name)
|
|
}
|
|
i.parsedFlags.AddFlag(f)
|
|
})
|
|
|
|
var parsedArgs []string
|
|
|
|
if !i.Command.RawArgs {
|
|
// Flag parsing will fail on intermediate commands in the command tree,
|
|
// so we check the error after looking for a child command.
|
|
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
|
|
parsedArgs = i.parsedFlags.Args()
|
|
|
|
i.parsedFlags.VisitAll(func(f *pflag.Flag) {
|
|
i, ok := defaultedArrays[f.Name]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if !f.Changed {
|
|
return
|
|
}
|
|
|
|
sv, ok := f.Value.(pflag.SliceValue)
|
|
if !ok {
|
|
panic("defaulted array option is not a slice value")
|
|
}
|
|
err := sv.Replace(sv.GetSlice()[i:])
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Run child command if found (next child only)
|
|
// We must do subcommand detection after flag parsing so we don't mistake flag
|
|
// values for subcommand names.
|
|
if len(parsedArgs) > state.commandDepth {
|
|
nextArg := parsedArgs[state.commandDepth]
|
|
if child, ok := children[nextArg]; ok {
|
|
child.Parent = i.Command
|
|
i.Command = child
|
|
state.commandDepth++
|
|
return i.run(state)
|
|
}
|
|
}
|
|
|
|
// Flag parse errors are irrelevant for raw args commands.
|
|
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
|
|
return xerrors.Errorf(
|
|
"parsing flags (%v) for %q: %w",
|
|
state.allArgs,
|
|
i.Command.FullName(), state.flagParseErr,
|
|
)
|
|
}
|
|
|
|
if i.Command.RawArgs {
|
|
// If we're at the root command, then the name is omitted
|
|
// from the arguments, so we can just use the entire slice.
|
|
if state.commandDepth == 0 {
|
|
i.Args = state.allArgs
|
|
} else {
|
|
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
i.Args = state.allArgs[argPos+1:]
|
|
}
|
|
} else {
|
|
// In non-raw-arg mode, we want to skip over flags.
|
|
i.Args = parsedArgs[state.commandDepth:]
|
|
}
|
|
|
|
mw := i.Command.Middleware
|
|
if mw == nil {
|
|
mw = Chain()
|
|
}
|
|
|
|
ctx := i.ctx
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
i = i.WithContext(ctx)
|
|
|
|
if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
|
|
if i.Command.HelpHandler == nil {
|
|
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
|
|
}
|
|
return i.Command.HelpHandler(i)
|
|
}
|
|
|
|
err = mw(i.Command.Handler)(i)
|
|
if err != nil {
|
|
return &RunCommandError{
|
|
Cmd: i.Command,
|
|
Err: err,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type RunCommandError struct {
|
|
Cmd *Cmd
|
|
Err error
|
|
}
|
|
|
|
func (e *RunCommandError) Unwrap() error {
|
|
return e.Err
|
|
}
|
|
|
|
func (e *RunCommandError) Error() string {
|
|
return fmt.Sprintf("running command %q: %+v", e.Cmd.FullName(), e.Err)
|
|
}
|
|
|
|
// findArg returns the index of the first occurrence of arg in args, skipping
|
|
// over all flags.
|
|
func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
|
|
for i := 0; i < len(args); i++ {
|
|
arg := args[i]
|
|
if !strings.HasPrefix(arg, "-") {
|
|
if arg == want {
|
|
return i, nil
|
|
}
|
|
continue
|
|
}
|
|
|
|
// This is a flag!
|
|
if strings.Contains(arg, "=") {
|
|
// The flag contains the value in the same arg, just skip.
|
|
continue
|
|
}
|
|
|
|
// We need to check if NoOptValue is set, then we should not wait
|
|
// for the next arg to be the value.
|
|
f := fs.Lookup(strings.TrimLeft(arg, "-"))
|
|
if f == nil {
|
|
return -1, xerrors.Errorf("unknown flag: %s", arg)
|
|
}
|
|
if f.NoOptDefVal != "" {
|
|
continue
|
|
}
|
|
|
|
if i == len(args)-1 {
|
|
return -1, xerrors.Errorf("flag %s requires a value", arg)
|
|
}
|
|
|
|
// Skip the value.
|
|
i++
|
|
}
|
|
|
|
return -1, xerrors.Errorf("arg %s not found", want)
|
|
}
|
|
|
|
// Run executes the command.
|
|
// If two command share a flag name, the first command wins.
|
|
//
|
|
//nolint:revive
|
|
func (i *Invocation) Run() (err error) {
|
|
defer func() {
|
|
// Pflag is panicky, so additional context is helpful in tests.
|
|
if flag.Lookup("test.v") == nil {
|
|
return
|
|
}
|
|
if r := recover(); r != nil {
|
|
err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r)
|
|
panic(err)
|
|
}
|
|
}()
|
|
err = i.run(&runState{
|
|
allArgs: i.Args,
|
|
})
|
|
return err
|
|
}
|
|
|
|
// WithContext returns a copy of the Invocation with the given context.
|
|
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
|
|
return i.with(func(i *Invocation) {
|
|
i.ctx = ctx
|
|
})
|
|
}
|
|
|
|
// with returns a copy of the Invocation with the given function applied.
|
|
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
|
|
i2 := *i
|
|
fn(&i2)
|
|
return &i2
|
|
}
|
|
|
|
// MiddlewareFunc returns the next handler in the chain,
|
|
// or nil if there are no more.
|
|
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
|
|
|
|
func chain(ms ...MiddlewareFunc) MiddlewareFunc {
|
|
return MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
|
|
if len(ms) > 0 {
|
|
return chain(ms[1:]...)(ms[0](next))
|
|
}
|
|
return next
|
|
})
|
|
}
|
|
|
|
// Chain returns a Handler that first calls middleware in order.
|
|
//
|
|
//nolint:revive
|
|
func Chain(ms ...MiddlewareFunc) MiddlewareFunc {
|
|
// We need to reverse the array to provide top-to-bottom execution
|
|
// order when defining a command.
|
|
reversed := make([]MiddlewareFunc, len(ms))
|
|
for i := range ms {
|
|
reversed[len(ms)-1-i] = ms[i]
|
|
}
|
|
return chain(reversed...)
|
|
}
|
|
|
|
func RequireNArgs(want int) MiddlewareFunc {
|
|
return RequireRangeArgs(want, want)
|
|
}
|
|
|
|
// RequireRangeArgs returns a Middleware that requires the number of arguments
|
|
// to be between start and end (inclusive). If end is -1, then the number of
|
|
// arguments must be at least start.
|
|
func RequireRangeArgs(start, end int) MiddlewareFunc {
|
|
if start < 0 {
|
|
panic("start must be >= 0")
|
|
}
|
|
return func(next HandlerFunc) HandlerFunc {
|
|
return func(i *Invocation) error {
|
|
got := len(i.Args)
|
|
switch {
|
|
case start == end && got != start:
|
|
switch start {
|
|
case 0:
|
|
if len(i.Command.Children) > 0 {
|
|
return xerrors.Errorf("unrecognized subcommand %q", i.Args[0])
|
|
}
|
|
return xerrors.Errorf("wanted no args but got %v %v", got, i.Args)
|
|
default:
|
|
return xerrors.Errorf(
|
|
"wanted %v args but got %v %v",
|
|
start,
|
|
got,
|
|
i.Args,
|
|
)
|
|
}
|
|
case start > 0 && end == -1:
|
|
switch {
|
|
case got < start:
|
|
return xerrors.Errorf(
|
|
"wanted at least %v args but got %v",
|
|
start,
|
|
got,
|
|
)
|
|
default:
|
|
return next(i)
|
|
}
|
|
case start > end:
|
|
panic("start must be <= end")
|
|
case got < start || got > end:
|
|
return xerrors.Errorf(
|
|
"wanted between %v and %v args but got %v",
|
|
start, end,
|
|
got,
|
|
)
|
|
default:
|
|
return next(i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// HandlerFunc handles an Invocation of a command.
|
|
type HandlerFunc func(i *Invocation) error
|