mirror of https://github.com/coder/coder.git
chore: remove middleware to request version and entitlement warnings (#12750)
This cleans up `root.go` a bit, adds tests for middleware HTTP transport functions, and removes two HTTP requests we always always performed previously when executing *any* client command. It should improve CLI performance (especially for users with higher latency).
This commit is contained in:
parent
ba3879ac47
commit
03ab37b343
13
cli/login.go
13
cli/login.go
|
@ -18,7 +18,6 @@ import (
|
||||||
|
|
||||||
"github.com/coder/pretty"
|
"github.com/coder/pretty"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/buildinfo"
|
|
||||||
"github.com/coder/coder/v2/cli/cliui"
|
"github.com/coder/coder/v2/cli/cliui"
|
||||||
"github.com/coder/coder/v2/coderd/userpassword"
|
"github.com/coder/coder/v2/coderd/userpassword"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
@ -180,21 +179,11 @@ func (r *RootCmd) login() *serpent.Command {
|
||||||
serverURL.Scheme = "https"
|
serverURL.Scheme = "https"
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := r.createUnauthenticatedClient(ctx, serverURL)
|
client, err := r.createUnauthenticatedClient(ctx, serverURL, inv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to check the version of the server prior to logging in.
|
|
||||||
// It may be useful to warn the user if they are trying to login
|
|
||||||
// on a very old client.
|
|
||||||
err = r.checkVersions(inv, client, buildinfo.Version())
|
|
||||||
if err != nil {
|
|
||||||
// Checking versions isn't a fatal error so we print a warning
|
|
||||||
// and proceed.
|
|
||||||
_, _ = fmt.Fprintln(inv.Stderr, pretty.Sprint(cliui.DefaultStyles.Warn, err.Error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
hasFirstUser, err := client.HasFirstUser(ctx)
|
hasFirstUser, err := client.HasFirstUser(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err)
|
return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err)
|
||||||
|
|
|
@ -97,33 +97,6 @@ func TestLogout(t *testing.T) {
|
||||||
|
|
||||||
<-logoutChan
|
<-logoutChan
|
||||||
})
|
})
|
||||||
t.Run("NoSessionFile", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
pty := ptytest.New(t)
|
|
||||||
config := login(t, pty)
|
|
||||||
|
|
||||||
// Ensure session files exist.
|
|
||||||
require.FileExists(t, string(config.URL()))
|
|
||||||
require.FileExists(t, string(config.Session()))
|
|
||||||
|
|
||||||
err := os.Remove(string(config.Session()))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
logoutChan := make(chan struct{})
|
|
||||||
logout, _ := clitest.New(t, "logout", "--global-config", string(config))
|
|
||||||
|
|
||||||
logout.Stdin = pty.Input()
|
|
||||||
logout.Stdout = pty.Output()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(logoutChan)
|
|
||||||
err = logout.Run()
|
|
||||||
assert.ErrorContains(t, err, "You are not logged in. Try logging in using 'coder login'.")
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-logoutChan
|
|
||||||
})
|
|
||||||
t.Run("CannotDeleteFiles", func(t *testing.T) {
|
t.Run("CannotDeleteFiles", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
362
cli/root.go
362
cli/root.go
|
@ -9,8 +9,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
@ -20,6 +18,7 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/trace"
|
"runtime/trace"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"text/tabwriter"
|
"text/tabwriter"
|
||||||
"time"
|
"time"
|
||||||
|
@ -27,6 +26,7 @@ import (
|
||||||
"github.com/mattn/go-isatty"
|
"github.com/mattn/go-isatty"
|
||||||
"github.com/mitchellh/go-wordwrap"
|
"github.com/mitchellh/go-wordwrap"
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
"golang.org/x/mod/semver"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/coder/pretty"
|
"github.com/coder/pretty"
|
||||||
|
@ -67,8 +67,7 @@ const (
|
||||||
varOrganizationSelect = "organization"
|
varOrganizationSelect = "organization"
|
||||||
varDisableDirect = "disable-direct-connections"
|
varDisableDirect = "disable-direct-connections"
|
||||||
|
|
||||||
notLoggedInMessage = "You are not logged in. Try logging in using 'coder login <url>'."
|
notLoggedInMessage = "You are not logged in. Try logging in using 'coder login <url>'."
|
||||||
notLoggedInURLSavedMessage = "You are not logged in. Try logging in using 'coder login'."
|
|
||||||
|
|
||||||
envNoVersionCheck = "CODER_NO_VERSION_WARNING"
|
envNoVersionCheck = "CODER_NO_VERSION_WARNING"
|
||||||
envNoFeatureWarning = "CODER_NO_FEATURE_WARNING"
|
envNoFeatureWarning = "CODER_NO_FEATURE_WARNING"
|
||||||
|
@ -80,12 +79,7 @@ const (
|
||||||
envURL = "CODER_URL"
|
envURL = "CODER_URL"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
func (r *RootCmd) CoreSubcommands() []*serpent.Command {
|
||||||
errUnauthenticated = xerrors.New(notLoggedInMessage)
|
|
||||||
errUnauthenticatedURLSaved = xerrors.New(notLoggedInURLSavedMessage)
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *RootCmd) Core() []*serpent.Command {
|
|
||||||
// Please re-sort this list alphabetically if you change it!
|
// Please re-sort this list alphabetically if you change it!
|
||||||
return []*serpent.Command{
|
return []*serpent.Command{
|
||||||
r.dotfiles(),
|
r.dotfiles(),
|
||||||
|
@ -134,12 +128,13 @@ func (r *RootCmd) Core() []*serpent.Command {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RootCmd) AGPL() []*serpent.Command {
|
func (r *RootCmd) AGPL() []*serpent.Command {
|
||||||
all := append(r.Core(), r.Server( /* Do not import coderd here. */ nil))
|
all := append(r.CoreSubcommands(), r.Server( /* Do not import coderd here. */ nil))
|
||||||
return all
|
return all
|
||||||
}
|
}
|
||||||
|
|
||||||
// Main is the entrypoint for the Coder CLI.
|
// RunWithSubcommands runs the root command with the given subcommands.
|
||||||
func (r *RootCmd) RunMain(subcommands []*serpent.Command) {
|
// It is abstracted to enable the Enterprise code to add commands.
|
||||||
|
func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) {
|
||||||
// This configuration is not available as a standard option because we
|
// This configuration is not available as a standard option because we
|
||||||
// want to trace the entire program, including Options parsing.
|
// want to trace the entire program, including Options parsing.
|
||||||
goTraceFilePath, ok := os.LookupEnv("CODER_GO_TRACE")
|
goTraceFilePath, ok := os.LookupEnv("CODER_GO_TRACE")
|
||||||
|
@ -156,8 +151,6 @@ func (r *RootCmd) RunMain(subcommands []*serpent.Command) {
|
||||||
defer trace.Stop()
|
defer trace.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
rand.Seed(time.Now().UnixMicro())
|
|
||||||
|
|
||||||
cmd, err := r.Command(subcommands)
|
cmd, err := r.Command(subcommands)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -503,79 +496,20 @@ type RootCmd struct {
|
||||||
noFeatureWarning bool
|
noFeatureWarning bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func addTelemetryHeader(client *codersdk.Client, inv *serpent.Invocation) {
|
// InitClient authenticates the client with files from disk
|
||||||
transport, ok := client.HTTPClient.Transport.(*codersdk.HeaderTransport)
|
// and injects header middlewares for telemetry, authentication,
|
||||||
if !ok {
|
// and version checks.
|
||||||
transport = &codersdk.HeaderTransport{
|
|
||||||
Transport: client.HTTPClient.Transport,
|
|
||||||
Header: http.Header{},
|
|
||||||
}
|
|
||||||
client.HTTPClient.Transport = transport
|
|
||||||
}
|
|
||||||
|
|
||||||
var topts []telemetry.Option
|
|
||||||
for _, opt := range inv.Command.FullOptions() {
|
|
||||||
if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
topts = append(topts, telemetry.Option{
|
|
||||||
Name: opt.Name,
|
|
||||||
ValueSource: string(opt.ValueSource),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
ti := telemetry.Invocation{
|
|
||||||
Command: inv.Command.FullName(),
|
|
||||||
Options: topts,
|
|
||||||
InvokedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
byt, err := json.Marshal(ti)
|
|
||||||
if err != nil {
|
|
||||||
// Should be impossible
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per https://stackoverflow.com/questions/686217/maximum-on-http-header-values,
|
|
||||||
// we don't want to send headers that are too long.
|
|
||||||
s := base64.StdEncoding.EncodeToString(byt)
|
|
||||||
if len(s) > 4096 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
transport.Header.Add(codersdk.CLITelemetryHeader, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitClient sets client to a new client.
|
|
||||||
// It reads from global configuration files if flags are not set.
|
|
||||||
func (r *RootCmd) InitClient(client *codersdk.Client) serpent.MiddlewareFunc {
|
func (r *RootCmd) InitClient(client *codersdk.Client) serpent.MiddlewareFunc {
|
||||||
return serpent.Chain(
|
|
||||||
r.initClientInternal(client, false),
|
|
||||||
// By default, we should print warnings in addition to initializing the client
|
|
||||||
r.PrintWarnings(client),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RootCmd) InitClientMissingTokenOK(client *codersdk.Client) serpent.MiddlewareFunc {
|
|
||||||
return r.initClientInternal(client, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint: revive
|
|
||||||
func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing bool) serpent.MiddlewareFunc {
|
|
||||||
if client == nil {
|
|
||||||
panic("client is nil")
|
|
||||||
}
|
|
||||||
if r == nil {
|
|
||||||
panic("root is nil")
|
|
||||||
}
|
|
||||||
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
|
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
|
||||||
return func(inv *serpent.Invocation) error {
|
return func(inv *serpent.Invocation) error {
|
||||||
conf := r.createConfig()
|
conf := r.createConfig()
|
||||||
var err error
|
var err error
|
||||||
|
// Read the client URL stored on disk.
|
||||||
if r.clientURL == nil || r.clientURL.String() == "" {
|
if r.clientURL == nil || r.clientURL.String() == "" {
|
||||||
rawURL, err := conf.URL().Read()
|
rawURL, err := conf.URL().Read()
|
||||||
// If the configuration files are absent, the user is logged out
|
// If the configuration files are absent, the user is logged out
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return errUnauthenticated
|
return xerrors.New(notLoggedInMessage)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -586,25 +520,20 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Read the token stored on disk.
|
||||||
if r.token == "" {
|
if r.token == "" {
|
||||||
r.token, err = conf.Session().Read()
|
r.token, err = conf.Session().Read()
|
||||||
// If the configuration files are absent, the user is logged out
|
// Even if there isn't a token, we don't care.
|
||||||
if os.IsNotExist(err) {
|
// Some API routes can be unauthenticated.
|
||||||
if !allowTokenMissing {
|
if err != nil && !os.IsNotExist(err) {
|
||||||
return errUnauthenticatedURLSaved
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = r.setClient(inv.Context(), client, r.clientURL)
|
|
||||||
|
err = r.configureClient(inv.Context(), client, r.clientURL, inv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
addTelemetryHeader(client, inv)
|
|
||||||
|
|
||||||
client.SetSessionToken(r.token)
|
client.SetSessionToken(r.token)
|
||||||
|
|
||||||
if r.debugHTTP {
|
if r.debugHTTP {
|
||||||
|
@ -617,48 +546,8 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RootCmd) PrintWarnings(client *codersdk.Client) serpent.MiddlewareFunc {
|
// HeaderTransport creates a new transport that executes `--header-command`
|
||||||
if client == nil {
|
// if it is set to add headers for all outbound requests.
|
||||||
panic("client is nil")
|
|
||||||
}
|
|
||||||
if r == nil {
|
|
||||||
panic("root is nil")
|
|
||||||
}
|
|
||||||
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
|
|
||||||
return func(inv *serpent.Invocation) error {
|
|
||||||
// We send these requests in parallel to minimize latency.
|
|
||||||
var (
|
|
||||||
versionErr = make(chan error)
|
|
||||||
warningErr = make(chan error)
|
|
||||||
)
|
|
||||||
go func() {
|
|
||||||
versionErr <- r.checkVersions(inv, client, buildinfo.Version())
|
|
||||||
close(versionErr)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
warningErr <- r.checkWarnings(inv, client)
|
|
||||||
close(warningErr)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := <-versionErr; err != nil {
|
|
||||||
// Just log the error here. We never want to fail a command
|
|
||||||
// due to a pre-run.
|
|
||||||
pretty.Fprintf(inv.Stderr, cliui.DefaultStyles.Warn, "check versions error: %s", err)
|
|
||||||
_, _ = fmt.Fprintln(inv.Stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := <-warningErr; err != nil {
|
|
||||||
// Same as above
|
|
||||||
pretty.Fprintf(inv.Stderr, cliui.DefaultStyles.Warn, "check entitlement warnings error: %s", err)
|
|
||||||
_, _ = fmt.Fprintln(inv.Stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return next(inv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) {
|
func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) {
|
||||||
transport := &codersdk.HeaderTransport{
|
transport := &codersdk.HeaderTransport{
|
||||||
Transport: http.DefaultTransport,
|
Transport: http.DefaultTransport,
|
||||||
|
@ -700,22 +589,38 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
|
||||||
return transport, nil
|
return transport, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL) error {
|
func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error {
|
||||||
transport, err := r.HeaderTransport(ctx, serverURL)
|
transport := http.DefaultTransport
|
||||||
|
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||||
|
if !r.noVersionCheck {
|
||||||
|
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||||
|
// Create a new client without any wrapped transport
|
||||||
|
// otherwise it creates an infinite loop!
|
||||||
|
basicClient := codersdk.New(serverURL)
|
||||||
|
return basicClient.BuildInfo(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if !r.noFeatureWarning {
|
||||||
|
transport = wrapTransportWithEntitlementsCheck(transport, inv.Stderr)
|
||||||
|
}
|
||||||
|
headerTransport, err := r.HeaderTransport(ctx, serverURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create header transport: %w", err)
|
return xerrors.Errorf("create header transport: %w", err)
|
||||||
}
|
}
|
||||||
|
// The header transport has to come last.
|
||||||
client.URL = serverURL
|
// codersdk checks for the header transport to get headers
|
||||||
|
// to clone on the DERP client.
|
||||||
|
headerTransport.Transport = transport
|
||||||
client.HTTPClient = &http.Client{
|
client.HTTPClient = &http.Client{
|
||||||
Transport: transport,
|
Transport: headerTransport,
|
||||||
}
|
}
|
||||||
|
client.URL = serverURL
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *url.URL) (*codersdk.Client, error) {
|
func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *url.URL, inv *serpent.Invocation) (*codersdk.Client, error) {
|
||||||
var client codersdk.Client
|
var client codersdk.Client
|
||||||
err := r.setClient(ctx, &client, serverURL)
|
err := r.configureClient(ctx, &client, serverURL, inv)
|
||||||
return &client, err
|
return &client, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -879,70 +784,6 @@ func formatExamples(examples ...example) string {
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkVersions checks to see if there's a version mismatch between the client
|
|
||||||
// and server and prints a message nudging the user to upgrade if a mismatch
|
|
||||||
// is detected. forceCheck is a test flag and should always be false in production.
|
|
||||||
//
|
|
||||||
//nolint:revive
|
|
||||||
func (r *RootCmd) checkVersions(i *serpent.Invocation, client *codersdk.Client, clientVersion string) error {
|
|
||||||
if r.noVersionCheck {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverInfo, err := client.BuildInfo(ctx)
|
|
||||||
// Avoid printing errors that are connection-related.
|
|
||||||
if isConnectionError(err) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("build info: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !buildinfo.VersionsMatch(clientVersion, serverInfo.Version) {
|
|
||||||
upgradeMessage := defaultUpgradeMessage(serverInfo.CanonicalVersion())
|
|
||||||
if serverInfo.UpgradeMessage != "" {
|
|
||||||
upgradeMessage = serverInfo.UpgradeMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
fmtWarningText := "version mismatch: client %s, server %s\n%s"
|
|
||||||
fmtWarn := pretty.Sprint(cliui.DefaultStyles.Warn, fmtWarningText)
|
|
||||||
warning := fmt.Sprintf(fmtWarn, clientVersion, serverInfo.Version, upgradeMessage)
|
|
||||||
|
|
||||||
_, _ = fmt.Fprint(i.Stderr, warning)
|
|
||||||
_, _ = fmt.Fprintln(i.Stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RootCmd) checkWarnings(i *serpent.Invocation, client *codersdk.Client) error {
|
|
||||||
if r.noFeatureWarning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
user, err := client.User(ctx, codersdk.Me)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("get user me: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
entitlements, err := client.Entitlements(ctx)
|
|
||||||
if err == nil {
|
|
||||||
// Don't show warning to regular users.
|
|
||||||
if len(user.Roles) > 0 {
|
|
||||||
for _, w := range entitlements.Warnings {
|
|
||||||
_, _ = fmt.Fprintln(i.Stderr, pretty.Sprint(cliui.DefaultStyles.Warn, w))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verbosef logs a message if verbose mode is enabled.
|
// Verbosef logs a message if verbose mode is enabled.
|
||||||
func (r *RootCmd) Verbosef(inv *serpent.Invocation, fmtStr string, args ...interface{}) {
|
func (r *RootCmd) Verbosef(inv *serpent.Invocation, fmtStr string, args ...interface{}) {
|
||||||
if r.verbose {
|
if r.verbose {
|
||||||
|
@ -1068,19 +909,6 @@ func ExitError(code int, err error) error {
|
||||||
return &exitError{code: code, err: err}
|
return &exitError{code: code, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IiConnectionErr is a convenience function for checking if the source of an
|
|
||||||
// error is due to a 'connection refused', 'no such host', etc.
|
|
||||||
func isConnectionError(err error) bool {
|
|
||||||
var (
|
|
||||||
// E.g. no such host
|
|
||||||
dnsErr *net.DNSError
|
|
||||||
// Eg. connection refused
|
|
||||||
opErr *net.OpError
|
|
||||||
)
|
|
||||||
|
|
||||||
return xerrors.As(err, &dnsErr) || xerrors.As(err, &opErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
type prettyErrorFormatter struct {
|
type prettyErrorFormatter struct {
|
||||||
w io.Writer
|
w io.Writer
|
||||||
// verbose turns on more detailed error logs, such as stack traces.
|
// verbose turns on more detailed error logs, such as stack traces.
|
||||||
|
@ -1305,3 +1133,105 @@ func defaultUpgradeMessage(version string) string {
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
|
||||||
|
// that checks for entitlement warnings and prints them to the user.
|
||||||
|
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
|
||||||
|
var once sync.Once
|
||||||
|
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
res, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
once.Do(func() {
|
||||||
|
for _, warning := range res.Header.Values(codersdk.EntitlementsWarningHeader) {
|
||||||
|
_, _ = fmt.Fprintln(w, pretty.Sprint(cliui.DefaultStyles.Warn, warning))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return res, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
|
||||||
|
// that checks for version mismatches between the client and server. If a mismatch
|
||||||
|
// is detected, a warning is printed to the user.
|
||||||
|
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||||
|
var once sync.Once
|
||||||
|
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
res, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
once.Do(func() {
|
||||||
|
serverVersion := res.Header.Get(codersdk.BuildVersionHeader)
|
||||||
|
if serverVersion == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
|
||||||
|
serverInfo, err := getBuildInfo(inv.Context())
|
||||||
|
if err == nil && serverInfo.UpgradeMessage != "" {
|
||||||
|
upgradeMessage = serverInfo.UpgradeMessage
|
||||||
|
}
|
||||||
|
fmtWarningText := "version mismatch: client %s, server %s\n%s"
|
||||||
|
fmtWarn := pretty.Sprint(cliui.DefaultStyles.Warn, fmtWarningText)
|
||||||
|
warning := fmt.Sprintf(fmtWarn, clientVersion, serverVersion, upgradeMessage)
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintln(inv.Stderr, warning)
|
||||||
|
})
|
||||||
|
return res, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapTransportWithTelemetryHeader adds telemetry headers to report command usage
|
||||||
|
// to an HTTP transport.
|
||||||
|
func wrapTransportWithTelemetryHeader(transport http.RoundTripper, inv *serpent.Invocation) http.RoundTripper {
|
||||||
|
var (
|
||||||
|
value string
|
||||||
|
once sync.Once
|
||||||
|
)
|
||||||
|
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
once.Do(func() {
|
||||||
|
// We only want to compute this header once when a request
|
||||||
|
// first goes out, hence the complexity with locking here.
|
||||||
|
var topts []telemetry.Option
|
||||||
|
for _, opt := range inv.Command.FullOptions() {
|
||||||
|
if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
topts = append(topts, telemetry.Option{
|
||||||
|
Name: opt.Name,
|
||||||
|
ValueSource: string(opt.ValueSource),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ti := telemetry.Invocation{
|
||||||
|
Command: inv.Command.FullName(),
|
||||||
|
Options: topts,
|
||||||
|
InvokedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
byt, err := json.Marshal(ti)
|
||||||
|
if err != nil {
|
||||||
|
// Should be impossible
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
s := base64.StdEncoding.EncodeToString(byt)
|
||||||
|
// Don't send the header if it's too long!
|
||||||
|
if len(s) <= 4096 {
|
||||||
|
value = s
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if value != "" {
|
||||||
|
req.Header.Add(codersdk.CLITelemetryHeader, value)
|
||||||
|
}
|
||||||
|
return transport.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundTripper func(req *http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return r(req)
|
||||||
|
}
|
||||||
|
|
|
@ -2,10 +2,13 @@ package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -13,14 +16,30 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/goleak"
|
"go.uber.org/goleak"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/buildinfo"
|
|
||||||
"github.com/coder/coder/v2/cli/cliui"
|
"github.com/coder/coder/v2/cli/cliui"
|
||||||
"github.com/coder/coder/v2/coderd"
|
"github.com/coder/coder/v2/cli/telemetry"
|
||||||
"github.com/coder/coder/v2/coderd/httpapi"
|
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/pretty"
|
"github.com/coder/pretty"
|
||||||
|
"github.com/coder/serpent"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
// Don't run goleak on windows tests, they're super flaky right now.
|
||||||
|
// See: https://github.com/coder/coder/issues/8954
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
goleak.VerifyTestMain(m,
|
||||||
|
// The lumberjack library is used by by agent and seems to leave
|
||||||
|
// goroutines after Close(), fails TestGitSSH tests.
|
||||||
|
// https://github.com/natefinch/lumberjack/pull/100
|
||||||
|
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
|
||||||
|
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"),
|
||||||
|
// The pq library appears to leave around a goroutine after Close().
|
||||||
|
goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_formatExamples(t *testing.T) {
|
func Test_formatExamples(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -80,49 +99,37 @@ func Test_formatExamples(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
// Don't run goleak on windows tests, they're super flaky right now.
|
|
||||||
// See: https://github.com/coder/coder/issues/8954
|
|
||||||
os.Exit(m.Run())
|
|
||||||
}
|
|
||||||
goleak.VerifyTestMain(m,
|
|
||||||
// The lumberjack library is used by by agent and seems to leave
|
|
||||||
// goroutines after Close(), fails TestGitSSH tests.
|
|
||||||
// https://github.com/natefinch/lumberjack/pull/100
|
|
||||||
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
|
|
||||||
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"),
|
|
||||||
// The pq library appears to leave around a goroutine after Close().
|
|
||||||
goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_checkVersions(t *testing.T) {
|
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("NoOutput", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
r := &RootCmd{}
|
||||||
|
cmd, err := r.Command(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
inv := cmd.Invoke()
|
||||||
|
inv.Stderr = &buf
|
||||||
|
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
// Provider a version that will not match!
|
||||||
|
codersdk.BuildVersionHeader: []string{"v2.0.0"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(nil),
|
||||||
|
}, nil
|
||||||
|
}), inv, "v2.0.0", nil)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
res, err := rt.RoundTrip(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
|
require.Equal(t, "", buf.String())
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("CustomUpgradeMessage", func(t *testing.T) {
|
t.Run("CustomUpgradeMessage", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
expectedUpgradeMessage := "My custom upgrade message"
|
|
||||||
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
||||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
|
|
||||||
ExternalURL: buildinfo.ExternalURL(),
|
|
||||||
// Provide a version that will not match
|
|
||||||
Version: "v1.0.0",
|
|
||||||
AgentAPIVersion: coderd.AgentAPIVersionREST,
|
|
||||||
// does not matter what the url is
|
|
||||||
DashboardURL: "https://example.com",
|
|
||||||
WorkspaceProxy: false,
|
|
||||||
UpgradeMessage: expectedUpgradeMessage,
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
surl, err := url.Parse(srv.URL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
client := codersdk.New(surl)
|
|
||||||
|
|
||||||
r := &RootCmd{}
|
r := &RootCmd{}
|
||||||
|
|
||||||
cmd, err := r.Command(nil)
|
cmd, err := r.Command(nil)
|
||||||
|
@ -131,50 +138,85 @@ func Test_checkVersions(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
inv := cmd.Invoke()
|
inv := cmd.Invoke()
|
||||||
inv.Stderr = &buf
|
inv.Stderr = &buf
|
||||||
|
expectedUpgradeMessage := "My custom upgrade message"
|
||||||
err = r.checkVersions(inv, client, "v2.0.0")
|
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
// Provider a version that will not match!
|
||||||
|
codersdk.BuildVersionHeader: []string{"v1.0.0"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(nil),
|
||||||
|
}, nil
|
||||||
|
}), inv, "v2.0.0", func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||||
|
return codersdk.BuildInfoResponse{
|
||||||
|
UpgradeMessage: expectedUpgradeMessage,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
res, err := rt.RoundTrip(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
// Run this twice to ensure the upgrade message is only printed once.
|
||||||
|
res, err = rt.RoundTrip(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", expectedUpgradeMessage)
|
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", expectedUpgradeMessage)
|
||||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||||
require.Equal(t, expectedOutput, buf.String())
|
require.Equal(t, expectedOutput, buf.String())
|
||||||
})
|
})
|
||||||
|
}
|
||||||
t.Run("DefaultUpgradeMessage", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
||||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
|
rt := wrapTransportWithTelemetryHeader(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
ExternalURL: buildinfo.ExternalURL(),
|
return &http.Response{
|
||||||
// Provide a version that will not match
|
Body: io.NopCloser(nil),
|
||||||
Version: "v1.0.0",
|
}, nil
|
||||||
AgentAPIVersion: coderd.AgentAPIVersionREST,
|
}), &serpent.Invocation{
|
||||||
// does not matter what the url is
|
Command: &serpent.Command{
|
||||||
DashboardURL: "https://example.com",
|
Use: "test",
|
||||||
WorkspaceProxy: false,
|
Options: serpent.OptionSet{{
|
||||||
UpgradeMessage: "",
|
Name: "bananas",
|
||||||
})
|
Description: "hey",
|
||||||
}))
|
}},
|
||||||
defer srv.Close()
|
},
|
||||||
surl, err := url.Parse(srv.URL)
|
})
|
||||||
require.NoError(t, err)
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
res, err := rt.RoundTrip(req)
|
||||||
client := codersdk.New(surl)
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
r := &RootCmd{}
|
resp := req.Header.Get(codersdk.CLITelemetryHeader)
|
||||||
|
require.NotEmpty(t, resp)
|
||||||
cmd, err := r.Command(nil)
|
data, err := base64.StdEncoding.DecodeString(resp)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
var ti telemetry.Invocation
|
||||||
var buf bytes.Buffer
|
err = json.Unmarshal(data, &ti)
|
||||||
inv := cmd.Invoke()
|
require.NoError(t, err)
|
||||||
inv.Stderr = &buf
|
require.Equal(t, ti.Command, "test")
|
||||||
|
}
|
||||||
err = r.checkVersions(inv, client, "v2.0.0")
|
|
||||||
require.NoError(t, err)
|
func Test_wrapTransportWithEntitlementsCheck(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", defaultUpgradeMessage("v1.0.0"))
|
|
||||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
lines := []string{"First Warning", "Second Warning"}
|
||||||
require.Equal(t, expectedOutput, buf.String())
|
var buf bytes.Buffer
|
||||||
})
|
rt := wrapTransportWithEntitlementsCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
codersdk.EntitlementsWarningHeader: lines,
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(nil),
|
||||||
|
}, nil
|
||||||
|
}), &buf)
|
||||||
|
res, err := rt.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
|
expectedOutput := fmt.Sprintf("%s\n%s\n", pretty.Sprint(cliui.DefaultStyles.Warn, lines[0]),
|
||||||
|
pretty.Sprint(cliui.DefaultStyles.Warn, lines[1]))
|
||||||
|
require.Equal(t, expectedOutput, buf.String())
|
||||||
}
|
}
|
||||||
|
|
|
@ -253,7 +253,7 @@ func TestHandlersOK(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var root cli.RootCmd
|
var root cli.RootCmd
|
||||||
cmd, err := root.Command(root.Core())
|
cmd, err := root.Command(root.CoreSubcommands())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
clitest.HandlersOK(t, cmd)
|
clitest.HandlersOK(t, cmd)
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/cli/clitest"
|
"github.com/coder/coder/v2/cli/clitest"
|
||||||
|
@ -197,56 +196,6 @@ func TestTemplateCreate(t *testing.T) {
|
||||||
require.NoError(t, err, "Template must be recreated without error")
|
require.NoError(t, err, "Template must be recreated without error")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("WithVariablesFileWithoutRequiredValue", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
|
||||||
coderdtest.CreateFirstUser(t, client)
|
|
||||||
|
|
||||||
templateVariables := []*proto.TemplateVariable{
|
|
||||||
{
|
|
||||||
Name: "first_variable",
|
|
||||||
Description: "This is the first variable.",
|
|
||||||
Type: "string",
|
|
||||||
Required: true,
|
|
||||||
Sensitive: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "second_variable",
|
|
||||||
Description: "This is the first variable",
|
|
||||||
Type: "string",
|
|
||||||
DefaultValue: "abc",
|
|
||||||
Required: false,
|
|
||||||
Sensitive: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
source := clitest.CreateTemplateVersionSource(t,
|
|
||||||
createEchoResponsesWithTemplateVariables(templateVariables))
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
removeTmpDirUntilSuccessAfterTest(t, tempDir)
|
|
||||||
variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml")
|
|
||||||
_, _ = variablesFile.WriteString(`second_variable: foobar`)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name())
|
|
||||||
clitest.SetupConfig(t, client, root)
|
|
||||||
inv = inv.WithContext(ctx)
|
|
||||||
pty := ptytest.New(t).Attach(inv)
|
|
||||||
|
|
||||||
// We expect the cli to return an error, so we have to handle it
|
|
||||||
// ourselves.
|
|
||||||
go func() {
|
|
||||||
cancel()
|
|
||||||
err := inv.Run()
|
|
||||||
assert.Error(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
pty.ExpectMatch("context canceled")
|
|
||||||
<-ctx.Done()
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("WithVariablesFileWithTheRequiredValue", func(t *testing.T) {
|
t.Run("WithVariablesFileWithTheRequiredValue", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ func TestUserList(t *testing.T) {
|
||||||
|
|
||||||
var apiErr *codersdk.Error
|
var apiErr *codersdk.Error
|
||||||
require.ErrorAs(t, err, &apiErr)
|
require.ErrorAs(t, err, &apiErr)
|
||||||
require.Contains(t, err.Error(), "Try logging in using 'coder login <url>'.")
|
require.Contains(t, err.Error(), "Try logging in using 'coder login'.")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ func (r *RootCmd) vscodeSSH() *serpent.Command {
|
||||||
client.SetSessionToken(string(sessionToken))
|
client.SetSessionToken(string(sessionToken))
|
||||||
|
|
||||||
// This adds custom headers to the request!
|
// This adds custom headers to the request!
|
||||||
err = r.setClient(ctx, client, serverURL)
|
err = r.configureClient(ctx, client, serverURL, inv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("set client: %w", err)
|
return xerrors.Errorf("set client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,5 +8,5 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var rootCmd cli.RootCmd
|
var rootCmd cli.RootCmd
|
||||||
rootCmd.RunMain(rootCmd.AGPL())
|
rootCmd.RunWithSubcommands(rootCmd.AGPL())
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,6 +129,12 @@ type Options struct {
|
||||||
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
|
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
|
||||||
// RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license.
|
// RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license.
|
||||||
RefreshEntitlements func(ctx context.Context) error
|
RefreshEntitlements func(ctx context.Context) error
|
||||||
|
// PostAuthAdditionalHeadersFunc is used to add additional headers to the response
|
||||||
|
// after a successful authentication.
|
||||||
|
// This is somewhat janky, but seemingly the only reasonable way to add a header
|
||||||
|
// for all authenticated users under a condition, only in Enterprise.
|
||||||
|
PostAuthAdditionalHeadersFunc func(auth httpmw.Authorization, header http.Header)
|
||||||
|
|
||||||
// TLSCertificates is used to mesh DERP servers securely.
|
// TLSCertificates is used to mesh DERP servers securely.
|
||||||
TLSCertificates []tls.Certificate
|
TLSCertificates []tls.Certificate
|
||||||
TailnetCoordinator tailnet.Coordinator
|
TailnetCoordinator tailnet.Coordinator
|
||||||
|
@ -557,30 +563,33 @@ func New(options *Options) *API {
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||||
DB: options.Database,
|
DB: options.Database,
|
||||||
OAuth2Configs: oauthConfigs,
|
OAuth2Configs: oauthConfigs,
|
||||||
RedirectToLogin: false,
|
RedirectToLogin: false,
|
||||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||||
Optional: false,
|
Optional: false,
|
||||||
SessionTokenFunc: nil, // Default behavior
|
SessionTokenFunc: nil, // Default behavior
|
||||||
|
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||||
})
|
})
|
||||||
// Same as above but it redirects to the login page.
|
// Same as above but it redirects to the login page.
|
||||||
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||||
DB: options.Database,
|
DB: options.Database,
|
||||||
OAuth2Configs: oauthConfigs,
|
OAuth2Configs: oauthConfigs,
|
||||||
RedirectToLogin: true,
|
RedirectToLogin: true,
|
||||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||||
Optional: false,
|
Optional: false,
|
||||||
SessionTokenFunc: nil, // Default behavior
|
SessionTokenFunc: nil, // Default behavior
|
||||||
|
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||||
})
|
})
|
||||||
// Same as the first but it's optional.
|
// Same as the first but it's optional.
|
||||||
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||||
DB: options.Database,
|
DB: options.Database,
|
||||||
OAuth2Configs: oauthConfigs,
|
OAuth2Configs: oauthConfigs,
|
||||||
RedirectToLogin: false,
|
RedirectToLogin: false,
|
||||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||||
Optional: true,
|
Optional: true,
|
||||||
SessionTokenFunc: nil, // Default behavior
|
SessionTokenFunc: nil, // Default behavior
|
||||||
|
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||||
})
|
})
|
||||||
|
|
||||||
// API rate limit middleware. The counter is local and not shared between
|
// API rate limit middleware. The counter is local and not shared between
|
||||||
|
|
|
@ -113,6 +113,13 @@ type ExtractAPIKeyConfig struct {
|
||||||
// SessionTokenFunc is a custom function that can be used to extract the API
|
// SessionTokenFunc is a custom function that can be used to extract the API
|
||||||
// key. If nil, the default behavior is used.
|
// key. If nil, the default behavior is used.
|
||||||
SessionTokenFunc func(r *http.Request) string
|
SessionTokenFunc func(r *http.Request) string
|
||||||
|
|
||||||
|
// PostAuthAdditionalHeadersFunc is a function that can be used to add
|
||||||
|
// headers to the response after the user has been authenticated.
|
||||||
|
//
|
||||||
|
// This is originally implemented to send entitlement warning headers after
|
||||||
|
// a user is authenticated to prevent additional CLI invocations.
|
||||||
|
PostAuthAdditionalHeadersFunc func(a Authorization, header http.Header)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtractAPIKeyMW calls ExtractAPIKey with the given config on each request,
|
// ExtractAPIKeyMW calls ExtractAPIKey with the given config on each request,
|
||||||
|
@ -454,6 +461,10 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||||
}.WithCachedASTValue(),
|
}.WithCachedASTValue(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.PostAuthAdditionalHeadersFunc != nil {
|
||||||
|
cfg.PostAuthAdditionalHeadersFunc(authz, rw.Header())
|
||||||
|
}
|
||||||
|
|
||||||
return key, &authz, true
|
return key, &authz, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -171,7 +171,7 @@ func TestPostTemplateByOrganization(t *testing.T) {
|
||||||
var apiErr *codersdk.Error
|
var apiErr *codersdk.Error
|
||||||
require.ErrorAs(t, err, &apiErr)
|
require.ErrorAs(t, err, &apiErr)
|
||||||
require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode())
|
require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode())
|
||||||
require.Contains(t, err.Error(), "Try logging in using 'coder login <url>'.")
|
require.Contains(t, err.Error(), "Try logging in using 'coder login'.")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AllowUserScheduling", func(t *testing.T) {
|
t.Run("AllowUserScheduling", func(t *testing.T) {
|
||||||
|
|
|
@ -81,6 +81,9 @@ const (
|
||||||
|
|
||||||
// BuildVersionHeader contains build information of Coder.
|
// BuildVersionHeader contains build information of Coder.
|
||||||
BuildVersionHeader = "X-Coder-Build-Version"
|
BuildVersionHeader = "X-Coder-Build-Version"
|
||||||
|
|
||||||
|
// EntitlementsWarnings contains active warnings for the user's entitlements.
|
||||||
|
EntitlementsWarningHeader = "X-Coder-Entitlements-Warning"
|
||||||
)
|
)
|
||||||
|
|
||||||
// loggableMimeTypes is a list of MIME types that are safe to log
|
// loggableMimeTypes is a list of MIME types that are safe to log
|
||||||
|
@ -358,7 +361,7 @@ func ReadBodyAsError(res *http.Response) error {
|
||||||
if res.StatusCode == http.StatusUnauthorized {
|
if res.StatusCode == http.StatusUnauthorized {
|
||||||
// 401 means the user is not logged in
|
// 401 means the user is not logged in
|
||||||
// 403 would mean that the user is not authorized
|
// 403 would mean that the user is not authorized
|
||||||
helpMessage = "Try logging in using 'coder login <url>'."
|
helpMessage = "Try logging in using 'coder login'."
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := io.ReadAll(res.Body)
|
resp, err := io.ReadAll(res.Body)
|
||||||
|
|
|
@ -82,7 +82,7 @@ func (r *RootCmd) provisionerDaemonStart() *serpent.Command {
|
||||||
// disable checks and warnings because this command starts a daemon; it is
|
// disable checks and warnings because this command starts a daemon; it is
|
||||||
// not meant for humans typing commands. Furthermore, the checks are
|
// not meant for humans typing commands. Furthermore, the checks are
|
||||||
// incompatible with PSK auth that this command uses
|
// incompatible with PSK auth that this command uses
|
||||||
r.InitClientMissingTokenOK(client),
|
r.InitClient(client),
|
||||||
),
|
),
|
||||||
Handler: func(inv *serpent.Invocation) error {
|
Handler: func(inv *serpent.Invocation) error {
|
||||||
ctx, cancel := context.WithCancel(inv.Context())
|
ctx, cancel := context.WithCancel(inv.Context())
|
||||||
|
|
|
@ -21,6 +21,6 @@ func (r *RootCmd) enterpriseOnly() []*serpent.Command {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RootCmd) EnterpriseSubcommands() []*serpent.Command {
|
func (r *RootCmd) EnterpriseSubcommands() []*serpent.Command {
|
||||||
all := append(r.Core(), r.enterpriseOnly()...)
|
all := append(r.CoreSubcommands(), r.enterpriseOnly()...)
|
||||||
return all
|
return all
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,5 +8,5 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var rootCmd entcli.RootCmd
|
var rootCmd entcli.RootCmd
|
||||||
rootCmd.RunMain(rootCmd.EnterpriseSubcommands())
|
rootCmd.RunWithSubcommands(rootCmd.EnterpriseSubcommands())
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,17 +94,18 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||||
return nil, xerrors.Errorf("init database encryption: %w", err)
|
return nil, xerrors.Errorf("init database encryption: %w", err)
|
||||||
}
|
}
|
||||||
options.Database = cryptDB
|
options.Database = cryptDB
|
||||||
|
|
||||||
api := &API{
|
api := &API{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancelFunc,
|
cancel: cancelFunc,
|
||||||
AGPL: coderd.New(options.Options),
|
|
||||||
Options: options,
|
Options: options,
|
||||||
provisionerDaemonAuth: &provisionerDaemonAuth{
|
provisionerDaemonAuth: &provisionerDaemonAuth{
|
||||||
psk: options.ProvisionerDaemonPSK,
|
psk: options.ProvisionerDaemonPSK,
|
||||||
authorizer: options.Authorizer,
|
authorizer: options.Authorizer,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
// This must happen before coderd initialization!
|
||||||
|
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
|
||||||
|
api.AGPL = coderd.New(options.Options)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = api.Close()
|
_ = api.Close()
|
||||||
|
@ -144,29 +145,32 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||||
OIDC: options.OIDCConfig,
|
OIDC: options.OIDCConfig,
|
||||||
}
|
}
|
||||||
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||||
DB: options.Database,
|
DB: options.Database,
|
||||||
OAuth2Configs: oauthConfigs,
|
OAuth2Configs: oauthConfigs,
|
||||||
RedirectToLogin: false,
|
RedirectToLogin: false,
|
||||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||||
Optional: false,
|
Optional: false,
|
||||||
SessionTokenFunc: nil, // Default behavior
|
SessionTokenFunc: nil, // Default behavior
|
||||||
|
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||||
})
|
})
|
||||||
// Same as above but it redirects to the login page.
|
// Same as above but it redirects to the login page.
|
||||||
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||||
DB: options.Database,
|
DB: options.Database,
|
||||||
OAuth2Configs: oauthConfigs,
|
OAuth2Configs: oauthConfigs,
|
||||||
RedirectToLogin: true,
|
RedirectToLogin: true,
|
||||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||||
Optional: false,
|
Optional: false,
|
||||||
SessionTokenFunc: nil, // Default behavior
|
SessionTokenFunc: nil, // Default behavior
|
||||||
|
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||||
})
|
})
|
||||||
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||||
DB: options.Database,
|
DB: options.Database,
|
||||||
OAuth2Configs: oauthConfigs,
|
OAuth2Configs: oauthConfigs,
|
||||||
RedirectToLogin: false,
|
RedirectToLogin: false,
|
||||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||||
Optional: true,
|
Optional: true,
|
||||||
SessionTokenFunc: nil, // Default behavior
|
SessionTokenFunc: nil, // Default behavior
|
||||||
|
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||||
})
|
})
|
||||||
|
|
||||||
deploymentID, err := options.Database.GetDeploymentID(ctx)
|
deploymentID, err := options.Database.GetDeploymentID(ctx)
|
||||||
|
@ -531,6 +535,38 @@ type API struct {
|
||||||
tailnetService *tailnet.ClientService
|
tailnetService *tailnet.ClientService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeEntitlementWarningsHeader writes the entitlement warnings to the response header
|
||||||
|
// for all authenticated users with roles. If there are no warnings, this header will not be written.
|
||||||
|
//
|
||||||
|
// This header is used by the CLI to display warnings to the user without having
|
||||||
|
// to make additional requests!
|
||||||
|
func (api *API) writeEntitlementWarningsHeader(a httpmw.Authorization, header http.Header) {
|
||||||
|
roles, err := a.Actor.Roles.Expand()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nonMemberRoles := 0
|
||||||
|
for _, role := range roles {
|
||||||
|
// The member role is implied, and not assignable.
|
||||||
|
// If there is no display name, then the role is also unassigned.
|
||||||
|
// This is not the ideal logic, but works for now.
|
||||||
|
if role.Name == rbac.RoleMember() || (role.DisplayName == "") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nonMemberRoles++
|
||||||
|
}
|
||||||
|
if nonMemberRoles == 0 {
|
||||||
|
// Don't show entitlement warnings if the user
|
||||||
|
// has no roles. This is a normal user!
|
||||||
|
return
|
||||||
|
}
|
||||||
|
api.entitlementsMu.RLock()
|
||||||
|
defer api.entitlementsMu.RUnlock()
|
||||||
|
for _, warning := range api.entitlements.Warnings {
|
||||||
|
header.Add(codersdk.EntitlementsWarningHeader, warning)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (api *API) Close() error {
|
func (api *API) Close() error {
|
||||||
// Replica manager should be closed first. This is because the replica
|
// Replica manager should be closed first. This is because the replica
|
||||||
// manager updates the replica's table in the database when it closes.
|
// manager updates the replica's table in the database when it closes.
|
||||||
|
|
|
@ -3,6 +3,7 @@ package coderd_test
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -197,6 +198,40 @@ func TestEntitlements(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEntitlements_HeaderWarnings(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
t.Run("ExistForAdmin", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
adminClient, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||||
|
AuditLogging: true,
|
||||||
|
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||||
|
AllFeatures: false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
//nolint:gocritic // This isn't actually bypassing any RBAC checks
|
||||||
|
res, err := adminClient.Request(context.Background(), http.MethodGet, "/api/v2/users/me", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||||
|
require.NotEmpty(t, res.Header.Values(codersdk.EntitlementsWarningHeader))
|
||||||
|
})
|
||||||
|
t.Run("NoneForNormalUser", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
adminClient, adminUser := coderdenttest.New(t, &coderdenttest.Options{
|
||||||
|
AuditLogging: true,
|
||||||
|
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||||
|
AllFeatures: false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
anotherClient, _ := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID)
|
||||||
|
res, err := anotherClient.Request(context.Background(), http.MethodGet, "/api/v2/users/me", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||||
|
require.Empty(t, res.Header.Values(codersdk.EntitlementsWarningHeader))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuditLogging(t *testing.T) {
|
func TestAuditLogging(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("Enabled", func(t *testing.T) {
|
t.Run("Enabled", func(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue