mirror of https://github.com/coder/coder.git
chore: remove middleware to request version and entitlement warnings (#12750)
This cleans up `root.go` a bit, adds tests for middleware HTTP transport functions, and removes two HTTP requests we always always performed previously when executing *any* client command. It should improve CLI performance (especially for users with higher latency).
This commit is contained in:
parent
ba3879ac47
commit
03ab37b343
13
cli/login.go
13
cli/login.go
|
@ -18,7 +18,6 @@ import (
|
|||
|
||||
"github.com/coder/pretty"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
@ -180,21 +179,11 @@ func (r *RootCmd) login() *serpent.Command {
|
|||
serverURL.Scheme = "https"
|
||||
}
|
||||
|
||||
client, err := r.createUnauthenticatedClient(ctx, serverURL)
|
||||
client, err := r.createUnauthenticatedClient(ctx, serverURL, inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to check the version of the server prior to logging in.
|
||||
// It may be useful to warn the user if they are trying to login
|
||||
// on a very old client.
|
||||
err = r.checkVersions(inv, client, buildinfo.Version())
|
||||
if err != nil {
|
||||
// Checking versions isn't a fatal error so we print a warning
|
||||
// and proceed.
|
||||
_, _ = fmt.Fprintln(inv.Stderr, pretty.Sprint(cliui.DefaultStyles.Warn, err.Error()))
|
||||
}
|
||||
|
||||
hasFirstUser, err := client.HasFirstUser(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err)
|
||||
|
|
|
@ -97,33 +97,6 @@ func TestLogout(t *testing.T) {
|
|||
|
||||
<-logoutChan
|
||||
})
|
||||
t.Run("NoSessionFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pty := ptytest.New(t)
|
||||
config := login(t, pty)
|
||||
|
||||
// Ensure session files exist.
|
||||
require.FileExists(t, string(config.URL()))
|
||||
require.FileExists(t, string(config.Session()))
|
||||
|
||||
err := os.Remove(string(config.Session()))
|
||||
require.NoError(t, err)
|
||||
|
||||
logoutChan := make(chan struct{})
|
||||
logout, _ := clitest.New(t, "logout", "--global-config", string(config))
|
||||
|
||||
logout.Stdin = pty.Input()
|
||||
logout.Stdout = pty.Output()
|
||||
|
||||
go func() {
|
||||
defer close(logoutChan)
|
||||
err = logout.Run()
|
||||
assert.ErrorContains(t, err, "You are not logged in. Try logging in using 'coder login'.")
|
||||
}()
|
||||
|
||||
<-logoutChan
|
||||
})
|
||||
t.Run("CannotDeleteFiles", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
362
cli/root.go
362
cli/root.go
|
@ -9,8 +9,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -20,6 +18,7 @@ import (
|
|||
"runtime"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
@ -27,6 +26,7 @@ import (
|
|||
"github.com/mattn/go-isatty"
|
||||
"github.com/mitchellh/go-wordwrap"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
|
@ -67,8 +67,7 @@ const (
|
|||
varOrganizationSelect = "organization"
|
||||
varDisableDirect = "disable-direct-connections"
|
||||
|
||||
notLoggedInMessage = "You are not logged in. Try logging in using 'coder login <url>'."
|
||||
notLoggedInURLSavedMessage = "You are not logged in. Try logging in using 'coder login'."
|
||||
notLoggedInMessage = "You are not logged in. Try logging in using 'coder login <url>'."
|
||||
|
||||
envNoVersionCheck = "CODER_NO_VERSION_WARNING"
|
||||
envNoFeatureWarning = "CODER_NO_FEATURE_WARNING"
|
||||
|
@ -80,12 +79,7 @@ const (
|
|||
envURL = "CODER_URL"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnauthenticated = xerrors.New(notLoggedInMessage)
|
||||
errUnauthenticatedURLSaved = xerrors.New(notLoggedInURLSavedMessage)
|
||||
)
|
||||
|
||||
func (r *RootCmd) Core() []*serpent.Command {
|
||||
func (r *RootCmd) CoreSubcommands() []*serpent.Command {
|
||||
// Please re-sort this list alphabetically if you change it!
|
||||
return []*serpent.Command{
|
||||
r.dotfiles(),
|
||||
|
@ -134,12 +128,13 @@ func (r *RootCmd) Core() []*serpent.Command {
|
|||
}
|
||||
|
||||
func (r *RootCmd) AGPL() []*serpent.Command {
|
||||
all := append(r.Core(), r.Server( /* Do not import coderd here. */ nil))
|
||||
all := append(r.CoreSubcommands(), r.Server( /* Do not import coderd here. */ nil))
|
||||
return all
|
||||
}
|
||||
|
||||
// Main is the entrypoint for the Coder CLI.
|
||||
func (r *RootCmd) RunMain(subcommands []*serpent.Command) {
|
||||
// RunWithSubcommands runs the root command with the given subcommands.
|
||||
// It is abstracted to enable the Enterprise code to add commands.
|
||||
func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) {
|
||||
// This configuration is not available as a standard option because we
|
||||
// want to trace the entire program, including Options parsing.
|
||||
goTraceFilePath, ok := os.LookupEnv("CODER_GO_TRACE")
|
||||
|
@ -156,8 +151,6 @@ func (r *RootCmd) RunMain(subcommands []*serpent.Command) {
|
|||
defer trace.Stop()
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixMicro())
|
||||
|
||||
cmd, err := r.Command(subcommands)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -503,79 +496,20 @@ type RootCmd struct {
|
|||
noFeatureWarning bool
|
||||
}
|
||||
|
||||
func addTelemetryHeader(client *codersdk.Client, inv *serpent.Invocation) {
|
||||
transport, ok := client.HTTPClient.Transport.(*codersdk.HeaderTransport)
|
||||
if !ok {
|
||||
transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: http.Header{},
|
||||
}
|
||||
client.HTTPClient.Transport = transport
|
||||
}
|
||||
|
||||
var topts []telemetry.Option
|
||||
for _, opt := range inv.Command.FullOptions() {
|
||||
if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault {
|
||||
continue
|
||||
}
|
||||
topts = append(topts, telemetry.Option{
|
||||
Name: opt.Name,
|
||||
ValueSource: string(opt.ValueSource),
|
||||
})
|
||||
}
|
||||
ti := telemetry.Invocation{
|
||||
Command: inv.Command.FullName(),
|
||||
Options: topts,
|
||||
InvokedAt: time.Now(),
|
||||
}
|
||||
|
||||
byt, err := json.Marshal(ti)
|
||||
if err != nil {
|
||||
// Should be impossible
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Per https://stackoverflow.com/questions/686217/maximum-on-http-header-values,
|
||||
// we don't want to send headers that are too long.
|
||||
s := base64.StdEncoding.EncodeToString(byt)
|
||||
if len(s) > 4096 {
|
||||
return
|
||||
}
|
||||
|
||||
transport.Header.Add(codersdk.CLITelemetryHeader, s)
|
||||
}
|
||||
|
||||
// InitClient sets client to a new client.
|
||||
// It reads from global configuration files if flags are not set.
|
||||
// InitClient authenticates the client with files from disk
|
||||
// and injects header middlewares for telemetry, authentication,
|
||||
// and version checks.
|
||||
func (r *RootCmd) InitClient(client *codersdk.Client) serpent.MiddlewareFunc {
|
||||
return serpent.Chain(
|
||||
r.initClientInternal(client, false),
|
||||
// By default, we should print warnings in addition to initializing the client
|
||||
r.PrintWarnings(client),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *RootCmd) InitClientMissingTokenOK(client *codersdk.Client) serpent.MiddlewareFunc {
|
||||
return r.initClientInternal(client, true)
|
||||
}
|
||||
|
||||
// nolint: revive
|
||||
func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing bool) serpent.MiddlewareFunc {
|
||||
if client == nil {
|
||||
panic("client is nil")
|
||||
}
|
||||
if r == nil {
|
||||
panic("root is nil")
|
||||
}
|
||||
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
|
||||
return func(inv *serpent.Invocation) error {
|
||||
conf := r.createConfig()
|
||||
var err error
|
||||
// Read the client URL stored on disk.
|
||||
if r.clientURL == nil || r.clientURL.String() == "" {
|
||||
rawURL, err := conf.URL().Read()
|
||||
// If the configuration files are absent, the user is logged out
|
||||
if os.IsNotExist(err) {
|
||||
return errUnauthenticated
|
||||
return xerrors.New(notLoggedInMessage)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -586,25 +520,20 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing
|
|||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Read the token stored on disk.
|
||||
if r.token == "" {
|
||||
r.token, err = conf.Session().Read()
|
||||
// If the configuration files are absent, the user is logged out
|
||||
if os.IsNotExist(err) {
|
||||
if !allowTokenMissing {
|
||||
return errUnauthenticatedURLSaved
|
||||
}
|
||||
} else if err != nil {
|
||||
// Even if there isn't a token, we don't care.
|
||||
// Some API routes can be unauthenticated.
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = r.setClient(inv.Context(), client, r.clientURL)
|
||||
|
||||
err = r.configureClient(inv.Context(), client, r.clientURL, inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addTelemetryHeader(client, inv)
|
||||
|
||||
client.SetSessionToken(r.token)
|
||||
|
||||
if r.debugHTTP {
|
||||
|
@ -617,48 +546,8 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing
|
|||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) PrintWarnings(client *codersdk.Client) serpent.MiddlewareFunc {
|
||||
if client == nil {
|
||||
panic("client is nil")
|
||||
}
|
||||
if r == nil {
|
||||
panic("root is nil")
|
||||
}
|
||||
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
|
||||
return func(inv *serpent.Invocation) error {
|
||||
// We send these requests in parallel to minimize latency.
|
||||
var (
|
||||
versionErr = make(chan error)
|
||||
warningErr = make(chan error)
|
||||
)
|
||||
go func() {
|
||||
versionErr <- r.checkVersions(inv, client, buildinfo.Version())
|
||||
close(versionErr)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
warningErr <- r.checkWarnings(inv, client)
|
||||
close(warningErr)
|
||||
}()
|
||||
|
||||
if err := <-versionErr; err != nil {
|
||||
// Just log the error here. We never want to fail a command
|
||||
// due to a pre-run.
|
||||
pretty.Fprintf(inv.Stderr, cliui.DefaultStyles.Warn, "check versions error: %s", err)
|
||||
_, _ = fmt.Fprintln(inv.Stderr)
|
||||
}
|
||||
|
||||
if err := <-warningErr; err != nil {
|
||||
// Same as above
|
||||
pretty.Fprintf(inv.Stderr, cliui.DefaultStyles.Warn, "check entitlement warnings error: %s", err)
|
||||
_, _ = fmt.Fprintln(inv.Stderr)
|
||||
}
|
||||
|
||||
return next(inv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HeaderTransport creates a new transport that executes `--header-command`
|
||||
// if it is set to add headers for all outbound requests.
|
||||
func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) {
|
||||
transport := &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
|
@ -700,22 +589,38 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
|
|||
return transport, nil
|
||||
}
|
||||
|
||||
func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL) error {
|
||||
transport, err := r.HeaderTransport(ctx, serverURL)
|
||||
func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error {
|
||||
transport := http.DefaultTransport
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
// Create a new client without any wrapped transport
|
||||
// otherwise it creates an infinite loop!
|
||||
basicClient := codersdk.New(serverURL)
|
||||
return basicClient.BuildInfo(ctx)
|
||||
})
|
||||
}
|
||||
if !r.noFeatureWarning {
|
||||
transport = wrapTransportWithEntitlementsCheck(transport, inv.Stderr)
|
||||
}
|
||||
headerTransport, err := r.HeaderTransport(ctx, serverURL)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create header transport: %w", err)
|
||||
}
|
||||
|
||||
client.URL = serverURL
|
||||
// The header transport has to come last.
|
||||
// codersdk checks for the header transport to get headers
|
||||
// to clone on the DERP client.
|
||||
headerTransport.Transport = transport
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: transport,
|
||||
Transport: headerTransport,
|
||||
}
|
||||
client.URL = serverURL
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *url.URL) (*codersdk.Client, error) {
|
||||
func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *url.URL, inv *serpent.Invocation) (*codersdk.Client, error) {
|
||||
var client codersdk.Client
|
||||
err := r.setClient(ctx, &client, serverURL)
|
||||
err := r.configureClient(ctx, &client, serverURL, inv)
|
||||
return &client, err
|
||||
}
|
||||
|
||||
|
@ -879,70 +784,6 @@ func formatExamples(examples ...example) string {
|
|||
return sb.String()
|
||||
}
|
||||
|
||||
// checkVersions checks to see if there's a version mismatch between the client
|
||||
// and server and prints a message nudging the user to upgrade if a mismatch
|
||||
// is detected. forceCheck is a test flag and should always be false in production.
|
||||
//
|
||||
//nolint:revive
|
||||
func (r *RootCmd) checkVersions(i *serpent.Invocation, client *codersdk.Client, clientVersion string) error {
|
||||
if r.noVersionCheck {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
serverInfo, err := client.BuildInfo(ctx)
|
||||
// Avoid printing errors that are connection-related.
|
||||
if isConnectionError(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("build info: %w", err)
|
||||
}
|
||||
|
||||
if !buildinfo.VersionsMatch(clientVersion, serverInfo.Version) {
|
||||
upgradeMessage := defaultUpgradeMessage(serverInfo.CanonicalVersion())
|
||||
if serverInfo.UpgradeMessage != "" {
|
||||
upgradeMessage = serverInfo.UpgradeMessage
|
||||
}
|
||||
|
||||
fmtWarningText := "version mismatch: client %s, server %s\n%s"
|
||||
fmtWarn := pretty.Sprint(cliui.DefaultStyles.Warn, fmtWarningText)
|
||||
warning := fmt.Sprintf(fmtWarn, clientVersion, serverInfo.Version, upgradeMessage)
|
||||
|
||||
_, _ = fmt.Fprint(i.Stderr, warning)
|
||||
_, _ = fmt.Fprintln(i.Stderr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RootCmd) checkWarnings(i *serpent.Invocation, client *codersdk.Client) error {
|
||||
if r.noFeatureWarning {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
user, err := client.User(ctx, codersdk.Me)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user me: %w", err)
|
||||
}
|
||||
|
||||
entitlements, err := client.Entitlements(ctx)
|
||||
if err == nil {
|
||||
// Don't show warning to regular users.
|
||||
if len(user.Roles) > 0 {
|
||||
for _, w := range entitlements.Warnings {
|
||||
_, _ = fmt.Fprintln(i.Stderr, pretty.Sprint(cliui.DefaultStyles.Warn, w))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verbosef logs a message if verbose mode is enabled.
|
||||
func (r *RootCmd) Verbosef(inv *serpent.Invocation, fmtStr string, args ...interface{}) {
|
||||
if r.verbose {
|
||||
|
@ -1068,19 +909,6 @@ func ExitError(code int, err error) error {
|
|||
return &exitError{code: code, err: err}
|
||||
}
|
||||
|
||||
// IiConnectionErr is a convenience function for checking if the source of an
|
||||
// error is due to a 'connection refused', 'no such host', etc.
|
||||
func isConnectionError(err error) bool {
|
||||
var (
|
||||
// E.g. no such host
|
||||
dnsErr *net.DNSError
|
||||
// Eg. connection refused
|
||||
opErr *net.OpError
|
||||
)
|
||||
|
||||
return xerrors.As(err, &dnsErr) || xerrors.As(err, &opErr)
|
||||
}
|
||||
|
||||
type prettyErrorFormatter struct {
|
||||
w io.Writer
|
||||
// verbose turns on more detailed error logs, such as stack traces.
|
||||
|
@ -1305,3 +1133,105 @@ func defaultUpgradeMessage(version string) string {
|
|||
}
|
||||
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
||||
}
|
||||
|
||||
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
|
||||
// that checks for entitlement warnings and prints them to the user.
|
||||
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
|
||||
var once sync.Once
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
once.Do(func() {
|
||||
for _, warning := range res.Header.Values(codersdk.EntitlementsWarningHeader) {
|
||||
_, _ = fmt.Fprintln(w, pretty.Sprint(cliui.DefaultStyles.Warn, warning))
|
||||
}
|
||||
})
|
||||
return res, err
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
|
||||
// that checks for version mismatches between the client and server. If a mismatch
|
||||
// is detected, a warning is printed to the user.
|
||||
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
var once sync.Once
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
once.Do(func() {
|
||||
serverVersion := res.Header.Get(codersdk.BuildVersionHeader)
|
||||
if serverVersion == "" {
|
||||
return
|
||||
}
|
||||
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
|
||||
return
|
||||
}
|
||||
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
|
||||
serverInfo, err := getBuildInfo(inv.Context())
|
||||
if err == nil && serverInfo.UpgradeMessage != "" {
|
||||
upgradeMessage = serverInfo.UpgradeMessage
|
||||
}
|
||||
fmtWarningText := "version mismatch: client %s, server %s\n%s"
|
||||
fmtWarn := pretty.Sprint(cliui.DefaultStyles.Warn, fmtWarningText)
|
||||
warning := fmt.Sprintf(fmtWarn, clientVersion, serverVersion, upgradeMessage)
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, warning)
|
||||
})
|
||||
return res, err
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithTelemetryHeader adds telemetry headers to report command usage
|
||||
// to an HTTP transport.
|
||||
func wrapTransportWithTelemetryHeader(transport http.RoundTripper, inv *serpent.Invocation) http.RoundTripper {
|
||||
var (
|
||||
value string
|
||||
once sync.Once
|
||||
)
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
once.Do(func() {
|
||||
// We only want to compute this header once when a request
|
||||
// first goes out, hence the complexity with locking here.
|
||||
var topts []telemetry.Option
|
||||
for _, opt := range inv.Command.FullOptions() {
|
||||
if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault {
|
||||
continue
|
||||
}
|
||||
topts = append(topts, telemetry.Option{
|
||||
Name: opt.Name,
|
||||
ValueSource: string(opt.ValueSource),
|
||||
})
|
||||
}
|
||||
ti := telemetry.Invocation{
|
||||
Command: inv.Command.FullName(),
|
||||
Options: topts,
|
||||
InvokedAt: time.Now(),
|
||||
}
|
||||
|
||||
byt, err := json.Marshal(ti)
|
||||
if err != nil {
|
||||
// Should be impossible
|
||||
panic(err)
|
||||
}
|
||||
s := base64.StdEncoding.EncodeToString(byt)
|
||||
// Don't send the header if it's too long!
|
||||
if len(s) <= 4096 {
|
||||
value = s
|
||||
}
|
||||
})
|
||||
if value != "" {
|
||||
req.Header.Add(codersdk.CLITelemetryHeader, value)
|
||||
}
|
||||
return transport.RoundTrip(req)
|
||||
})
|
||||
}
|
||||
|
||||
type roundTripper func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return r(req)
|
||||
}
|
||||
|
|
|
@ -2,10 +2,13 @@ package cli
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
@ -13,14 +16,30 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/cli/telemetry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Don't run goleak on windows tests, they're super flaky right now.
|
||||
// See: https://github.com/coder/coder/issues/8954
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
goleak.VerifyTestMain(m,
|
||||
// The lumberjack library is used by by agent and seems to leave
|
||||
// goroutines after Close(), fails TestGitSSH tests.
|
||||
// https://github.com/natefinch/lumberjack/pull/100
|
||||
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
|
||||
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"),
|
||||
// The pq library appears to leave around a goroutine after Close().
|
||||
goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"),
|
||||
)
|
||||
}
|
||||
|
||||
func Test_formatExamples(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -80,49 +99,37 @@ func Test_formatExamples(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Don't run goleak on windows tests, they're super flaky right now.
|
||||
// See: https://github.com/coder/coder/issues/8954
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
goleak.VerifyTestMain(m,
|
||||
// The lumberjack library is used by by agent and seems to leave
|
||||
// goroutines after Close(), fails TestGitSSH tests.
|
||||
// https://github.com/natefinch/lumberjack/pull/100
|
||||
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
|
||||
goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"),
|
||||
// The pq library appears to leave around a goroutine after Close().
|
||||
goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"),
|
||||
)
|
||||
}
|
||||
|
||||
func Test_checkVersions(t *testing.T) {
|
||||
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &RootCmd{}
|
||||
cmd, err := r.Command(nil)
|
||||
require.NoError(t, err)
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
// Provider a version that will not match!
|
||||
codersdk.BuildVersionHeader: []string{"v2.0.0"},
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), inv, "v2.0.0", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, "", buf.String())
|
||||
})
|
||||
|
||||
t.Run("CustomUpgradeMessage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectedUpgradeMessage := "My custom upgrade message"
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
|
||||
ExternalURL: buildinfo.ExternalURL(),
|
||||
// Provide a version that will not match
|
||||
Version: "v1.0.0",
|
||||
AgentAPIVersion: coderd.AgentAPIVersionREST,
|
||||
// does not matter what the url is
|
||||
DashboardURL: "https://example.com",
|
||||
WorkspaceProxy: false,
|
||||
UpgradeMessage: expectedUpgradeMessage,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
surl, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := codersdk.New(surl)
|
||||
|
||||
r := &RootCmd{}
|
||||
|
||||
cmd, err := r.Command(nil)
|
||||
|
@ -131,50 +138,85 @@ func Test_checkVersions(t *testing.T) {
|
|||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
|
||||
err = r.checkVersions(inv, client, "v2.0.0")
|
||||
expectedUpgradeMessage := "My custom upgrade message"
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
// Provider a version that will not match!
|
||||
codersdk.BuildVersionHeader: []string{"v1.0.0"},
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), inv, "v2.0.0", func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
return codersdk.BuildInfoResponse{
|
||||
UpgradeMessage: expectedUpgradeMessage,
|
||||
}, nil
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
// Run this twice to ensure the upgrade message is only printed once.
|
||||
res, err = rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", expectedUpgradeMessage)
|
||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
})
|
||||
|
||||
t.Run("DefaultUpgradeMessage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
|
||||
ExternalURL: buildinfo.ExternalURL(),
|
||||
// Provide a version that will not match
|
||||
Version: "v1.0.0",
|
||||
AgentAPIVersion: coderd.AgentAPIVersionREST,
|
||||
// does not matter what the url is
|
||||
DashboardURL: "https://example.com",
|
||||
WorkspaceProxy: false,
|
||||
UpgradeMessage: "",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
surl, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := codersdk.New(surl)
|
||||
|
||||
r := &RootCmd{}
|
||||
|
||||
cmd, err := r.Command(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
|
||||
err = r.checkVersions(inv, client, "v2.0.0")
|
||||
require.NoError(t, err)
|
||||
|
||||
fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", defaultUpgradeMessage("v1.0.0"))
|
||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rt := wrapTransportWithTelemetryHeader(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), &serpent.Invocation{
|
||||
Command: &serpent.Command{
|
||||
Use: "test",
|
||||
Options: serpent.OptionSet{{
|
||||
Name: "bananas",
|
||||
Description: "hey",
|
||||
}},
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
resp := req.Header.Get(codersdk.CLITelemetryHeader)
|
||||
require.NotEmpty(t, resp)
|
||||
data, err := base64.StdEncoding.DecodeString(resp)
|
||||
require.NoError(t, err)
|
||||
var ti telemetry.Invocation
|
||||
err = json.Unmarshal(data, &ti)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ti.Command, "test")
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithEntitlementsCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lines := []string{"First Warning", "Second Warning"}
|
||||
var buf bytes.Buffer
|
||||
rt := wrapTransportWithEntitlementsCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
codersdk.EntitlementsWarningHeader: lines,
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), &buf)
|
||||
res, err := rt.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil))
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
expectedOutput := fmt.Sprintf("%s\n%s\n", pretty.Sprint(cliui.DefaultStyles.Warn, lines[0]),
|
||||
pretty.Sprint(cliui.DefaultStyles.Warn, lines[1]))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
}
|
||||
|
|
|
@ -253,7 +253,7 @@ func TestHandlersOK(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
var root cli.RootCmd
|
||||
cmd, err := root.Command(root.Core())
|
||||
cmd, err := root.Command(root.CoreSubcommands())
|
||||
require.NoError(t, err)
|
||||
|
||||
clitest.HandlersOK(t, cmd)
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
|
@ -197,56 +196,6 @@ func TestTemplateCreate(t *testing.T) {
|
|||
require.NoError(t, err, "Template must be recreated without error")
|
||||
})
|
||||
|
||||
t.Run("WithVariablesFileWithoutRequiredValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
templateVariables := []*proto.TemplateVariable{
|
||||
{
|
||||
Name: "first_variable",
|
||||
Description: "This is the first variable.",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
Sensitive: true,
|
||||
},
|
||||
{
|
||||
Name: "second_variable",
|
||||
Description: "This is the first variable",
|
||||
Type: "string",
|
||||
DefaultValue: "abc",
|
||||
Required: false,
|
||||
Sensitive: true,
|
||||
},
|
||||
}
|
||||
source := clitest.CreateTemplateVersionSource(t,
|
||||
createEchoResponsesWithTemplateVariables(templateVariables))
|
||||
tempDir := t.TempDir()
|
||||
removeTmpDirUntilSuccessAfterTest(t, tempDir)
|
||||
variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml")
|
||||
_, _ = variablesFile.WriteString(`second_variable: foobar`)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name())
|
||||
clitest.SetupConfig(t, client, root)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
|
||||
// We expect the cli to return an error, so we have to handle it
|
||||
// ourselves.
|
||||
go func() {
|
||||
cancel()
|
||||
err := inv.Run()
|
||||
assert.Error(t, err)
|
||||
}()
|
||||
|
||||
pty.ExpectMatch("context canceled")
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
t.Run("WithVariablesFileWithTheRequiredValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ func TestUserList(t *testing.T) {
|
|||
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Contains(t, err.Error(), "Try logging in using 'coder login <url>'.")
|
||||
require.Contains(t, err.Error(), "Try logging in using 'coder login'.")
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ func (r *RootCmd) vscodeSSH() *serpent.Command {
|
|||
client.SetSessionToken(string(sessionToken))
|
||||
|
||||
// This adds custom headers to the request!
|
||||
err = r.setClient(ctx, client, serverURL)
|
||||
err = r.configureClient(ctx, client, serverURL, inv)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set client: %w", err)
|
||||
}
|
||||
|
|
|
@ -8,5 +8,5 @@ import (
|
|||
|
||||
func main() {
|
||||
var rootCmd cli.RootCmd
|
||||
rootCmd.RunMain(rootCmd.AGPL())
|
||||
rootCmd.RunWithSubcommands(rootCmd.AGPL())
|
||||
}
|
||||
|
|
|
@ -129,6 +129,12 @@ type Options struct {
|
|||
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
|
||||
// RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license.
|
||||
RefreshEntitlements func(ctx context.Context) error
|
||||
// PostAuthAdditionalHeadersFunc is used to add additional headers to the response
|
||||
// after a successful authentication.
|
||||
// This is somewhat janky, but seemingly the only reasonable way to add a header
|
||||
// for all authenticated users under a condition, only in Enterprise.
|
||||
PostAuthAdditionalHeadersFunc func(auth httpmw.Authorization, header http.Header)
|
||||
|
||||
// TLSCertificates is used to mesh DERP servers securely.
|
||||
TLSCertificates []tls.Certificate
|
||||
TailnetCoordinator tailnet.Coordinator
|
||||
|
@ -557,30 +563,33 @@ func New(options *Options) *API {
|
|||
}
|
||||
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: false,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: false,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||
})
|
||||
// Same as above but it redirects to the login page.
|
||||
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: true,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: false,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: true,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: false,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||
})
|
||||
// Same as the first but it's optional.
|
||||
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: true,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: true,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||
})
|
||||
|
||||
// API rate limit middleware. The counter is local and not shared between
|
||||
|
|
|
@ -113,6 +113,13 @@ type ExtractAPIKeyConfig struct {
|
|||
// SessionTokenFunc is a custom function that can be used to extract the API
|
||||
// key. If nil, the default behavior is used.
|
||||
SessionTokenFunc func(r *http.Request) string
|
||||
|
||||
// PostAuthAdditionalHeadersFunc is a function that can be used to add
|
||||
// headers to the response after the user has been authenticated.
|
||||
//
|
||||
// This is originally implemented to send entitlement warning headers after
|
||||
// a user is authenticated to prevent additional CLI invocations.
|
||||
PostAuthAdditionalHeadersFunc func(a Authorization, header http.Header)
|
||||
}
|
||||
|
||||
// ExtractAPIKeyMW calls ExtractAPIKey with the given config on each request,
|
||||
|
@ -454,6 +461,10 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
|||
}.WithCachedASTValue(),
|
||||
}
|
||||
|
||||
if cfg.PostAuthAdditionalHeadersFunc != nil {
|
||||
cfg.PostAuthAdditionalHeadersFunc(authz, rw.Header())
|
||||
}
|
||||
|
||||
return key, &authz, true
|
||||
}
|
||||
|
||||
|
|
|
@ -171,7 +171,7 @@ func TestPostTemplateByOrganization(t *testing.T) {
|
|||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode())
|
||||
require.Contains(t, err.Error(), "Try logging in using 'coder login <url>'.")
|
||||
require.Contains(t, err.Error(), "Try logging in using 'coder login'.")
|
||||
})
|
||||
|
||||
t.Run("AllowUserScheduling", func(t *testing.T) {
|
||||
|
|
|
@ -81,6 +81,9 @@ const (
|
|||
|
||||
// BuildVersionHeader contains build information of Coder.
|
||||
BuildVersionHeader = "X-Coder-Build-Version"
|
||||
|
||||
// EntitlementsWarnings contains active warnings for the user's entitlements.
|
||||
EntitlementsWarningHeader = "X-Coder-Entitlements-Warning"
|
||||
)
|
||||
|
||||
// loggableMimeTypes is a list of MIME types that are safe to log
|
||||
|
@ -358,7 +361,7 @@ func ReadBodyAsError(res *http.Response) error {
|
|||
if res.StatusCode == http.StatusUnauthorized {
|
||||
// 401 means the user is not logged in
|
||||
// 403 would mean that the user is not authorized
|
||||
helpMessage = "Try logging in using 'coder login <url>'."
|
||||
helpMessage = "Try logging in using 'coder login'."
|
||||
}
|
||||
|
||||
resp, err := io.ReadAll(res.Body)
|
||||
|
|
|
@ -82,7 +82,7 @@ func (r *RootCmd) provisionerDaemonStart() *serpent.Command {
|
|||
// disable checks and warnings because this command starts a daemon; it is
|
||||
// not meant for humans typing commands. Furthermore, the checks are
|
||||
// incompatible with PSK auth that this command uses
|
||||
r.InitClientMissingTokenOK(client),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx, cancel := context.WithCancel(inv.Context())
|
||||
|
|
|
@ -21,6 +21,6 @@ func (r *RootCmd) enterpriseOnly() []*serpent.Command {
|
|||
}
|
||||
|
||||
func (r *RootCmd) EnterpriseSubcommands() []*serpent.Command {
|
||||
all := append(r.Core(), r.enterpriseOnly()...)
|
||||
all := append(r.CoreSubcommands(), r.enterpriseOnly()...)
|
||||
return all
|
||||
}
|
||||
|
|
|
@ -8,5 +8,5 @@ import (
|
|||
|
||||
func main() {
|
||||
var rootCmd entcli.RootCmd
|
||||
rootCmd.RunMain(rootCmd.EnterpriseSubcommands())
|
||||
rootCmd.RunWithSubcommands(rootCmd.EnterpriseSubcommands())
|
||||
}
|
||||
|
|
|
@ -94,17 +94,18 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
|||
return nil, xerrors.Errorf("init database encryption: %w", err)
|
||||
}
|
||||
options.Database = cryptDB
|
||||
|
||||
api := &API{
|
||||
ctx: ctx,
|
||||
cancel: cancelFunc,
|
||||
AGPL: coderd.New(options.Options),
|
||||
Options: options,
|
||||
provisionerDaemonAuth: &provisionerDaemonAuth{
|
||||
psk: options.ProvisionerDaemonPSK,
|
||||
authorizer: options.Authorizer,
|
||||
},
|
||||
}
|
||||
// This must happen before coderd initialization!
|
||||
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
|
||||
api.AGPL = coderd.New(options.Options)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = api.Close()
|
||||
|
@ -144,29 +145,32 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
|||
OIDC: options.OIDCConfig,
|
||||
}
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: false,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: false,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||
})
|
||||
// Same as above but it redirects to the login page.
|
||||
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: true,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: false,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: true,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: false,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||
})
|
||||
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: true,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(),
|
||||
Optional: true,
|
||||
SessionTokenFunc: nil, // Default behavior
|
||||
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
|
||||
})
|
||||
|
||||
deploymentID, err := options.Database.GetDeploymentID(ctx)
|
||||
|
@ -531,6 +535,38 @@ type API struct {
|
|||
tailnetService *tailnet.ClientService
|
||||
}
|
||||
|
||||
// writeEntitlementWarningsHeader writes the entitlement warnings to the response header
|
||||
// for all authenticated users with roles. If there are no warnings, this header will not be written.
|
||||
//
|
||||
// This header is used by the CLI to display warnings to the user without having
|
||||
// to make additional requests!
|
||||
func (api *API) writeEntitlementWarningsHeader(a httpmw.Authorization, header http.Header) {
|
||||
roles, err := a.Actor.Roles.Expand()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
nonMemberRoles := 0
|
||||
for _, role := range roles {
|
||||
// The member role is implied, and not assignable.
|
||||
// If there is no display name, then the role is also unassigned.
|
||||
// This is not the ideal logic, but works for now.
|
||||
if role.Name == rbac.RoleMember() || (role.DisplayName == "") {
|
||||
continue
|
||||
}
|
||||
nonMemberRoles++
|
||||
}
|
||||
if nonMemberRoles == 0 {
|
||||
// Don't show entitlement warnings if the user
|
||||
// has no roles. This is a normal user!
|
||||
return
|
||||
}
|
||||
api.entitlementsMu.RLock()
|
||||
defer api.entitlementsMu.RUnlock()
|
||||
for _, warning := range api.entitlements.Warnings {
|
||||
header.Add(codersdk.EntitlementsWarningHeader, warning)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) Close() error {
|
||||
// Replica manager should be closed first. This is because the replica
|
||||
// manager updates the replica's table in the database when it closes.
|
||||
|
|
|
@ -3,6 +3,7 @@ package coderd_test
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -197,6 +198,40 @@ func TestEntitlements(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestEntitlements_HeaderWarnings(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("ExistForAdmin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
adminClient, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
AuditLogging: true,
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
AllFeatures: false,
|
||||
},
|
||||
})
|
||||
//nolint:gocritic // This isn't actually bypassing any RBAC checks
|
||||
res, err := adminClient.Request(context.Background(), http.MethodGet, "/api/v2/users/me", nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
require.NotEmpty(t, res.Header.Values(codersdk.EntitlementsWarningHeader))
|
||||
})
|
||||
t.Run("NoneForNormalUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
adminClient, adminUser := coderdenttest.New(t, &coderdenttest.Options{
|
||||
AuditLogging: true,
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
AllFeatures: false,
|
||||
},
|
||||
})
|
||||
anotherClient, _ := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID)
|
||||
res, err := anotherClient.Request(context.Background(), http.MethodGet, "/api/v2/users/me", nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
require.Empty(t, res.Header.Values(codersdk.EntitlementsWarningHeader))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuditLogging(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Enabled", func(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue