diff --git a/cli/login.go b/cli/login.go index 7539497853..dfba4f3b4c 100644 --- a/cli/login.go +++ b/cli/login.go @@ -18,7 +18,6 @@ import ( "github.com/coder/pretty" - "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/coderd/userpassword" "github.com/coder/coder/v2/codersdk" @@ -180,21 +179,11 @@ func (r *RootCmd) login() *serpent.Command { serverURL.Scheme = "https" } - client, err := r.createUnauthenticatedClient(ctx, serverURL) + client, err := r.createUnauthenticatedClient(ctx, serverURL, inv) if err != nil { return err } - // Try to check the version of the server prior to logging in. - // It may be useful to warn the user if they are trying to login - // on a very old client. - err = r.checkVersions(inv, client, buildinfo.Version()) - if err != nil { - // Checking versions isn't a fatal error so we print a warning - // and proceed. - _, _ = fmt.Fprintln(inv.Stderr, pretty.Sprint(cliui.DefaultStyles.Warn, err.Error())) - } - hasFirstUser, err := client.HasFirstUser(ctx) if err != nil { return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err) diff --git a/cli/logout_test.go b/cli/logout_test.go index 493ccb5dd2..62c93c2d6f 100644 --- a/cli/logout_test.go +++ b/cli/logout_test.go @@ -97,33 +97,6 @@ func TestLogout(t *testing.T) { <-logoutChan }) - t.Run("NoSessionFile", func(t *testing.T) { - t.Parallel() - - pty := ptytest.New(t) - config := login(t, pty) - - // Ensure session files exist. - require.FileExists(t, string(config.URL())) - require.FileExists(t, string(config.Session())) - - err := os.Remove(string(config.Session())) - require.NoError(t, err) - - logoutChan := make(chan struct{}) - logout, _ := clitest.New(t, "logout", "--global-config", string(config)) - - logout.Stdin = pty.Input() - logout.Stdout = pty.Output() - - go func() { - defer close(logoutChan) - err = logout.Run() - assert.ErrorContains(t, err, "You are not logged in. Try logging in using 'coder login'.") - }() - - <-logoutChan - }) t.Run("CannotDeleteFiles", func(t *testing.T) { t.Parallel() diff --git a/cli/root.go b/cli/root.go index 88af2860ed..bbc59a8d94 100644 --- a/cli/root.go +++ b/cli/root.go @@ -9,8 +9,6 @@ import ( "errors" "fmt" "io" - "math/rand" - "net" "net/http" "net/url" "os" @@ -20,6 +18,7 @@ import ( "runtime" "runtime/trace" "strings" + "sync" "syscall" "text/tabwriter" "time" @@ -27,6 +26,7 @@ import ( "github.com/mattn/go-isatty" "github.com/mitchellh/go-wordwrap" "golang.org/x/exp/slices" + "golang.org/x/mod/semver" "golang.org/x/xerrors" "github.com/coder/pretty" @@ -67,8 +67,7 @@ const ( varOrganizationSelect = "organization" varDisableDirect = "disable-direct-connections" - notLoggedInMessage = "You are not logged in. Try logging in using 'coder login '." - notLoggedInURLSavedMessage = "You are not logged in. Try logging in using 'coder login'." + notLoggedInMessage = "You are not logged in. Try logging in using 'coder login '." envNoVersionCheck = "CODER_NO_VERSION_WARNING" envNoFeatureWarning = "CODER_NO_FEATURE_WARNING" @@ -80,12 +79,7 @@ const ( envURL = "CODER_URL" ) -var ( - errUnauthenticated = xerrors.New(notLoggedInMessage) - errUnauthenticatedURLSaved = xerrors.New(notLoggedInURLSavedMessage) -) - -func (r *RootCmd) Core() []*serpent.Command { +func (r *RootCmd) CoreSubcommands() []*serpent.Command { // Please re-sort this list alphabetically if you change it! return []*serpent.Command{ r.dotfiles(), @@ -134,12 +128,13 @@ func (r *RootCmd) Core() []*serpent.Command { } func (r *RootCmd) AGPL() []*serpent.Command { - all := append(r.Core(), r.Server( /* Do not import coderd here. */ nil)) + all := append(r.CoreSubcommands(), r.Server( /* Do not import coderd here. */ nil)) return all } -// Main is the entrypoint for the Coder CLI. -func (r *RootCmd) RunMain(subcommands []*serpent.Command) { +// RunWithSubcommands runs the root command with the given subcommands. +// It is abstracted to enable the Enterprise code to add commands. +func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) { // This configuration is not available as a standard option because we // want to trace the entire program, including Options parsing. goTraceFilePath, ok := os.LookupEnv("CODER_GO_TRACE") @@ -156,8 +151,6 @@ func (r *RootCmd) RunMain(subcommands []*serpent.Command) { defer trace.Stop() } - rand.Seed(time.Now().UnixMicro()) - cmd, err := r.Command(subcommands) if err != nil { panic(err) @@ -503,79 +496,20 @@ type RootCmd struct { noFeatureWarning bool } -func addTelemetryHeader(client *codersdk.Client, inv *serpent.Invocation) { - transport, ok := client.HTTPClient.Transport.(*codersdk.HeaderTransport) - if !ok { - transport = &codersdk.HeaderTransport{ - Transport: client.HTTPClient.Transport, - Header: http.Header{}, - } - client.HTTPClient.Transport = transport - } - - var topts []telemetry.Option - for _, opt := range inv.Command.FullOptions() { - if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault { - continue - } - topts = append(topts, telemetry.Option{ - Name: opt.Name, - ValueSource: string(opt.ValueSource), - }) - } - ti := telemetry.Invocation{ - Command: inv.Command.FullName(), - Options: topts, - InvokedAt: time.Now(), - } - - byt, err := json.Marshal(ti) - if err != nil { - // Should be impossible - panic(err) - } - - // Per https://stackoverflow.com/questions/686217/maximum-on-http-header-values, - // we don't want to send headers that are too long. - s := base64.StdEncoding.EncodeToString(byt) - if len(s) > 4096 { - return - } - - transport.Header.Add(codersdk.CLITelemetryHeader, s) -} - -// InitClient sets client to a new client. -// It reads from global configuration files if flags are not set. +// InitClient authenticates the client with files from disk +// and injects header middlewares for telemetry, authentication, +// and version checks. func (r *RootCmd) InitClient(client *codersdk.Client) serpent.MiddlewareFunc { - return serpent.Chain( - r.initClientInternal(client, false), - // By default, we should print warnings in addition to initializing the client - r.PrintWarnings(client), - ) -} - -func (r *RootCmd) InitClientMissingTokenOK(client *codersdk.Client) serpent.MiddlewareFunc { - return r.initClientInternal(client, true) -} - -// nolint: revive -func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing bool) serpent.MiddlewareFunc { - if client == nil { - panic("client is nil") - } - if r == nil { - panic("root is nil") - } return func(next serpent.HandlerFunc) serpent.HandlerFunc { return func(inv *serpent.Invocation) error { conf := r.createConfig() var err error + // Read the client URL stored on disk. if r.clientURL == nil || r.clientURL.String() == "" { rawURL, err := conf.URL().Read() // If the configuration files are absent, the user is logged out if os.IsNotExist(err) { - return errUnauthenticated + return xerrors.New(notLoggedInMessage) } if err != nil { return err @@ -586,25 +520,20 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing return err } } - + // Read the token stored on disk. if r.token == "" { r.token, err = conf.Session().Read() - // If the configuration files are absent, the user is logged out - if os.IsNotExist(err) { - if !allowTokenMissing { - return errUnauthenticatedURLSaved - } - } else if err != nil { + // Even if there isn't a token, we don't care. + // Some API routes can be unauthenticated. + if err != nil && !os.IsNotExist(err) { return err } } - err = r.setClient(inv.Context(), client, r.clientURL) + + err = r.configureClient(inv.Context(), client, r.clientURL, inv) if err != nil { return err } - - addTelemetryHeader(client, inv) - client.SetSessionToken(r.token) if r.debugHTTP { @@ -617,48 +546,8 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing } } -func (r *RootCmd) PrintWarnings(client *codersdk.Client) serpent.MiddlewareFunc { - if client == nil { - panic("client is nil") - } - if r == nil { - panic("root is nil") - } - return func(next serpent.HandlerFunc) serpent.HandlerFunc { - return func(inv *serpent.Invocation) error { - // We send these requests in parallel to minimize latency. - var ( - versionErr = make(chan error) - warningErr = make(chan error) - ) - go func() { - versionErr <- r.checkVersions(inv, client, buildinfo.Version()) - close(versionErr) - }() - - go func() { - warningErr <- r.checkWarnings(inv, client) - close(warningErr) - }() - - if err := <-versionErr; err != nil { - // Just log the error here. We never want to fail a command - // due to a pre-run. - pretty.Fprintf(inv.Stderr, cliui.DefaultStyles.Warn, "check versions error: %s", err) - _, _ = fmt.Fprintln(inv.Stderr) - } - - if err := <-warningErr; err != nil { - // Same as above - pretty.Fprintf(inv.Stderr, cliui.DefaultStyles.Warn, "check entitlement warnings error: %s", err) - _, _ = fmt.Fprintln(inv.Stderr) - } - - return next(inv) - } - } -} - +// HeaderTransport creates a new transport that executes `--header-command` +// if it is set to add headers for all outbound requests. func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) { transport := &codersdk.HeaderTransport{ Transport: http.DefaultTransport, @@ -700,22 +589,38 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod return transport, nil } -func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL) error { - transport, err := r.HeaderTransport(ctx, serverURL) +func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error { + transport := http.DefaultTransport + transport = wrapTransportWithTelemetryHeader(transport, inv) + if !r.noVersionCheck { + transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) { + // Create a new client without any wrapped transport + // otherwise it creates an infinite loop! + basicClient := codersdk.New(serverURL) + return basicClient.BuildInfo(ctx) + }) + } + if !r.noFeatureWarning { + transport = wrapTransportWithEntitlementsCheck(transport, inv.Stderr) + } + headerTransport, err := r.HeaderTransport(ctx, serverURL) if err != nil { return xerrors.Errorf("create header transport: %w", err) } - - client.URL = serverURL + // The header transport has to come last. + // codersdk checks for the header transport to get headers + // to clone on the DERP client. + headerTransport.Transport = transport client.HTTPClient = &http.Client{ - Transport: transport, + Transport: headerTransport, } + client.URL = serverURL return nil } -func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *url.URL) (*codersdk.Client, error) { +func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *url.URL, inv *serpent.Invocation) (*codersdk.Client, error) { var client codersdk.Client - err := r.setClient(ctx, &client, serverURL) + err := r.configureClient(ctx, &client, serverURL, inv) return &client, err } @@ -879,70 +784,6 @@ func formatExamples(examples ...example) string { return sb.String() } -// checkVersions checks to see if there's a version mismatch between the client -// and server and prints a message nudging the user to upgrade if a mismatch -// is detected. forceCheck is a test flag and should always be false in production. -// -//nolint:revive -func (r *RootCmd) checkVersions(i *serpent.Invocation, client *codersdk.Client, clientVersion string) error { - if r.noVersionCheck { - return nil - } - - ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second) - defer cancel() - - serverInfo, err := client.BuildInfo(ctx) - // Avoid printing errors that are connection-related. - if isConnectionError(err) { - return nil - } - if err != nil { - return xerrors.Errorf("build info: %w", err) - } - - if !buildinfo.VersionsMatch(clientVersion, serverInfo.Version) { - upgradeMessage := defaultUpgradeMessage(serverInfo.CanonicalVersion()) - if serverInfo.UpgradeMessage != "" { - upgradeMessage = serverInfo.UpgradeMessage - } - - fmtWarningText := "version mismatch: client %s, server %s\n%s" - fmtWarn := pretty.Sprint(cliui.DefaultStyles.Warn, fmtWarningText) - warning := fmt.Sprintf(fmtWarn, clientVersion, serverInfo.Version, upgradeMessage) - - _, _ = fmt.Fprint(i.Stderr, warning) - _, _ = fmt.Fprintln(i.Stderr) - } - - return nil -} - -func (r *RootCmd) checkWarnings(i *serpent.Invocation, client *codersdk.Client) error { - if r.noFeatureWarning { - return nil - } - - ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second) - defer cancel() - - user, err := client.User(ctx, codersdk.Me) - if err != nil { - return xerrors.Errorf("get user me: %w", err) - } - - entitlements, err := client.Entitlements(ctx) - if err == nil { - // Don't show warning to regular users. - if len(user.Roles) > 0 { - for _, w := range entitlements.Warnings { - _, _ = fmt.Fprintln(i.Stderr, pretty.Sprint(cliui.DefaultStyles.Warn, w)) - } - } - } - return nil -} - // Verbosef logs a message if verbose mode is enabled. func (r *RootCmd) Verbosef(inv *serpent.Invocation, fmtStr string, args ...interface{}) { if r.verbose { @@ -1068,19 +909,6 @@ func ExitError(code int, err error) error { return &exitError{code: code, err: err} } -// IiConnectionErr is a convenience function for checking if the source of an -// error is due to a 'connection refused', 'no such host', etc. -func isConnectionError(err error) bool { - var ( - // E.g. no such host - dnsErr *net.DNSError - // Eg. connection refused - opErr *net.OpError - ) - - return xerrors.As(err, &dnsErr) || xerrors.As(err, &opErr) -} - type prettyErrorFormatter struct { w io.Writer // verbose turns on more detailed error logs, such as stack traces. @@ -1305,3 +1133,105 @@ func defaultUpgradeMessage(version string) string { } return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version) } + +// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport +// that checks for entitlement warnings and prints them to the user. +func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper { + var once sync.Once + return roundTripper(func(req *http.Request) (*http.Response, error) { + res, err := rt.RoundTrip(req) + if err != nil { + return res, err + } + once.Do(func() { + for _, warning := range res.Header.Values(codersdk.EntitlementsWarningHeader) { + _, _ = fmt.Fprintln(w, pretty.Sprint(cliui.DefaultStyles.Warn, warning)) + } + }) + return res, err + }) +} + +// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport +// that checks for version mismatches between the client and server. If a mismatch +// is detected, a warning is printed to the user. +func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper { + var once sync.Once + return roundTripper(func(req *http.Request) (*http.Response, error) { + res, err := rt.RoundTrip(req) + if err != nil { + return res, err + } + once.Do(func() { + serverVersion := res.Header.Get(codersdk.BuildVersionHeader) + if serverVersion == "" { + return + } + if buildinfo.VersionsMatch(clientVersion, serverVersion) { + return + } + upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion)) + serverInfo, err := getBuildInfo(inv.Context()) + if err == nil && serverInfo.UpgradeMessage != "" { + upgradeMessage = serverInfo.UpgradeMessage + } + fmtWarningText := "version mismatch: client %s, server %s\n%s" + fmtWarn := pretty.Sprint(cliui.DefaultStyles.Warn, fmtWarningText) + warning := fmt.Sprintf(fmtWarn, clientVersion, serverVersion, upgradeMessage) + + _, _ = fmt.Fprintln(inv.Stderr, warning) + }) + return res, err + }) +} + +// wrapTransportWithTelemetryHeader adds telemetry headers to report command usage +// to an HTTP transport. +func wrapTransportWithTelemetryHeader(transport http.RoundTripper, inv *serpent.Invocation) http.RoundTripper { + var ( + value string + once sync.Once + ) + return roundTripper(func(req *http.Request) (*http.Response, error) { + once.Do(func() { + // We only want to compute this header once when a request + // first goes out, hence the complexity with locking here. + var topts []telemetry.Option + for _, opt := range inv.Command.FullOptions() { + if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault { + continue + } + topts = append(topts, telemetry.Option{ + Name: opt.Name, + ValueSource: string(opt.ValueSource), + }) + } + ti := telemetry.Invocation{ + Command: inv.Command.FullName(), + Options: topts, + InvokedAt: time.Now(), + } + + byt, err := json.Marshal(ti) + if err != nil { + // Should be impossible + panic(err) + } + s := base64.StdEncoding.EncodeToString(byt) + // Don't send the header if it's too long! + if len(s) <= 4096 { + value = s + } + }) + if value != "" { + req.Header.Add(codersdk.CLITelemetryHeader, value) + } + return transport.RoundTrip(req) + }) +} + +type roundTripper func(req *http.Request) (*http.Response, error) + +func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return r(req) +} diff --git a/cli/root_internal_test.go b/cli/root_internal_test.go index 6d108ee554..9bb05a33b1 100644 --- a/cli/root_internal_test.go +++ b/cli/root_internal_test.go @@ -2,10 +2,13 @@ package cli import ( "bytes" + "context" + "encoding/base64" + "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" - "net/url" "os" "runtime" "testing" @@ -13,14 +16,30 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/cli/cliui" - "github.com/coder/coder/v2/coderd" - "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/cli/telemetry" "github.com/coder/coder/v2/codersdk" "github.com/coder/pretty" + "github.com/coder/serpent" ) +func TestMain(m *testing.M) { + if runtime.GOOS == "windows" { + // Don't run goleak on windows tests, they're super flaky right now. + // See: https://github.com/coder/coder/issues/8954 + os.Exit(m.Run()) + } + goleak.VerifyTestMain(m, + // The lumberjack library is used by by agent and seems to leave + // goroutines after Close(), fails TestGitSSH tests. + // https://github.com/natefinch/lumberjack/pull/100 + goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), + goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"), + // The pq library appears to leave around a goroutine after Close(). + goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"), + ) +} + func Test_formatExamples(t *testing.T) { t.Parallel() @@ -80,49 +99,37 @@ func Test_formatExamples(t *testing.T) { } } -func TestMain(m *testing.M) { - if runtime.GOOS == "windows" { - // Don't run goleak on windows tests, they're super flaky right now. - // See: https://github.com/coder/coder/issues/8954 - os.Exit(m.Run()) - } - goleak.VerifyTestMain(m, - // The lumberjack library is used by by agent and seems to leave - // goroutines after Close(), fails TestGitSSH tests. - // https://github.com/natefinch/lumberjack/pull/100 - goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), - goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"), - // The pq library appears to leave around a goroutine after Close(). - goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"), - ) -} - -func Test_checkVersions(t *testing.T) { +func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) { t.Parallel() + t.Run("NoOutput", func(t *testing.T) { + t.Parallel() + r := &RootCmd{} + cmd, err := r.Command(nil) + require.NoError(t, err) + var buf bytes.Buffer + inv := cmd.Invoke() + inv.Stderr = &buf + rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + // Provider a version that will not match! + codersdk.BuildVersionHeader: []string{"v2.0.0"}, + }, + Body: io.NopCloser(nil), + }, nil + }), inv, "v2.0.0", nil) + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, "", buf.String()) + }) + t.Run("CustomUpgradeMessage", func(t *testing.T) { t.Parallel() - expectedUpgradeMessage := "My custom upgrade message" - - srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{ - ExternalURL: buildinfo.ExternalURL(), - // Provide a version that will not match - Version: "v1.0.0", - AgentAPIVersion: coderd.AgentAPIVersionREST, - // does not matter what the url is - DashboardURL: "https://example.com", - WorkspaceProxy: false, - UpgradeMessage: expectedUpgradeMessage, - }) - })) - defer srv.Close() - surl, err := url.Parse(srv.URL) - require.NoError(t, err) - - client := codersdk.New(surl) - r := &RootCmd{} cmd, err := r.Command(nil) @@ -131,50 +138,85 @@ func Test_checkVersions(t *testing.T) { var buf bytes.Buffer inv := cmd.Invoke() inv.Stderr = &buf - - err = r.checkVersions(inv, client, "v2.0.0") + expectedUpgradeMessage := "My custom upgrade message" + rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + // Provider a version that will not match! + codersdk.BuildVersionHeader: []string{"v1.0.0"}, + }, + Body: io.NopCloser(nil), + }, nil + }), inv, "v2.0.0", func(ctx context.Context) (codersdk.BuildInfoResponse, error) { + return codersdk.BuildInfoResponse{ + UpgradeMessage: expectedUpgradeMessage, + }, nil + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + res, err := rt.RoundTrip(req) require.NoError(t, err) + defer res.Body.Close() + + // Run this twice to ensure the upgrade message is only printed once. + res, err = rt.RoundTrip(req) + require.NoError(t, err) + defer res.Body.Close() fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", expectedUpgradeMessage) expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput)) require.Equal(t, expectedOutput, buf.String()) }) - - t.Run("DefaultUpgradeMessage", func(t *testing.T) { - t.Parallel() - - srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{ - ExternalURL: buildinfo.ExternalURL(), - // Provide a version that will not match - Version: "v1.0.0", - AgentAPIVersion: coderd.AgentAPIVersionREST, - // does not matter what the url is - DashboardURL: "https://example.com", - WorkspaceProxy: false, - UpgradeMessage: "", - }) - })) - defer srv.Close() - surl, err := url.Parse(srv.URL) - require.NoError(t, err) - - client := codersdk.New(surl) - - r := &RootCmd{} - - cmd, err := r.Command(nil) - require.NoError(t, err) - - var buf bytes.Buffer - inv := cmd.Invoke() - inv.Stderr = &buf - - err = r.checkVersions(inv, client, "v2.0.0") - require.NoError(t, err) - - fmtOutput := fmt.Sprintf("version mismatch: client v2.0.0, server v1.0.0\n%s", defaultUpgradeMessage("v1.0.0")) - expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput)) - require.Equal(t, expectedOutput, buf.String()) - }) +} + +func Test_wrapTransportWithTelemetryHeader(t *testing.T) { + t.Parallel() + + rt := wrapTransportWithTelemetryHeader(roundTripper(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + Body: io.NopCloser(nil), + }, nil + }), &serpent.Invocation{ + Command: &serpent.Command{ + Use: "test", + Options: serpent.OptionSet{{ + Name: "bananas", + Description: "hey", + }}, + }, + }) + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + defer res.Body.Close() + resp := req.Header.Get(codersdk.CLITelemetryHeader) + require.NotEmpty(t, resp) + data, err := base64.StdEncoding.DecodeString(resp) + require.NoError(t, err) + var ti telemetry.Invocation + err = json.Unmarshal(data, &ti) + require.NoError(t, err) + require.Equal(t, ti.Command, "test") +} + +func Test_wrapTransportWithEntitlementsCheck(t *testing.T) { + t.Parallel() + + lines := []string{"First Warning", "Second Warning"} + var buf bytes.Buffer + rt := wrapTransportWithEntitlementsCheck(roundTripper(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + codersdk.EntitlementsWarningHeader: lines, + }, + Body: io.NopCloser(nil), + }, nil + }), &buf) + res, err := rt.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil)) + require.NoError(t, err) + defer res.Body.Close() + expectedOutput := fmt.Sprintf("%s\n%s\n", pretty.Sprint(cliui.DefaultStyles.Warn, lines[0]), + pretty.Sprint(cliui.DefaultStyles.Warn, lines[1])) + require.Equal(t, expectedOutput, buf.String()) } diff --git a/cli/root_test.go b/cli/root_test.go index bc10cfceec..897aea18fe 100644 --- a/cli/root_test.go +++ b/cli/root_test.go @@ -253,7 +253,7 @@ func TestHandlersOK(t *testing.T) { t.Parallel() var root cli.RootCmd - cmd, err := root.Command(root.Core()) + cmd, err := root.Command(root.CoreSubcommands()) require.NoError(t, err) clitest.HandlersOK(t, cmd) diff --git a/cli/templatecreate_test.go b/cli/templatecreate_test.go index 9710a86a88..42ef60946b 100644 --- a/cli/templatecreate_test.go +++ b/cli/templatecreate_test.go @@ -7,7 +7,6 @@ import ( "path/filepath" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/cli/clitest" @@ -197,56 +196,6 @@ func TestTemplateCreate(t *testing.T) { require.NoError(t, err, "Template must be recreated without error") }) - t.Run("WithVariablesFileWithoutRequiredValue", func(t *testing.T) { - t.Parallel() - - client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) - coderdtest.CreateFirstUser(t, client) - - templateVariables := []*proto.TemplateVariable{ - { - Name: "first_variable", - Description: "This is the first variable.", - Type: "string", - Required: true, - Sensitive: true, - }, - { - Name: "second_variable", - Description: "This is the first variable", - Type: "string", - DefaultValue: "abc", - Required: false, - Sensitive: true, - }, - } - source := clitest.CreateTemplateVersionSource(t, - createEchoResponsesWithTemplateVariables(templateVariables)) - tempDir := t.TempDir() - removeTmpDirUntilSuccessAfterTest(t, tempDir) - variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml") - _, _ = variablesFile.WriteString(`second_variable: foobar`) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name()) - clitest.SetupConfig(t, client, root) - inv = inv.WithContext(ctx) - pty := ptytest.New(t).Attach(inv) - - // We expect the cli to return an error, so we have to handle it - // ourselves. - go func() { - cancel() - err := inv.Run() - assert.Error(t, err) - }() - - pty.ExpectMatch("context canceled") - <-ctx.Done() - }) - t.Run("WithVariablesFileWithTheRequiredValue", func(t *testing.T) { t.Parallel() diff --git a/cli/userlist_test.go b/cli/userlist_test.go index 64565e1dde..feca8746df 100644 --- a/cli/userlist_test.go +++ b/cli/userlist_test.go @@ -77,7 +77,7 @@ func TestUserList(t *testing.T) { var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) - require.Contains(t, err.Error(), "Try logging in using 'coder login '.") + require.Contains(t, err.Error(), "Try logging in using 'coder login'.") }) } diff --git a/cli/vscodessh.go b/cli/vscodessh.go index b5c2c7d352..eec8ba775b 100644 --- a/cli/vscodessh.go +++ b/cli/vscodessh.go @@ -90,7 +90,7 @@ func (r *RootCmd) vscodeSSH() *serpent.Command { client.SetSessionToken(string(sessionToken)) // This adds custom headers to the request! - err = r.setClient(ctx, client, serverURL) + err = r.configureClient(ctx, client, serverURL, inv) if err != nil { return xerrors.Errorf("set client: %w", err) } diff --git a/cmd/coder/main.go b/cmd/coder/main.go index 5d1cea2f80..7d41563c18 100644 --- a/cmd/coder/main.go +++ b/cmd/coder/main.go @@ -8,5 +8,5 @@ import ( func main() { var rootCmd cli.RootCmd - rootCmd.RunMain(rootCmd.AGPL()) + rootCmd.RunWithSubcommands(rootCmd.AGPL()) } diff --git a/coderd/coderd.go b/coderd/coderd.go index 7c62a62bc2..2c6367c58a 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -129,6 +129,12 @@ type Options struct { TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error // RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license. RefreshEntitlements func(ctx context.Context) error + // PostAuthAdditionalHeadersFunc is used to add additional headers to the response + // after a successful authentication. + // This is somewhat janky, but seemingly the only reasonable way to add a header + // for all authenticated users under a condition, only in Enterprise. + PostAuthAdditionalHeadersFunc func(auth httpmw.Authorization, header http.Header) + // TLSCertificates is used to mesh DERP servers securely. TLSCertificates []tls.Certificate TailnetCoordinator tailnet.Coordinator @@ -557,30 +563,33 @@ func New(options *Options) *API { } apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ - DB: options.Database, - OAuth2Configs: oauthConfigs, - RedirectToLogin: false, - DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), - Optional: false, - SessionTokenFunc: nil, // Default behavior + DB: options.Database, + OAuth2Configs: oauthConfigs, + RedirectToLogin: false, + DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), + Optional: false, + SessionTokenFunc: nil, // Default behavior + PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, }) // Same as above but it redirects to the login page. apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ - DB: options.Database, - OAuth2Configs: oauthConfigs, - RedirectToLogin: true, - DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), - Optional: false, - SessionTokenFunc: nil, // Default behavior + DB: options.Database, + OAuth2Configs: oauthConfigs, + RedirectToLogin: true, + DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), + Optional: false, + SessionTokenFunc: nil, // Default behavior + PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, }) // Same as the first but it's optional. apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ - DB: options.Database, - OAuth2Configs: oauthConfigs, - RedirectToLogin: false, - DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), - Optional: true, - SessionTokenFunc: nil, // Default behavior + DB: options.Database, + OAuth2Configs: oauthConfigs, + RedirectToLogin: false, + DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), + Optional: true, + SessionTokenFunc: nil, // Default behavior + PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, }) // API rate limit middleware. The counter is local and not shared between diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 46d8c97014..733e722e04 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -113,6 +113,13 @@ type ExtractAPIKeyConfig struct { // SessionTokenFunc is a custom function that can be used to extract the API // key. If nil, the default behavior is used. SessionTokenFunc func(r *http.Request) string + + // PostAuthAdditionalHeadersFunc is a function that can be used to add + // headers to the response after the user has been authenticated. + // + // This is originally implemented to send entitlement warning headers after + // a user is authenticated to prevent additional CLI invocations. + PostAuthAdditionalHeadersFunc func(a Authorization, header http.Header) } // ExtractAPIKeyMW calls ExtractAPIKey with the given config on each request, @@ -454,6 +461,10 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon }.WithCachedASTValue(), } + if cfg.PostAuthAdditionalHeadersFunc != nil { + cfg.PostAuthAdditionalHeadersFunc(authz, rw.Header()) + } + return key, &authz, true } diff --git a/coderd/templates_test.go b/coderd/templates_test.go index 13501274f5..087044c293 100644 --- a/coderd/templates_test.go +++ b/coderd/templates_test.go @@ -171,7 +171,7 @@ func TestPostTemplateByOrganization(t *testing.T) { var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode()) - require.Contains(t, err.Error(), "Try logging in using 'coder login '.") + require.Contains(t, err.Error(), "Try logging in using 'coder login'.") }) t.Run("AllowUserScheduling", func(t *testing.T) { diff --git a/codersdk/client.go b/codersdk/client.go index b6a1b1dc11..f1ac879817 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -81,6 +81,9 @@ const ( // BuildVersionHeader contains build information of Coder. BuildVersionHeader = "X-Coder-Build-Version" + + // EntitlementsWarnings contains active warnings for the user's entitlements. + EntitlementsWarningHeader = "X-Coder-Entitlements-Warning" ) // loggableMimeTypes is a list of MIME types that are safe to log @@ -358,7 +361,7 @@ func ReadBodyAsError(res *http.Response) error { if res.StatusCode == http.StatusUnauthorized { // 401 means the user is not logged in // 403 would mean that the user is not authorized - helpMessage = "Try logging in using 'coder login '." + helpMessage = "Try logging in using 'coder login'." } resp, err := io.ReadAll(res.Body) diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index bbcb7af678..0b0548cfd0 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -82,7 +82,7 @@ func (r *RootCmd) provisionerDaemonStart() *serpent.Command { // disable checks and warnings because this command starts a daemon; it is // not meant for humans typing commands. Furthermore, the checks are // incompatible with PSK auth that this command uses - r.InitClientMissingTokenOK(client), + r.InitClient(client), ), Handler: func(inv *serpent.Invocation) error { ctx, cancel := context.WithCancel(inv.Context()) diff --git a/enterprise/cli/root.go b/enterprise/cli/root.go index f510369ec0..74615ff0e9 100644 --- a/enterprise/cli/root.go +++ b/enterprise/cli/root.go @@ -21,6 +21,6 @@ func (r *RootCmd) enterpriseOnly() []*serpent.Command { } func (r *RootCmd) EnterpriseSubcommands() []*serpent.Command { - all := append(r.Core(), r.enterpriseOnly()...) + all := append(r.CoreSubcommands(), r.enterpriseOnly()...) return all } diff --git a/enterprise/cmd/coder/main.go b/enterprise/cmd/coder/main.go index 0aa1400c5c..c7e19dfab9 100644 --- a/enterprise/cmd/coder/main.go +++ b/enterprise/cmd/coder/main.go @@ -8,5 +8,5 @@ import ( func main() { var rootCmd entcli.RootCmd - rootCmd.RunMain(rootCmd.EnterpriseSubcommands()) + rootCmd.RunWithSubcommands(rootCmd.EnterpriseSubcommands()) } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 4a31e13362..4ce4414bdc 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -94,17 +94,18 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { return nil, xerrors.Errorf("init database encryption: %w", err) } options.Database = cryptDB - api := &API{ ctx: ctx, cancel: cancelFunc, - AGPL: coderd.New(options.Options), Options: options, provisionerDaemonAuth: &provisionerDaemonAuth{ psk: options.ProvisionerDaemonPSK, authorizer: options.Authorizer, }, } + // This must happen before coderd initialization! + options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader + api.AGPL = coderd.New(options.Options) defer func() { if err != nil { _ = api.Close() @@ -144,29 +145,32 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { OIDC: options.OIDCConfig, } apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ - DB: options.Database, - OAuth2Configs: oauthConfigs, - RedirectToLogin: false, - DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), - Optional: false, - SessionTokenFunc: nil, // Default behavior + DB: options.Database, + OAuth2Configs: oauthConfigs, + RedirectToLogin: false, + DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), + Optional: false, + SessionTokenFunc: nil, // Default behavior + PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, }) // Same as above but it redirects to the login page. apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ - DB: options.Database, - OAuth2Configs: oauthConfigs, - RedirectToLogin: true, - DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), - Optional: false, - SessionTokenFunc: nil, // Default behavior + DB: options.Database, + OAuth2Configs: oauthConfigs, + RedirectToLogin: true, + DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), + Optional: false, + SessionTokenFunc: nil, // Default behavior + PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, }) apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ - DB: options.Database, - OAuth2Configs: oauthConfigs, - RedirectToLogin: false, - DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), - Optional: true, - SessionTokenFunc: nil, // Default behavior + DB: options.Database, + OAuth2Configs: oauthConfigs, + RedirectToLogin: false, + DisableSessionExpiryRefresh: options.DeploymentValues.DisableSessionExpiryRefresh.Value(), + Optional: true, + SessionTokenFunc: nil, // Default behavior + PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc, }) deploymentID, err := options.Database.GetDeploymentID(ctx) @@ -531,6 +535,38 @@ type API struct { tailnetService *tailnet.ClientService } +// writeEntitlementWarningsHeader writes the entitlement warnings to the response header +// for all authenticated users with roles. If there are no warnings, this header will not be written. +// +// This header is used by the CLI to display warnings to the user without having +// to make additional requests! +func (api *API) writeEntitlementWarningsHeader(a httpmw.Authorization, header http.Header) { + roles, err := a.Actor.Roles.Expand() + if err != nil { + return + } + nonMemberRoles := 0 + for _, role := range roles { + // The member role is implied, and not assignable. + // If there is no display name, then the role is also unassigned. + // This is not the ideal logic, but works for now. + if role.Name == rbac.RoleMember() || (role.DisplayName == "") { + continue + } + nonMemberRoles++ + } + if nonMemberRoles == 0 { + // Don't show entitlement warnings if the user + // has no roles. This is a normal user! + return + } + api.entitlementsMu.RLock() + defer api.entitlementsMu.RUnlock() + for _, warning := range api.entitlements.Warnings { + header.Add(codersdk.EntitlementsWarningHeader, warning) + } +} + func (api *API) Close() error { // Replica manager should be closed first. This is because the replica // manager updates the replica's table in the database when it closes. diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 73b8cca5ac..e2074dd43b 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -3,6 +3,7 @@ package coderd_test import ( "bytes" "context" + "net/http" "reflect" "strings" "testing" @@ -197,6 +198,40 @@ func TestEntitlements(t *testing.T) { }) } +func TestEntitlements_HeaderWarnings(t *testing.T) { + t.Parallel() + t.Run("ExistForAdmin", func(t *testing.T) { + t.Parallel() + adminClient, _ := coderdenttest.New(t, &coderdenttest.Options{ + AuditLogging: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AllFeatures: false, + }, + }) + //nolint:gocritic // This isn't actually bypassing any RBAC checks + res, err := adminClient.Request(context.Background(), http.MethodGet, "/api/v2/users/me", nil) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.NotEmpty(t, res.Header.Values(codersdk.EntitlementsWarningHeader)) + }) + t.Run("NoneForNormalUser", func(t *testing.T) { + t.Parallel() + adminClient, adminUser := coderdenttest.New(t, &coderdenttest.Options{ + AuditLogging: true, + LicenseOptions: &coderdenttest.LicenseOptions{ + AllFeatures: false, + }, + }) + anotherClient, _ := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID) + res, err := anotherClient.Request(context.Background(), http.MethodGet, "/api/v2/users/me", nil) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Empty(t, res.Header.Values(codersdk.EntitlementsWarningHeader)) + }) +} + func TestAuditLogging(t *testing.T) { t.Parallel() t.Run("Enabled", func(t *testing.T) {