chore: remove middleware to request version and entitlement warnings (#12750)

This cleans up `root.go` a bit, adds tests for middleware HTTP transport
functions, and removes two HTTP requests we always always performed previously
when executing *any* client command.

It should improve CLI performance (especially for users with higher latency).
This commit is contained in:
Kyle Carberry 2024-03-25 20:01:42 +01:00 committed by GitHub
parent ba3879ac47
commit 03ab37b343
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 412 additions and 435 deletions

View File

@ -18,7 +18,6 @@ import (
"github.com/coder/pretty"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/userpassword"
"github.com/coder/coder/v2/codersdk"
@ -180,21 +179,11 @@ func (r *RootCmd) login() *serpent.Command {
serverURL.Scheme = "https"
}
client, err := r.createUnauthenticatedClient(ctx, serverURL)
client, err := r.createUnauthenticatedClient(ctx, serverURL, inv)
if err != nil {
return err
}
// Try to check the version of the server prior to logging in.
// It may be useful to warn the user if they are trying to login
// on a very old client.
err = r.checkVersions(inv, client, buildinfo.Version())
if err != nil {
// Checking versions isn't a fatal error so we print a warning
// and proceed.
_, _ = fmt.Fprintln(inv.Stderr, pretty.Sprint(cliui.DefaultStyles.Warn, err.Error()))
}
hasFirstUser, err := client.HasFirstUser(ctx)
if err != nil {
return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err)

View File

@ -97,33 +97,6 @@ func TestLogout(t *testing.T) {
<-logoutChan
})
t.Run("NoSessionFile", func(t *testing.T) {
t.Parallel()
pty := ptytest.New(t)
config := login(t, pty)
// Ensure session files exist.
require.FileExists(t, string(config.URL()))
require.FileExists(t, string(config.Session()))
err := os.Remove(string(config.Session()))
require.NoError(t, err)
logoutChan := make(chan struct{})
logout, _ := clitest.New(t, "logout", "--global-config", string(config))
logout.Stdin = pty.Input()
logout.Stdout = pty.Output()
go func() {
defer close(logoutChan)
err = logout.Run()
assert.ErrorContains(t, err, "You are not logged in. Try logging in using 'coder login'.")
}()
<-logoutChan
})
t.Run("CannotDeleteFiles", func(t *testing.T) {
t.Parallel()

View File

@ -9,8 +9,6 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/url"
"os"
@ -20,6 +18,7 @@ import (
"runtime"
"runtime/trace"
"strings"
"sync"
"syscall"
"text/tabwriter"
"time"
@ -27,6 +26,7 @@ import (
"github.com/mattn/go-isatty"
"github.com/mitchellh/go-wordwrap"
"golang.org/x/exp/slices"
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
"github.com/coder/pretty"
@ -67,8 +67,7 @@ const (
varOrganizationSelect = "organization"
varDisableDirect = "disable-direct-connections"
notLoggedInMessage = "You are not logged in. Try logging in using 'coder login <url>'."
notLoggedInURLSavedMessage = "You are not logged in. Try logging in using 'coder login'."
notLoggedInMessage = "You are not logged in. Try logging in using 'coder login <url>'."
envNoVersionCheck = "CODER_NO_VERSION_WARNING"
envNoFeatureWarning = "CODER_NO_FEATURE_WARNING"
@ -80,12 +79,7 @@ const (
envURL = "CODER_URL"
)
var (
errUnauthenticated = xerrors.New(notLoggedInMessage)
errUnauthenticatedURLSaved = xerrors.New(notLoggedInURLSavedMessage)
)
func (r *RootCmd) Core() []*serpent.Command {
func (r *RootCmd) CoreSubcommands() []*serpent.Command {
// Please re-sort this list alphabetically if you change it!
return []*serpent.Command{
r.dotfiles(),
@ -134,12 +128,13 @@ func (r *RootCmd) Core() []*serpent.Command {
}
func (r *RootCmd) AGPL() []*serpent.Command {
all := append(r.Core(), r.Server( /* Do not import coderd here. */ nil))
all := append(r.CoreSubcommands(), r.Server( /* Do not import coderd here. */ nil))
return all
}
// Main is the entrypoint for the Coder CLI.
func (r *RootCmd) RunMain(subcommands []*serpent.Command) {
// RunWithSubcommands runs the root command with the given subcommands.
// It is abstracted to enable the Enterprise code to add commands.
func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) {
// This configuration is not available as a standard option because we
// want to trace the entire program, including Options parsing.
goTraceFilePath, ok := os.LookupEnv("CODER_GO_TRACE")
@ -156,8 +151,6 @@ func (r *RootCmd) RunMain(subcommands []*serpent.Command) {
defer trace.Stop()
}
rand.Seed(time.Now().UnixMicro())
cmd, err := r.Command(subcommands)
if err != nil {
panic(err)
@ -503,79 +496,20 @@ type RootCmd struct {
noFeatureWarning bool
}
func addTelemetryHeader(client *codersdk.Client, inv *serpent.Invocation) {
transport, ok := client.HTTPClient.Transport.(*codersdk.HeaderTransport)
if !ok {
transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{},
}
client.HTTPClient.Transport = transport
}
var topts []telemetry.Option
for _, opt := range inv.Command.FullOptions() {
if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault {
continue
}
topts = append(topts, telemetry.Option{
Name: opt.Name,
ValueSource: string(opt.ValueSource),
})
}
ti := telemetry.Invocation{
Command: inv.Command.FullName(),
Options: topts,
InvokedAt: time.Now(),
}
byt, err := json.Marshal(ti)
if err != nil {
// Should be impossible
panic(err)
}
// Per https://stackoverflow.com/questions/686217/maximum-on-http-header-values,
// we don't want to send headers that are too long.
s := base64.StdEncoding.EncodeToString(byt)
if len(s) > 4096 {
return
}
transport.Header.Add(codersdk.CLITelemetryHeader, s)
}
// InitClient sets client to a new client.
// It reads from global configuration files if flags are not set.
// InitClient authenticates the client with files from disk
// and injects header middlewares for telemetry, authentication,
// and version checks.
func (r *RootCmd) InitClient(client *codersdk.Client) serpent.MiddlewareFunc {
return serpent.Chain(
r.initClientInternal(client, false),
// By default, we should print warnings in addition to initializing the client
r.PrintWarnings(client),
)
}
func (r *RootCmd) InitClientMissingTokenOK(client *codersdk.Client) serpent.MiddlewareFunc {
return r.initClientInternal(client, true)
}
// nolint: revive
func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing bool) serpent.MiddlewareFunc {
if client == nil {
panic("client is nil")
}
if r == nil {
panic("root is nil")
}
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
return func(inv *serpent.Invocation) error {
conf := r.createConfig()
var err error
// Read the client URL stored on disk.
if r.clientURL == nil || r.clientURL.String() == "" {
rawURL, err := conf.URL().Read()
// If the configuration files are absent, the user is logged out
if os.IsNotExist(err) {
return errUnauthenticated
return xerrors.New(notLoggedInMessage)
}
if err != nil {
return err
@ -586,25 +520,20 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing
return err
}
}
// Read the token stored on disk.
if r.token == "" {
r.token, err = conf.Session().Read()
// If the configuration files are absent, the user is logged out
if os.IsNotExist(err) {
if !allowTokenMissing {
return errUnauthenticatedURLSaved
}
} else if err != nil {
// Even if there isn't a token, we don't care.
// Some API routes can be unauthenticated.
if err != nil && !os.IsNotExist(err) {
return err
}
}
err = r.setClient(inv.Context(), client, r.clientURL)
err = r.configureClient(inv.Context(), client, r.clientURL, inv)
if err != nil {
return err
}
addTelemetryHeader(client, inv)
client.SetSessionToken(r.token)
if r.debugHTTP {
@ -617,48 +546,8 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing
}
}
func (r *RootCmd) PrintWarnings(client *codersdk.Client) serpent.MiddlewareFunc {
if client == nil {
panic("client is nil")
}
if r == nil {
panic("root is nil")
}
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
return func(inv *serpent.Invocation) error {
// We send these requests in parallel to minimize latency.
var (
versionErr = make(chan error)
warningErr = make(chan error)
)
go func() {
versionErr <- r.checkVersions(inv, client, buildinfo.Version())
close(versionErr)
}()
go func() {
warningErr <- r.checkWarnings(inv, client)
close(warningErr)
}()
if err := <-versionErr; err != nil {
// Just log the error here. We never want to fail a command
// due to a pre-run.
pretty.Fprintf(inv.Stderr, cliui.DefaultStyles.Warn, "check versions error: %s", err)
_, _ = fmt.Fprintln(inv.Stderr)
}
if err := <-warningErr; err != nil {
// Same as above
pretty.Fprintf(inv.Stderr, cliui.DefaultStyles.Warn, "check entitlement warnings error: %s", err)
_, _ = fmt.Fprintln(inv.Stderr)
}
return next(inv)
}
}
}
// HeaderTransport creates a new transport that executes `--header-command`
// if it is set to add headers for all outbound requests.
func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) {
transport := &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
@ -700,22 +589,38 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
return transport, nil
}
func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL) error {
transport, err := r.HeaderTransport(ctx, serverURL)
func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error {
transport := http.DefaultTransport
transport = wrapTransportWithTelemetryHeader(transport, inv)
if !r.noVersionCheck {
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
// Create a new client without any wrapped transport
// otherwise it creates an infinite loop!
basicClient := codersdk.New(serverURL)
return basicClient.BuildInfo(ctx)
})
}
if !r.noFeatureWarning {
transport = wrapTransportWithEntitlementsCheck(transport, inv.Stderr)
}
headerTransport, err := r.HeaderTransport(ctx, serverURL)
if err != nil {
return xerrors.Errorf("create header transport: %w", err)
}
client.URL = serverURL
// The header transport has to come last.
// codersdk checks for the header transport to get headers
// to clone on the DERP client.
headerTransport.Transport = transport
client.HTTPClient = &http.Client{
Transport: transport,
Transport: headerTransport,
}
client.URL = serverURL
return nil
}
func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *url.URL) (*codersdk.Client, error) {
func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *url.URL, inv *serpent.Invocation) (*codersdk.Client, error) {
var client codersdk.Client
err := r.setClient(ctx, &client, serverURL)
err := r.configureClient(ctx, &client, serverURL, inv)
return &client, err
}
@ -879,70 +784,6 @@ func formatExamples(examples ...example) string {
return sb.String()
}
// checkVersions checks to see if there's a version mismatch between the client
// and server and prints a message nudging the user to upgrade if a mismatch
// is detected. forceCheck is a test flag and should always be false in production.
//
//nolint:revive
func (r *RootCmd) checkVersions(i *serpent.Invocation, client *codersdk.Client, clientVersion string) error {
if r.noVersionCheck {
return nil
}
ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second)
defer cancel()
serverInfo, err := client.BuildInfo(ctx)
// Avoid printing errors that are connection-related.
if isConnectionError(err) {
return nil
}
if err != nil {
return xerrors.Errorf("build info: %w", err)
}
if !buildinfo.VersionsMatch(clientVersion, serverInfo.Version) {
upgradeMessage := defaultUpgradeMessage(serverInfo.CanonicalVersion())
if serverInfo.UpgradeMessage != "" {
upgradeMessage = serverInfo.UpgradeMessage
}
fmtWarningText := "version mismatch: client %s, server %s\n%s"
fmtWarn := pretty.Sprint(cliui.DefaultStyles.Warn, fmtWarningText)
warning := fmt.Sprintf(fmtWarn, clientVersion, serverInfo.Version, upgradeMessage)
_, _ = fmt.Fprint(i.Stderr, warning)
_, _ = fmt.Fprintln(i.Stderr)
}
return nil
}
func (r *RootCmd) checkWarnings(i *serpent.Invocation, client *codersdk.Client) error {
if r.noFeatureWarning {
return nil
}
ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second)
defer cancel()
user, err := client.User(ctx, codersdk.Me)
if err != nil {
return xerrors.Errorf("get user me: %w", err)
}
entitlements, err := client.Entitlements(ctx)
if err == nil {
// Don't show warning to regular users.
if len(user.Roles) > 0 {
for _, w := range entitlements.Warnings {
_, _ = fmt.Fprintln(i.Stderr, pretty.Sprint(cliui.DefaultStyles.Warn, w))
}
}
}
return nil
}
// Verbosef logs a message if verbose mode is enabled.
func (r *RootCmd) Verbosef(inv *serpent.Invocation, fmtStr string, args ...interface{}) {
if r.verbose {
@ -1068,19 +909,6 @@ func ExitError(code int, err error) error {
return &exitError{code: code, err: err}
}
// IiConnectionErr is a convenience function for checking if the source of an
// error is due to a 'connection refused', 'no such host', etc.
func isConnectionError(err error) bool {
var (
// E.g. no such host
dnsErr *net.DNSError
// Eg. connection refused
opErr *net.OpError
)
return xerrors.As(err, &dnsErr) || xerrors.As(err, &opErr)
}
type prettyErrorFormatter struct {
w io.Writer
// verbose turns on more detailed error logs, such as stack traces.
@ -1305,3 +1133,105 @@ func defaultUpgradeMessage(version string) string {
}
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
}
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
// that checks for entitlement warnings and prints them to the user.
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
var once sync.Once
return roundTripper(func(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTrip(req)
if err != nil {
return res, err
}
once.Do(func() {
for _, warning := range res.Header.Values(codersdk.EntitlementsWarningHeader) {
_, _ = fmt.Fprintln(w, pretty.Sprint(cliui.DefaultStyles.Warn, warning))
}
})
return res, err
})
}
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
// that checks for version mismatches between the client and server. If a mismatch
// is detected, a warning is printed to the user.
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
var once sync.Once
return roundTripper(func(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTrip(req)
if err != nil {
return res, err
}
once.Do(func() {
serverVersion := res.Header.Get(codersdk.BuildVersionHeader)
if serverVersion == "" {
return
}
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
return
}
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
serverInfo, err := getBuildInfo(inv.Context())
if err == nil && serverInfo.UpgradeMessage != "" {
upgradeMessage = serverInfo.UpgradeMessage
}
fmtWarningText := "version mismatch: client %s, server %s\n%s"
fmtWarn := pretty.Sprint(cliui.DefaultStyles.Warn, fmtWarningText)
warning := fmt.Sprintf(fmtWarn, clientVersion, serverVersion, upgradeMessage)
_, _ = fmt.Fprintln(inv.Stderr, warning)
})
return res, err
})
}
// wrapTransportWithTelemetryHeader adds telemetry headers to report command usage
// to an HTTP transport.
func wrapTransportWithTelemetryHeader(transport http.RoundTripper, inv *serpent.Invocation) http.RoundTripper {
var (
value string
once sync.Once
)
return roundTripper(func(req *http.Request) (*http.Response, error) {
once.Do(func() {
// We only want to compute this header once when a request
// first goes out, hence the complexity with locking here.
var topts []telemetry.Option
for _, opt := range inv.Command.FullOptions() {
if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault {
continue
}
topts = append(topts, telemetry.Option{
Name: opt.Name,
ValueSource: string(opt.ValueSource),
})
}
ti := telemetry.Invocation{
Command: inv.Command.FullName(),
Options: topts,
InvokedAt: time.Now(),
}
byt, err := json.Marshal(ti)
if err != nil {
// Should be impossible
panic(err)
}
s := base64.StdEncoding.EncodeToString(byt)
// Don't send the header if it's too long!
if len(s) <= 4096 {
value = s
}
})
if value != "" {
req.Header.Add(codersdk.CLITelemetryHeader, value)
}
return transport.RoundTrip(req)
})
}
type roundTripper func(req *http.Request) (*http.Response, error)
func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r(req)
}

View File

@ -2,10 +2,13 @@ package cli
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"runtime"
"testing"
@ -13,14 +16,30 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/cli/telemetry"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func TestMain(m *testing.M) {
if runtime.GOOS == "windows" {
// Don't run goleak on windows tests, they're super flaky right now.
// See: https://github.com/coder/coder/issues/8954
os.Exit(m.Run())
}
goleak.VerifyTestMain(m,
// The lumberjack library is used by by agent and seems to leave
// goroutines after Close(), fails TestGitSSH tests.
// https://github.com/natefinch/lumberjack/pull/100
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"),
// The pq library appears to leave around a goroutine after Close().
goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"),
)
}
func Test_formatExamples(t *testing.T) {
t.Parallel()
@ -80,49 +99,37 @@ func Test_formatExamples(t *testing.T) {
}
}
func TestMain(m *testing.M) {
if runtime.GOOS == "windows" {
// Don't run goleak on windows tests, they're super flaky right now.
// See: https://github.com/coder/coder/issues/8954
os.Exit(m.Run())
}
goleak.VerifyTestMain(m,
// The lumberjack library is used by by agent and seems to leave
// goroutines after Close(), fails TestGitSSH tests.
// https://github.com/natefinch/lumberjack/pull/100
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"),
// The pq library appears to leave around a goroutine after Close().
goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"),
)
}
func Test_checkVersions(t *testing.T) {
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
t.Parallel()
t.Run("NoOutput", func(t *testing.T) {
t.Parallel()
r := &RootCmd{}
cmd, err := r.Command(nil)
require.NoError(t, err)
var buf bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &buf
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
// Provider a version that will not match!
codersdk.BuildVersionHeader: []string{"v2.0.0"},
},
Body: io.NopCloser(nil),
}, nil
}), inv, "v2.0.0", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, "", buf.String())
})
t.Run("CustomUpgradeMessage", func(t *testing.T) {
t.Parallel()
expectedUpgradeMessage := "My custom upgrade message"
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
ExternalURL: buildinfo.ExternalURL(),
// Provide a version that will not match
Version: "v1.0.0",
AgentAPIVersion: coderd.AgentAPIVersionREST,
// does not matter what the url is
DashboardURL: "https://example.com",
WorkspaceProxy: false,
UpgradeMessage: expectedUpgradeMessage,
})
}))
defer srv.Close()
surl, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(surl)
r := &RootCmd{}
cmd, err := r.Command(nil)
@ -131,50 +138,85 @@ func Test_checkVersions(t *testing.T) {
var buf bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &buf
err = r.checkVersions(inv, client, "v2.0.0")
expectedUpgradeMessage := "My custom upgrade message"
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
// Provider a version that will not match!
codersdk.BuildVersionHeader: []string{"v1.0.0"},
},
Body: io.NopCloser(nil),
}, nil
}), inv, "v2.0.0", func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
return codersdk.BuildInfoResponse{
UpgradeMessage: expectedUpgradeMessage,
}, nil
})
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
// Run this twice to ensure the upgrade message is only printed once.
res, err = rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", expectedUpgradeMessage)
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
require.Equal(t, expectedOutput, buf.String())
})
t.Run("DefaultUpgradeMessage", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
ExternalURL: buildinfo.ExternalURL(),
// Provide a version that will not match
Version: "v1.0.0",
AgentAPIVersion: coderd.AgentAPIVersionREST,
// does not matter what the url is
DashboardURL: "https://example.com",
WorkspaceProxy: false,
UpgradeMessage: "",
})
}))
defer srv.Close()
surl, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(surl)
r := &RootCmd{}
cmd, err := r.Command(nil)
require.NoError(t, err)
var buf bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &buf
err = r.checkVersions(inv, client, "v2.0.0")
require.NoError(t, err)
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", defaultUpgradeMessage("v1.0.0"))
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
require.Equal(t, expectedOutput, buf.String())
})
}
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
t.Parallel()
rt := wrapTransportWithTelemetryHeader(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(nil),
}, nil
}), &serpent.Invocation{
Command: &serpent.Command{
Use: "test",
Options: serpent.OptionSet{{
Name: "bananas",
Description: "hey",
}},
},
})
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
resp := req.Header.Get(codersdk.CLITelemetryHeader)
require.NotEmpty(t, resp)
data, err := base64.StdEncoding.DecodeString(resp)
require.NoError(t, err)
var ti telemetry.Invocation
err = json.Unmarshal(data, &ti)
require.NoError(t, err)
require.Equal(t, ti.Command, "test")
}
func Test_wrapTransportWithEntitlementsCheck(t *testing.T) {
t.Parallel()
lines := []string{"First Warning", "Second Warning"}
var buf bytes.Buffer
rt := wrapTransportWithEntitlementsCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
codersdk.EntitlementsWarningHeader: lines,
},
Body: io.NopCloser(nil),
}, nil
}), &buf)
res, err := rt.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil))
require.NoError(t, err)
defer res.Body.Close()
expectedOutput := fmt.Sprintf("%s\n%s\n", pretty.Sprint(cliui.DefaultStyles.Warn, lines[0]),
pretty.Sprint(cliui.DefaultStyles.Warn, lines[1]))
require.Equal(t, expectedOutput, buf.String())
}

View File

@ -253,7 +253,7 @@ func TestHandlersOK(t *testing.T) {
t.Parallel()
var root cli.RootCmd
cmd, err := root.Command(root.Core())
cmd, err := root.Command(root.CoreSubcommands())
require.NoError(t, err)
clitest.HandlersOK(t, cmd)

View File

@ -7,7 +7,6 @@ import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
@ -197,56 +196,6 @@ func TestTemplateCreate(t *testing.T) {
require.NoError(t, err, "Template must be recreated without error")
})
t.Run("WithVariablesFileWithoutRequiredValue", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
coderdtest.CreateFirstUser(t, client)
templateVariables := []*proto.TemplateVariable{
{
Name: "first_variable",
Description: "This is the first variable.",
Type: "string",
Required: true,
Sensitive: true,
},
{
Name: "second_variable",
Description: "This is the first variable",
Type: "string",
DefaultValue: "abc",
Required: false,
Sensitive: true,
},
}
source := clitest.CreateTemplateVersionSource(t,
createEchoResponsesWithTemplateVariables(templateVariables))
tempDir := t.TempDir()
removeTmpDirUntilSuccessAfterTest(t, tempDir)
variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml")
_, _ = variablesFile.WriteString(`second_variable: foobar`)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name())
clitest.SetupConfig(t, client, root)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
// We expect the cli to return an error, so we have to handle it
// ourselves.
go func() {
cancel()
err := inv.Run()
assert.Error(t, err)
}()
pty.ExpectMatch("context canceled")
<-ctx.Done()
})
t.Run("WithVariablesFileWithTheRequiredValue", func(t *testing.T) {
t.Parallel()

View File

@ -77,7 +77,7 @@ func TestUserList(t *testing.T) {
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Contains(t, err.Error(), "Try logging in using 'coder login <url>'.")
require.Contains(t, err.Error(), "Try logging in using 'coder login'.")
})
}

View File

@ -90,7 +90,7 @@ func (r *RootCmd) vscodeSSH() *serpent.Command {
client.SetSessionToken(string(sessionToken))
// This adds custom headers to the request!
err = r.setClient(ctx, client, serverURL)
err = r.configureClient(ctx, client, serverURL, inv)
if err != nil {
return xerrors.Errorf("set client: %w", err)
}

View File

@ -8,5 +8,5 @@ import (
func main() {
var rootCmd cli.RootCmd
rootCmd.RunMain(rootCmd.AGPL())
rootCmd.RunWithSubcommands(rootCmd.AGPL())
}

View File

@ -129,6 +129,12 @@ type Options struct {
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
// RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license.
RefreshEntitlements func(ctx context.Context) error
// PostAuthAdditionalHeadersFunc is used to add additional headers to the response
// after a successful authentication.
// This is somewhat janky, but seemingly the only reasonable way to add a header
// for all authenticated users under a condition, only in Enterprise.
PostAuthAdditionalHeadersFunc func(auth httpmw.Authorization, header http.Header)
// TLSCertificates is used to mesh DERP servers securely.
TLSCertificates []tls.Certificate
TailnetCoordinator tailnet.Coordinator
@ -557,30 +563,33 @@ func New(options *Options) *API {
}
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: false,
SessionTokenFunc: nil, // Default behavior
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: false,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
})
// Same as above but it redirects to the login page.
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: true,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: false,
SessionTokenFunc: nil, // Default behavior
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: true,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: false,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
})
// Same as the first but it's optional.
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: true,
SessionTokenFunc: nil, // Default behavior
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: true,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
})
// API rate limit middleware. The counter is local and not shared between

View File

@ -113,6 +113,13 @@ type ExtractAPIKeyConfig struct {
// SessionTokenFunc is a custom function that can be used to extract the API
// key. If nil, the default behavior is used.
SessionTokenFunc func(r *http.Request) string
// PostAuthAdditionalHeadersFunc is a function that can be used to add
// headers to the response after the user has been authenticated.
//
// This is originally implemented to send entitlement warning headers after
// a user is authenticated to prevent additional CLI invocations.
PostAuthAdditionalHeadersFunc func(a Authorization, header http.Header)
}
// ExtractAPIKeyMW calls ExtractAPIKey with the given config on each request,
@ -454,6 +461,10 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
}.WithCachedASTValue(),
}
if cfg.PostAuthAdditionalHeadersFunc != nil {
cfg.PostAuthAdditionalHeadersFunc(authz, rw.Header())
}
return key, &authz, true
}

View File

@ -171,7 +171,7 @@ func TestPostTemplateByOrganization(t *testing.T) {
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode())
require.Contains(t, err.Error(), "Try logging in using 'coder login <url>'.")
require.Contains(t, err.Error(), "Try logging in using 'coder login'.")
})
t.Run("AllowUserScheduling", func(t *testing.T) {

View File

@ -81,6 +81,9 @@ const (
// BuildVersionHeader contains build information of Coder.
BuildVersionHeader = "X-Coder-Build-Version"
// EntitlementsWarnings contains active warnings for the user's entitlements.
EntitlementsWarningHeader = "X-Coder-Entitlements-Warning"
)
// loggableMimeTypes is a list of MIME types that are safe to log
@ -358,7 +361,7 @@ func ReadBodyAsError(res *http.Response) error {
if res.StatusCode == http.StatusUnauthorized {
// 401 means the user is not logged in
// 403 would mean that the user is not authorized
helpMessage = "Try logging in using 'coder login <url>'."
helpMessage = "Try logging in using 'coder login'."
}
resp, err := io.ReadAll(res.Body)

View File

@ -82,7 +82,7 @@ func (r *RootCmd) provisionerDaemonStart() *serpent.Command {
// disable checks and warnings because this command starts a daemon; it is
// not meant for humans typing commands. Furthermore, the checks are
// incompatible with PSK auth that this command uses
r.InitClientMissingTokenOK(client),
r.InitClient(client),
),
Handler: func(inv *serpent.Invocation) error {
ctx, cancel := context.WithCancel(inv.Context())

View File

@ -21,6 +21,6 @@ func (r *RootCmd) enterpriseOnly() []*serpent.Command {
}
func (r *RootCmd) EnterpriseSubcommands() []*serpent.Command {
all := append(r.Core(), r.enterpriseOnly()...)
all := append(r.CoreSubcommands(), r.enterpriseOnly()...)
return all
}

View File

@ -8,5 +8,5 @@ import (
func main() {
var rootCmd entcli.RootCmd
rootCmd.RunMain(rootCmd.EnterpriseSubcommands())
rootCmd.RunWithSubcommands(rootCmd.EnterpriseSubcommands())
}

View File

@ -94,17 +94,18 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
return nil, xerrors.Errorf("init database encryption: %w", err)
}
options.Database = cryptDB
api := &API{
ctx: ctx,
cancel: cancelFunc,
AGPL: coderd.New(options.Options),
Options: options,
provisionerDaemonAuth: &provisionerDaemonAuth{
psk: options.ProvisionerDaemonPSK,
authorizer: options.Authorizer,
},
}
// This must happen before coderd initialization!
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
api.AGPL = coderd.New(options.Options)
defer func() {
if err != nil {
_ = api.Close()
@ -144,29 +145,32 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
OIDC: options.OIDCConfig,
}
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: false,
SessionTokenFunc: nil, // Default behavior
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: false,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
})
// Same as above but it redirects to the login page.
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: true,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: false,
SessionTokenFunc: nil, // Default behavior
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: true,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: false,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
})
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: true,
SessionTokenFunc: nil, // Default behavior
DB: options.Database,
OAuth2Configs: oauthConfigs,
RedirectToLogin: false,
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
Optional: true,
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
})
deploymentID, err := options.Database.GetDeploymentID(ctx)
@ -531,6 +535,38 @@ type API struct {
tailnetService *tailnet.ClientService
}
// writeEntitlementWarningsHeader writes the entitlement warnings to the response header
// for all authenticated users with roles. If there are no warnings, this header will not be written.
//
// This header is used by the CLI to display warnings to the user without having
// to make additional requests!
func (api *API) writeEntitlementWarningsHeader(a httpmw.Authorization, header http.Header) {
roles, err := a.Actor.Roles.Expand()
if err != nil {
return
}
nonMemberRoles := 0
for _, role := range roles {
// The member role is implied, and not assignable.
// If there is no display name, then the role is also unassigned.
// This is not the ideal logic, but works for now.
if role.Name == rbac.RoleMember() || (role.DisplayName == "") {
continue
}
nonMemberRoles++
}
if nonMemberRoles == 0 {
// Don't show entitlement warnings if the user
// has no roles. This is a normal user!
return
}
api.entitlementsMu.RLock()
defer api.entitlementsMu.RUnlock()
for _, warning := range api.entitlements.Warnings {
header.Add(codersdk.EntitlementsWarningHeader, warning)
}
}
func (api *API) Close() error {
// Replica manager should be closed first. This is because the replica
// manager updates the replica's table in the database when it closes.

View File

@ -3,6 +3,7 @@ package coderd_test
import (
"bytes"
"context"
"net/http"
"reflect"
"strings"
"testing"
@ -197,6 +198,40 @@ func TestEntitlements(t *testing.T) {
})
}
func TestEntitlements_HeaderWarnings(t *testing.T) {
t.Parallel()
t.Run("ExistForAdmin", func(t *testing.T) {
t.Parallel()
adminClient, _ := coderdenttest.New(t, &coderdenttest.Options{
AuditLogging: true,
LicenseOptions: &coderdenttest.LicenseOptions{
AllFeatures: false,
},
})
//nolint:gocritic // This isn't actually bypassing any RBAC checks
res, err := adminClient.Request(context.Background(), http.MethodGet, "/api/v2/users/me", nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
require.NotEmpty(t, res.Header.Values(codersdk.EntitlementsWarningHeader))
})
t.Run("NoneForNormalUser", func(t *testing.T) {
t.Parallel()
adminClient, adminUser := coderdenttest.New(t, &coderdenttest.Options{
AuditLogging: true,
LicenseOptions: &coderdenttest.LicenseOptions{
AllFeatures: false,
},
})
anotherClient, _ := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID)
res, err := anotherClient.Request(context.Background(), http.MethodGet, "/api/v2/users/me", nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
require.Empty(t, res.Header.Values(codersdk.EntitlementsWarningHeader))
})
}
func TestAuditLogging(t *testing.T) {
t.Parallel()
t.Run("Enabled", func(t *testing.T) {