coder/cli/clibase/cmd.go

548 lines
13 KiB
Go

package clibase
import (
"context"
"errors"
"flag"
"fmt"
"io"
"os"
"strings"
"unicode"
"github.com/spf13/pflag"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
)
// Cmd describes an executable command.
type Cmd struct {
// Parent is the direct parent of the command.
Parent *Cmd
// Children is a list of direct descendants.
Children []*Cmd
// Use is provided in form "command [flags] [args...]".
Use string
// Aliases is a list of alternative names for the command.
Aliases []string
// Short is a one-line description of the command.
Short string
// Hidden determines whether the command should be hidden from help.
Hidden bool
// RawArgs determines whether the command should receive unparsed arguments.
// No flags are parsed when set, and the command is responsible for parsing
// its own flags.
RawArgs bool
// Long is a detailed description of the command,
// presented on its help page. It may contain examples.
Long string
Options OptionSet
Annotations Annotations
// Middleware is called before the Handler.
// Use Chain() to combine multiple middlewares.
Middleware MiddlewareFunc
Handler HandlerFunc
HelpHandler HandlerFunc
}
// AddSubcommands adds the given subcommands, setting their
// Parent field automatically.
func (c *Cmd) AddSubcommands(cmds ...*Cmd) {
for _, cmd := range cmds {
cmd.Parent = c
c.Children = append(c.Children, cmd)
}
}
// Walk calls fn for the command and all its children.
func (c *Cmd) Walk(fn func(*Cmd)) {
fn(c)
for _, child := range c.Children {
child.Parent = c
child.Walk(fn)
}
}
// PrepareAll performs initialization and linting on the command and all its children.
func (c *Cmd) PrepareAll() error {
if c.Use == "" {
return xerrors.New("command must have a Use field so that it has a name")
}
var merr error
slices.SortFunc(c.Options, func(a, b Option) bool {
return a.Flag < b.Flag
})
for _, opt := range c.Options {
if opt.Name == "" {
switch {
case opt.Flag != "":
opt.Name = opt.Flag
case opt.Env != "":
opt.Name = opt.Env
case opt.YAML != "":
opt.Name = opt.YAML
default:
merr = errors.Join(merr, xerrors.Errorf("option must have a Name, Flag, Env or YAML field"))
}
}
if opt.Description != "" {
// Enforce that description uses sentence form.
if unicode.IsLower(rune(opt.Description[0])) {
merr = errors.Join(merr, xerrors.Errorf("option %q description should start with a capital letter", opt.Name))
}
if !strings.HasSuffix(opt.Description, ".") {
merr = errors.Join(merr, xerrors.Errorf("option %q description should end with a period", opt.Name))
}
}
}
slices.SortFunc(c.Children, func(a, b *Cmd) bool {
return a.Name() < b.Name()
})
for _, child := range c.Children {
child.Parent = c
err := child.PrepareAll()
if err != nil {
merr = errors.Join(merr, xerrors.Errorf("command %v: %w", child.Name(), err))
}
}
return merr
}
// Name returns the first word in the Use string.
func (c *Cmd) Name() string {
return strings.Split(c.Use, " ")[0]
}
// FullName returns the full invocation name of the command,
// as seen on the command line.
func (c *Cmd) FullName() string {
var names []string
if c.Parent != nil {
names = append(names, c.Parent.FullName())
}
names = append(names, c.Name())
return strings.Join(names, " ")
}
// FullName returns usage of the command, preceded
// by the usage of its parents.
func (c *Cmd) FullUsage() string {
var uses []string
if c.Parent != nil {
uses = append(uses, c.Parent.FullName())
}
uses = append(uses, c.Use)
return strings.Join(uses, " ")
}
// Invoke creates a new invocation of the command, with
// stdio discarded.
//
// The returned invocation is not live until Run() is called.
func (c *Cmd) Invoke(args ...string) *Invocation {
return &Invocation{
Command: c,
Args: args,
Stdout: io.Discard,
Stderr: io.Discard,
Stdin: strings.NewReader(""),
}
}
// Invocation represents an instance of a command being executed.
type Invocation struct {
ctx context.Context
Command *Cmd
parsedFlags *pflag.FlagSet
Args []string
// Environ is a list of environment variables. Use EnvsWithPrefix to parse
// os.Environ.
Environ Environ
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
}
// WithOS returns the invocation as a main package, filling in the invocation's unset
// fields with OS defaults.
func (i *Invocation) WithOS() *Invocation {
return i.with(func(i *Invocation) {
i.Stdout = os.Stdout
i.Stderr = os.Stderr
i.Stdin = os.Stdin
i.Args = os.Args[1:]
i.Environ = ParseEnviron(os.Environ(), "")
})
}
func (i *Invocation) Context() context.Context {
if i.ctx == nil {
return context.Background()
}
return i.ctx
}
func (i *Invocation) ParsedFlags() *pflag.FlagSet {
if i.parsedFlags == nil {
panic("flags not parsed, has Run() been called?")
}
return i.parsedFlags
}
type runState struct {
allArgs []string
commandDepth int
flagParseErr error
}
func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
fs2 := pflag.NewFlagSet("", pflag.ContinueOnError)
fs2.Usage = func() {}
fs.VisitAll(func(f *pflag.Flag) {
if f.Name == without {
return
}
fs2.AddFlag(f)
})
return fs2
}
// run recursively executes the command and its children.
// allArgs is wired through the stack so that global flags can be accepted
// anywhere in the command invocation.
func (i *Invocation) run(state *runState) error {
err := i.Command.Options.SetDefaults()
if err != nil {
return xerrors.Errorf("setting defaults: %w", err)
}
// If we set the Default of an array but later see a flag for it, we
// don't want to append, we want to replace. So, we need to keep the state
// of defaulted array options.
defaultedArrays := make(map[string]int)
for _, opt := range i.Command.Options {
sv, ok := opt.Value.(pflag.SliceValue)
if !ok {
continue
}
if opt.Flag == "" {
continue
}
defaultedArrays[opt.Flag] = len(sv.GetSlice())
}
err = i.Command.Options.ParseEnv(i.Environ)
if err != nil {
return xerrors.Errorf("parsing env: %w", err)
}
// Now the fun part, argument parsing!
children := make(map[string]*Cmd)
for _, child := range i.Command.Children {
child.Parent = i.Command
for _, name := range append(child.Aliases, child.Name()) {
if _, ok := children[name]; ok {
return xerrors.Errorf("duplicate command name: %s", name)
}
children[name] = child
}
}
if i.parsedFlags == nil {
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
// We handle Usage ourselves.
i.parsedFlags.Usage = func() {}
}
// If we find a duplicate flag, we want the deeper command's flag to override
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
// have to create a copy of the flagset without a value.
i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
if i.parsedFlags.Lookup(f.Name) != nil {
i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name)
}
i.parsedFlags.AddFlag(f)
})
var parsedArgs []string
if !i.Command.RawArgs {
// Flag parsing will fail on intermediate commands in the command tree,
// so we check the error after looking for a child command.
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
parsedArgs = i.parsedFlags.Args()
i.parsedFlags.VisitAll(func(f *pflag.Flag) {
i, ok := defaultedArrays[f.Name]
if !ok {
return
}
if !f.Changed {
return
}
sv, ok := f.Value.(pflag.SliceValue)
if !ok {
panic("defaulted array option is not a slice value")
}
err := sv.Replace(sv.GetSlice()[i:])
if err != nil {
panic(err)
}
})
}
// Run child command if found (next child only)
// We must do subcommand detection after flag parsing so we don't mistake flag
// values for subcommand names.
if len(parsedArgs) > state.commandDepth {
nextArg := parsedArgs[state.commandDepth]
if child, ok := children[nextArg]; ok {
child.Parent = i.Command
i.Command = child
state.commandDepth++
return i.run(state)
}
}
// Flag parse errors are irrelevant for raw args commands.
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf(
"parsing flags (%v) for %q: %w",
state.allArgs,
i.Command.FullName(), state.flagParseErr,
)
}
if i.Command.RawArgs {
// If we're at the root command, then the name is omitted
// from the arguments, so we can just use the entire slice.
if state.commandDepth == 0 {
i.Args = state.allArgs
} else {
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
if err != nil {
panic(err)
}
i.Args = state.allArgs[argPos+1:]
}
} else {
// In non-raw-arg mode, we want to skip over flags.
i.Args = parsedArgs[state.commandDepth:]
}
mw := i.Command.Middleware
if mw == nil {
mw = Chain()
}
ctx := i.ctx
if ctx == nil {
ctx = context.Background()
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
i = i.WithContext(ctx)
if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
if i.Command.HelpHandler == nil {
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
}
return i.Command.HelpHandler(i)
}
err = mw(i.Command.Handler)(i)
if err != nil {
return &RunCommandError{
Cmd: i.Command,
Err: err,
}
}
return nil
}
type RunCommandError struct {
Cmd *Cmd
Err error
}
func (e *RunCommandError) Unwrap() error {
return e.Err
}
func (e *RunCommandError) Error() string {
return fmt.Sprintf("running command %q: %+v", e.Cmd.FullName(), e.Err)
}
// findArg returns the index of the first occurrence of arg in args, skipping
// over all flags.
func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
for i := 0; i < len(args); i++ {
arg := args[i]
if !strings.HasPrefix(arg, "-") {
if arg == want {
return i, nil
}
continue
}
// This is a flag!
if strings.Contains(arg, "=") {
// The flag contains the value in the same arg, just skip.
continue
}
// We need to check if NoOptValue is set, then we should not wait
// for the next arg to be the value.
f := fs.Lookup(strings.TrimLeft(arg, "-"))
if f == nil {
return -1, xerrors.Errorf("unknown flag: %s", arg)
}
if f.NoOptDefVal != "" {
continue
}
if i == len(args)-1 {
return -1, xerrors.Errorf("flag %s requires a value", arg)
}
// Skip the value.
i++
}
return -1, xerrors.Errorf("arg %s not found", want)
}
// Run executes the command.
// If two command share a flag name, the first command wins.
//
//nolint:revive
func (i *Invocation) Run() (err error) {
defer func() {
// Pflag is panicky, so additional context is helpful in tests.
if flag.Lookup("test.v") == nil {
return
}
if r := recover(); r != nil {
err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r)
panic(err)
}
}()
err = i.run(&runState{
allArgs: i.Args,
})
return err
}
// WithContext returns a copy of the Invocation with the given context.
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
return i.with(func(i *Invocation) {
i.ctx = ctx
})
}
// with returns a copy of the Invocation with the given function applied.
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
i2 := *i
fn(&i2)
return &i2
}
// MiddlewareFunc returns the next handler in the chain,
// or nil if there are no more.
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
func chain(ms ...MiddlewareFunc) MiddlewareFunc {
return MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
if len(ms) > 0 {
return chain(ms[1:]...)(ms[0](next))
}
return next
})
}
// Chain returns a Handler that first calls middleware in order.
//
//nolint:revive
func Chain(ms ...MiddlewareFunc) MiddlewareFunc {
// We need to reverse the array to provide top-to-bottom execution
// order when defining a command.
reversed := make([]MiddlewareFunc, len(ms))
for i := range ms {
reversed[len(ms)-1-i] = ms[i]
}
return chain(reversed...)
}
func RequireNArgs(want int) MiddlewareFunc {
return RequireRangeArgs(want, want)
}
// RequireRangeArgs returns a Middleware that requires the number of arguments
// to be between start and end (inclusive). If end is -1, then the number of
// arguments must be at least start.
func RequireRangeArgs(start, end int) MiddlewareFunc {
if start < 0 {
panic("start must be >= 0")
}
return func(next HandlerFunc) HandlerFunc {
return func(i *Invocation) error {
got := len(i.Args)
switch {
case start == end && got != start:
switch start {
case 0:
if len(i.Command.Children) > 0 {
return xerrors.Errorf("unrecognized subcommand %q", i.Args[0])
}
return xerrors.Errorf("wanted no args but got %v %v", got, i.Args)
default:
return xerrors.Errorf(
"wanted %v args but got %v %v",
start,
got,
i.Args,
)
}
case start > 0 && end == -1:
switch {
case got < start:
return xerrors.Errorf(
"wanted at least %v args but got %v",
start,
got,
)
default:
return next(i)
}
case start > end:
panic("start must be <= end")
case got < start || got > end:
return xerrors.Errorf(
"wanted between %v and %v args but got %v",
start, end,
got,
)
default:
return next(i)
}
}
}
}
// HandlerFunc handles an Invocation of a command.
type HandlerFunc func(i *Invocation) error