From 915bb41ea22687bebcb159e29ab0ba995c148a90 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 6 Oct 2022 19:28:22 -0500 Subject: [PATCH] feat: Add trial property to licenses (#4372) * feat: Add trial property to licenses This allows the frontend to display whether the user is on a trial license of Coder. This is useful for advertising Enterprise functionality. * Improve tests for license enablement code * Add all features property --- .vscode/settings.json | 3 +- codersdk/features.go | 1 + enterprise/cli/features_test.go | 4 +- enterprise/coderd/coderd.go | 213 +++------------ enterprise/coderd/coderd_test.go | 42 --- .../coderd/coderdenttest/coderdenttest.go | 24 +- enterprise/coderd/license/license.go | 243 ++++++++++++++++++ enterprise/coderd/license/license_test.go | 189 ++++++++++++++ enterprise/coderd/licenses.go | 100 +------ enterprise/coderd/licenses_test.go | 6 +- enterprise/coderd/scim.go | 4 +- enterprise/coderd/workspaceagents.go | 4 +- site/src/api/api.ts | 1 + site/src/api/typesGenerated.ts | 1 + site/src/testHelpers/entities.ts | 3 + .../entitlements/entitlementsXService.ts | 1 + 16 files changed, 507 insertions(+), 332 deletions(-) create mode 100644 enterprise/coderd/license/license.go create mode 100644 enterprise/coderd/license/license_test.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 726e2a0121..e9a32e850c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,8 +9,8 @@ "cliflag", "cliui", "codecov", - "Codespaces", "coderd", + "coderdenttest", "coderdtest", "codersdk", "cronstrue", @@ -24,6 +24,7 @@ "drpcmux", "drpcserver", "Dsts", + "enablements", "fatih", "Formik", "gitsshkey", diff --git a/codersdk/features.go b/codersdk/features.go index 3b57d6eeb3..fe8673ef02 100644 --- a/codersdk/features.go +++ b/codersdk/features.go @@ -42,6 +42,7 @@ type Entitlements struct { Warnings []string `json:"warnings"` HasLicense bool `json:"has_license"` Experimental bool `json:"experimental"` + Trial bool `json:"trial"` } func (c *Client) Entitlements(ctx context.Context) (Entitlements, error) { diff --git a/enterprise/cli/features_test.go b/enterprise/cli/features_test.go index da2425634c..f5e7b1ff35 100644 --- a/enterprise/cli/features_test.go +++ b/enterprise/cli/features_test.go @@ -57,7 +57,7 @@ func TestFeaturesList(t *testing.T) { var entitlements codersdk.Entitlements err := json.Unmarshal(buf.Bytes(), &entitlements) require.NoError(t, err, "unmarshal JSON output") - assert.Len(t, entitlements.Features, 4) + assert.Len(t, entitlements.Features, 5) assert.Empty(t, entitlements.Warnings) assert.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[codersdk.FeatureUserLimit].Entitlement) @@ -67,6 +67,8 @@ func TestFeaturesList(t *testing.T) { entitlements.Features[codersdk.FeatureBrowserOnly].Entitlement) assert.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[codersdk.FeatureWorkspaceQuota].Entitlement) + assert.Equal(t, codersdk.EntitlementNotEntitled, + entitlements.Features[codersdk.FeatureSCIM].Entitlement) assert.False(t, entitlements.HasLicense) assert.False(t, entitlements.Experimental) }) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 324d458ba4..11cceef98f 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -3,7 +3,6 @@ package coderd import ( "context" "crypto/ed25519" - "fmt" "net/http" "sync" "time" @@ -15,11 +14,14 @@ import ( "cdr.dev/slog" "github.com/coder/coder/coderd" + agplaudit "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/coderd/workspacequota" "github.com/coder/coder/codersdk" "github.com/coder/coder/enterprise/audit" "github.com/coder/coder/enterprise/audit/backends" + "github.com/coder/coder/enterprise/coderd/license" ) // New constructs an Enterprise coderd API instance. @@ -34,19 +36,8 @@ func New(ctx context.Context, options *Options) (*API, error) { } ctx, cancelFunc := context.WithCancel(ctx) api := &API{ - AGPL: coderd.New(options.Options), - Options: options, - - entitlements: entitlements{ - activeUsers: codersdk.Feature{ - Entitlement: codersdk.EntitlementNotEntitled, - Enabled: false, - }, - auditLogs: codersdk.EntitlementNotEntitled, - browserOnly: codersdk.EntitlementNotEntitled, - scim: codersdk.EntitlementNotEntitled, - workspaceQuota: codersdk.EntitlementNotEntitled, - }, + AGPL: coderd.New(options.Options), + Options: options, cancelEntitlementsLoop: cancelFunc, } oauthConfigs := &httpmw.OAuth2Configs{ @@ -117,16 +108,7 @@ type API struct { cancelEntitlementsLoop func() entitlementsMu sync.RWMutex - entitlements entitlements -} - -type entitlements struct { - hasLicense bool - activeUsers codersdk.Feature - auditLogs codersdk.Entitlement - browserOnly codersdk.Entitlement - scim codersdk.Entitlement - workspaceQuota codersdk.Entitlement + entitlements codersdk.Entitlements } func (api *API) Close() error { @@ -135,94 +117,57 @@ func (api *API) Close() error { } func (api *API) updateEntitlements(ctx context.Context) error { - licenses, err := api.Database.GetUnexpiredLicenses(ctx) + api.entitlementsMu.Lock() + defer api.entitlementsMu.Unlock() + + entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, api.Keys, map[string]bool{ + codersdk.FeatureAuditLog: api.AuditLogging, + codersdk.FeatureBrowserOnly: api.BrowserOnly, + codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0, + codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0, + }) if err != nil { return err } - api.entitlementsMu.Lock() - defer api.entitlementsMu.Unlock() - now := time.Now() - // Default all entitlements to be disabled. - entitlements := entitlements{ - hasLicense: false, - activeUsers: codersdk.Feature{ - Enabled: false, - Entitlement: codersdk.EntitlementNotEntitled, - }, - auditLogs: codersdk.EntitlementNotEntitled, - scim: codersdk.EntitlementNotEntitled, - browserOnly: codersdk.EntitlementNotEntitled, - workspaceQuota: codersdk.EntitlementNotEntitled, + featureChanged := func(featureName string) (changed bool, enabled bool) { + if api.entitlements.Features == nil { + return true, entitlements.Features[featureName].Enabled + } + oldFeature := api.entitlements.Features[featureName] + newFeature := entitlements.Features[featureName] + if oldFeature.Enabled != newFeature.Enabled { + return true, newFeature.Enabled + } + return false, newFeature.Enabled } - // Here we loop through licenses to detect enabled features. - for _, l := range licenses { - claims, err := validateDBLicense(l, api.Keys) - if err != nil { - api.Logger.Debug(ctx, "skipping invalid license", - slog.F("id", l.ID), slog.Error(err)) - continue - } - entitlements.hasLicense = true - entitlement := codersdk.EntitlementEntitled - if now.After(claims.LicenseExpires.Time) { - // if the grace period were over, the validation fails, so if we are after - // LicenseExpires we must be in grace period. - entitlement = codersdk.EntitlementGracePeriod - } - if claims.Features.UserLimit > 0 { - entitlements.activeUsers = codersdk.Feature{ - Enabled: true, - Entitlement: entitlement, - } - currentLimit := int64(0) - if entitlements.activeUsers.Limit != nil { - currentLimit = *entitlements.activeUsers.Limit - } - limit := max(currentLimit, claims.Features.UserLimit) - entitlements.activeUsers.Limit = &limit - } - if claims.Features.AuditLog > 0 { - entitlements.auditLogs = entitlement - } - if claims.Features.BrowserOnly > 0 { - entitlements.browserOnly = entitlement - } - if claims.Features.SCIM > 0 { - entitlements.scim = entitlement - } - if claims.Features.WorkspaceQuota > 0 { - entitlements.workspaceQuota = entitlement - } - } - - if entitlements.auditLogs != api.entitlements.auditLogs { - // A flag could be added to the options that would allow disabling - // enhanced audit logging here! - if entitlements.auditLogs != codersdk.EntitlementNotEntitled && api.AuditLogging { - auditor := audit.NewAuditor( + if changed, enabled := featureChanged(codersdk.FeatureAuditLog); changed { + auditor := agplaudit.NewNop() + if enabled { + auditor = audit.NewAuditor( audit.DefaultFilter, backends.NewPostgres(api.Database, true), backends.NewSlog(api.Logger), ) - api.AGPL.Auditor.Store(&auditor) } + api.AGPL.Auditor.Store(&auditor) } - if entitlements.browserOnly != api.entitlements.browserOnly { + if changed, enabled := featureChanged(codersdk.FeatureBrowserOnly); changed { var handler func(rw http.ResponseWriter) bool - if entitlements.browserOnly != codersdk.EntitlementNotEntitled && api.BrowserOnly { + if enabled { handler = api.shouldBlockNonBrowserConnections } api.AGPL.WorkspaceClientCoordinateOverride.Store(&handler) } - if entitlements.workspaceQuota != api.entitlements.workspaceQuota { - if entitlements.workspaceQuota != codersdk.EntitlementNotEntitled && api.UserWorkspaceQuota > 0 { - enforcer := NewEnforcer(api.Options.UserWorkspaceQuota) - api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer) + if changed, enabled := featureChanged(codersdk.FeatureWorkspaceQuota); changed { + enforcer := workspacequota.NewNop() + if enabled { + enforcer = NewEnforcer(api.Options.UserWorkspaceQuota) } + api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer) } api.entitlements = entitlements @@ -235,82 +180,7 @@ func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) { api.entitlementsMu.RLock() entitlements := api.entitlements api.entitlementsMu.RUnlock() - - resp := codersdk.Entitlements{ - Features: make(map[string]codersdk.Feature), - Warnings: make([]string, 0), - HasLicense: entitlements.hasLicense, - Experimental: api.Experimental, - } - - if entitlements.activeUsers.Limit != nil { - activeUserCount, err := api.Database.GetActiveUserCount(ctx) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Unable to query database", - Detail: err.Error(), - }) - return - } - entitlements.activeUsers.Actual = &activeUserCount - if activeUserCount > *entitlements.activeUsers.Limit { - resp.Warnings = append(resp.Warnings, - fmt.Sprintf( - "Your deployment has %d active users but is only licensed for %d.", - activeUserCount, *entitlements.activeUsers.Limit)) - } - } - resp.Features[codersdk.FeatureUserLimit] = entitlements.activeUsers - - // Audit logs - resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ - Entitlement: entitlements.auditLogs, - Enabled: api.AuditLogging, - } - // Audit logging is enabled by default. We don't want to display - // a warning if they don't have a license. - if entitlements.hasLicense && api.AuditLogging { - if entitlements.auditLogs == codersdk.EntitlementNotEntitled { - resp.Warnings = append(resp.Warnings, - "Audit logging is enabled but your license is not entitled to this feature.") - } - if entitlements.auditLogs == codersdk.EntitlementGracePeriod { - resp.Warnings = append(resp.Warnings, - "Audit logging is enabled but your license for this feature is expired.") - } - } - - resp.Features[codersdk.FeatureBrowserOnly] = codersdk.Feature{ - Entitlement: entitlements.browserOnly, - Enabled: api.BrowserOnly, - } - if api.BrowserOnly { - if entitlements.browserOnly == codersdk.EntitlementNotEntitled { - resp.Warnings = append(resp.Warnings, - "Browser only connections are enabled but your license is not entitled to this feature.") - } - if entitlements.browserOnly == codersdk.EntitlementGracePeriod { - resp.Warnings = append(resp.Warnings, - "Browser only connections are enabled but your license for this feature is expired.") - } - } - - resp.Features[codersdk.FeatureWorkspaceQuota] = codersdk.Feature{ - Entitlement: entitlements.workspaceQuota, - Enabled: api.UserWorkspaceQuota > 0, - } - if api.UserWorkspaceQuota > 0 { - if entitlements.workspaceQuota == codersdk.EntitlementNotEntitled { - resp.Warnings = append(resp.Warnings, - "Workspace quotas are enabled but your license is not entitled to this feature.") - } - if entitlements.workspaceQuota == codersdk.EntitlementGracePeriod { - resp.Warnings = append(resp.Warnings, - "Workspace quotas are enabled but your license for this feature is expired.") - } - } - - httpapi.Write(ctx, rw, http.StatusOK, resp) + httpapi.Write(ctx, rw, http.StatusOK, entitlements) } func (api *API) runEntitlementsLoop(ctx context.Context) { @@ -374,10 +244,3 @@ func (api *API) runEntitlementsLoop(ctx context.Context) { } } } - -func max(a, b int64) int64 { - if a > b { - return a - } - return b -} diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 1dfbb247f4..da397bc39d 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -86,48 +86,6 @@ func TestEntitlements(t *testing.T) { assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement) assert.True(t, al.Enabled) }) - t.Run("Warnings", func(t *testing.T) { - t.Parallel() - client := coderdenttest.New(t, &coderdenttest.Options{ - AuditLogging: true, - BrowserOnly: true, - }) - first := coderdtest.CreateFirstUser(t, client) - for i := 0; i < 4; i++ { - coderdtest.CreateAnotherUser(t, client, first.OrganizationID) - } - coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - UserLimit: 4, - AuditLog: true, - BrowserOnly: true, - GraceAt: time.Now().Add(-time.Second), - }) - res, err := client.Entitlements(context.Background()) - require.NoError(t, err) - assert.True(t, res.HasLicense) - ul := res.Features[codersdk.FeatureUserLimit] - assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement) - assert.Equal(t, int64(4), *ul.Limit) - assert.Equal(t, int64(5), *ul.Actual) - assert.True(t, ul.Enabled) - al := res.Features[codersdk.FeatureAuditLog] - assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement) - assert.True(t, al.Enabled) - assert.Nil(t, al.Limit) - assert.Nil(t, al.Actual) - bo := res.Features[codersdk.FeatureBrowserOnly] - assert.Equal(t, codersdk.EntitlementGracePeriod, bo.Entitlement) - assert.True(t, bo.Enabled) - assert.Nil(t, bo.Limit) - assert.Nil(t, bo.Actual) - assert.Len(t, res.Warnings, 3) - assert.Contains(t, res.Warnings, - "Your deployment has 5 active users but is only licensed for 4.") - assert.Contains(t, res.Warnings, - "Audit logging is enabled but your license for this feature is expired.") - assert.Contains(t, res.Warnings, - "Browser only connections are enabled but your license for this feature is expired.") - }) t.Run("Pubsub", func(t *testing.T) { t.Parallel() client, _, api := coderdenttest.NewWithAPI(t, nil) diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index acf3a206ef..90d09fd5c9 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -15,6 +15,7 @@ import ( "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/enterprise/coderd/license" ) const ( @@ -24,6 +25,8 @@ const ( var ( testPrivateKey ed25519.PrivateKey testPublicKey ed25519.PublicKey + + Keys = map[string]ed25519.PublicKey{} ) func init() { @@ -32,6 +35,7 @@ func init() { if err != nil { panic(err) } + Keys[testKeyID] = testPublicKey } type Options struct { @@ -64,9 +68,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c UserWorkspaceQuota: options.UserWorkspaceQuota, Options: oop, EntitlementsUpdateInterval: options.EntitlementsUpdateInterval, - Keys: map[string]ed25519.PublicKey{ - testKeyID: testPublicKey, - }, + Keys: Keys, }) assert.NoError(t, err) srv.Config.Handler = coderAPI.AGPL.RootHandler @@ -85,6 +87,8 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c type LicenseOptions struct { AccountType string AccountID string + Trial bool + AllFeatures bool GraceAt time.Time ExpiresAt time.Time UserLimit int64 @@ -96,11 +100,11 @@ type LicenseOptions struct { // AddLicense generates a new license with the options provided and inserts it. func AddLicense(t *testing.T, client *codersdk.Client, options LicenseOptions) codersdk.License { - license, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ + l, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{ License: GenerateLicense(t, options), }) require.NoError(t, err) - return license + return l } // GenerateLicense returns a signed JWT using the test key. @@ -128,7 +132,7 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { workspaceQuota = 1 } - c := &coderd.Claims{ + c := &license.Claims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: "test@testing.test", ExpiresAt: jwt.NewNumericDate(options.ExpiresAt), @@ -138,8 +142,10 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { LicenseExpires: jwt.NewNumericDate(options.GraceAt), AccountType: options.AccountType, AccountID: options.AccountID, - Version: coderd.CurrentVersion, - Features: coderd.Features{ + Trial: options.Trial, + Version: license.CurrentVersion, + AllFeatures: options.AllFeatures, + Features: license.Features{ UserLimit: options.UserLimit, AuditLog: auditLog, BrowserOnly: browserOnly, @@ -148,7 +154,7 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { }, } tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) - tok.Header[coderd.HeaderKeyID] = testKeyID + tok.Header[license.HeaderKeyID] = testKeyID signedTok, err := tok.SignedString(testPrivateKey) require.NoError(t, err) return signedTok diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go new file mode 100644 index 0000000000..213d5dafb8 --- /dev/null +++ b/enterprise/coderd/license/license.go @@ -0,0 +1,243 @@ +package license + +import ( + "context" + "crypto/ed25519" + "fmt" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/codersdk" +) + +// Entitlements processes licenses to return whether features are enabled or not. +func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, keys map[string]ed25519.PublicKey, enablements map[string]bool) (codersdk.Entitlements, error) { + now := time.Now() + // Default all entitlements to be disabled. + entitlements := codersdk.Entitlements{ + Features: map[string]codersdk.Feature{}, + Warnings: []string{}, + } + for _, featureName := range codersdk.FeatureNames { + entitlements.Features[featureName] = codersdk.Feature{ + Entitlement: codersdk.EntitlementNotEntitled, + Enabled: enablements[featureName], + } + } + + licenses, err := db.GetUnexpiredLicenses(ctx) + if err != nil { + return entitlements, err + } + + activeUserCount, err := db.GetActiveUserCount(ctx) + if err != nil { + return entitlements, xerrors.Errorf("query active user count: %w", err) + } + + allFeatures := false + + // Here we loop through licenses to detect enabled features. + for _, l := range licenses { + claims, err := validateDBLicense(l, keys) + if err != nil { + logger.Debug(ctx, "skipping invalid license", + slog.F("id", l.ID), slog.Error(err)) + continue + } + entitlements.HasLicense = true + entitlement := codersdk.EntitlementEntitled + entitlements.Trial = claims.Trial + if now.After(claims.LicenseExpires.Time) { + // if the grace period were over, the validation fails, so if we are after + // LicenseExpires we must be in grace period. + entitlement = codersdk.EntitlementGracePeriod + } + if claims.Features.UserLimit > 0 { + entitlements.Features[codersdk.FeatureUserLimit] = codersdk.Feature{ + Enabled: true, + Entitlement: entitlement, + Limit: &claims.Features.UserLimit, + Actual: &activeUserCount, + } + if activeUserCount > claims.Features.UserLimit { + entitlements.Warnings = append(entitlements.Warnings, fmt.Sprintf( + "Your deployment has %d active users but is only licensed for %d.", + activeUserCount, claims.Features.UserLimit)) + } + } + if claims.Features.AuditLog > 0 { + entitlements.Features[codersdk.FeatureAuditLog] = codersdk.Feature{ + Entitlement: entitlement, + Enabled: enablements[codersdk.FeatureAuditLog], + } + } + if claims.Features.BrowserOnly > 0 { + entitlements.Features[codersdk.FeatureBrowserOnly] = codersdk.Feature{ + Entitlement: entitlement, + Enabled: enablements[codersdk.FeatureBrowserOnly], + } + } + if claims.Features.SCIM > 0 { + entitlements.Features[codersdk.FeatureSCIM] = codersdk.Feature{ + Entitlement: entitlement, + Enabled: enablements[codersdk.FeatureSCIM], + } + } + if claims.Features.WorkspaceQuota > 0 { + entitlements.Features[codersdk.FeatureWorkspaceQuota] = codersdk.Feature{ + Entitlement: entitlement, + Enabled: enablements[codersdk.FeatureWorkspaceQuota], + } + } + if claims.AllFeatures { + allFeatures = true + } + } + + if allFeatures { + for _, featureName := range codersdk.FeatureNames { + // No user limit! + if featureName == codersdk.FeatureUserLimit { + continue + } + feature := entitlements.Features[featureName] + feature.Entitlement = codersdk.EntitlementEntitled + entitlements.Features[featureName] = feature + } + } + + if entitlements.HasLicense { + for _, featureName := range codersdk.FeatureNames { + // The user limit has it's own warnings! + if featureName == codersdk.FeatureUserLimit { + continue + } + feature := entitlements.Features[featureName] + if !feature.Enabled { + continue + } + niceName := strings.Title(strings.ReplaceAll(featureName, "_", " ")) + switch feature.Entitlement { + case codersdk.EntitlementNotEntitled: + entitlements.Warnings = append(entitlements.Warnings, + fmt.Sprintf("%s is enabled but your license is not entitled to this feature.", niceName)) + // Disable the feature and add a warning... + feature.Enabled = false + entitlements.Features[featureName] = feature + case codersdk.EntitlementGracePeriod: + entitlements.Warnings = append(entitlements.Warnings, + fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName)) + default: + } + } + } + + return entitlements, nil +} + +const ( + CurrentVersion = 3 + HeaderKeyID = "kid" + AccountTypeSalesforce = "salesforce" + VersionClaim = "version" +) + +var ( + ValidMethods = []string{"EdDSA"} + + ErrInvalidVersion = xerrors.New("license must be version 3") + ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) + ErrMissingLicenseExpires = xerrors.New("license missing license_expires") +) + +type Features struct { + UserLimit int64 `json:"user_limit"` + AuditLog int64 `json:"audit_log"` + BrowserOnly int64 `json:"browser_only"` + SCIM int64 `json:"scim"` + WorkspaceQuota int64 `json:"workspace_quota"` +} + +type Claims struct { + jwt.RegisteredClaims + // LicenseExpires is the end of the legit license term, and the start of the grace period, if + // there is one. The standard JWT claim "exp" (ExpiresAt in jwt.RegisteredClaims, above) is + // the end of the grace period (identical to LicenseExpires if there is no grace period). + // The reason we use the standard claim for the end of the grace period is that we want JWT + // processing libraries to consider the token "valid" until then. + LicenseExpires *jwt.NumericDate `json:"license_expires,omitempty"` + AccountType string `json:"account_type,omitempty"` + AccountID string `json:"account_id,omitempty"` + Trial bool `json:"trial"` + AllFeatures bool `json:"all_features"` + Version uint64 `json:"version"` + Features Features `json:"features"` +} + +// Parse consumes a license and returns the claims. +func Parse(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) { + tok, err := jwt.Parse( + l, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) + if err != nil { + return nil, err + } + if claims, ok := tok.Claims.(jwt.MapClaims); ok && tok.Valid { + version, ok := claims[VersionClaim].(float64) + if !ok { + return nil, ErrInvalidVersion + } + if int64(version) != CurrentVersion { + return nil, ErrInvalidVersion + } + return claims, nil + } + return nil, xerrors.New("unable to parse Claims") +} + +// validateDBLicense validates a database.License record, and if valid, returns the claims. If +// unparsable or invalid, it returns an error +func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { + tok, err := jwt.ParseWithClaims( + l.JWT, + &Claims{}, + keyFunc(keys), + jwt.WithValidMethods(ValidMethods), + ) + if err != nil { + return nil, err + } + if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { + if claims.Version != uint64(CurrentVersion) { + return nil, ErrInvalidVersion + } + if claims.LicenseExpires == nil { + return nil, ErrMissingLicenseExpires + } + return claims, nil + } + return nil, xerrors.New("unable to parse Claims") +} + +func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { + return func(j *jwt.Token) (interface{}, error) { + keyID, ok := j.Header[HeaderKeyID].(string) + if !ok { + return nil, ErrMissingKeyID + } + k, ok := keys[keyID] + if !ok { + return nil, xerrors.Errorf("no key with ID %s", keyID) + } + return k, nil + } +} diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go new file mode 100644 index 0000000000..04141718e6 --- /dev/null +++ b/enterprise/coderd/license/license_test.go @@ -0,0 +1,189 @@ +package license_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/enterprise/coderd/license" +) + +func TestEntitlements(t *testing.T) { + t.Parallel() + all := map[string]bool{ + codersdk.FeatureAuditLog: true, + codersdk.FeatureBrowserOnly: true, + codersdk.FeatureSCIM: true, + codersdk.FeatureWorkspaceQuota: true, + } + + t.Run("Defaults", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + require.NoError(t, err) + require.False(t, entitlements.HasLicense) + require.False(t, entitlements.Trial) + for _, featureName := range codersdk.FeatureNames { + require.False(t, entitlements.Features[featureName].Enabled) + require.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[featureName].Entitlement) + } + }) + t.Run("SingleLicenseNothing", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}), + Exp: time.Now().Add(time.Hour), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.False(t, entitlements.Trial) + for _, featureName := range codersdk.FeatureNames { + require.False(t, entitlements.Features[featureName].Enabled) + require.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[featureName].Entitlement) + } + }) + t.Run("SingleLicenseAll", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + WorkspaceQuota: true, + }), + Exp: time.Now().Add(time.Hour), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.False(t, entitlements.Trial) + for _, featureName := range codersdk.FeatureNames { + require.Equal(t, codersdk.EntitlementEntitled, entitlements.Features[featureName].Entitlement) + } + }) + t.Run("SingleLicenseGrace", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + WorkspaceQuota: true, + GraceAt: time.Now().Add(-time.Hour), + ExpiresAt: time.Now().Add(time.Hour), + }), + Exp: time.Now().Add(time.Hour), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.False(t, entitlements.Trial) + for _, featureName := range codersdk.FeatureNames { + if featureName == codersdk.FeatureUserLimit { + continue + } + niceName := strings.Title(strings.ReplaceAll(featureName, "_", " ")) + require.Equal(t, codersdk.EntitlementGracePeriod, entitlements.Features[featureName].Entitlement) + require.Contains(t, entitlements.Warnings, fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName)) + } + }) + t.Run("SingleLicenseNotEntitled", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}), + Exp: time.Now().Add(time.Hour), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.False(t, entitlements.Trial) + for _, featureName := range codersdk.FeatureNames { + if featureName == codersdk.FeatureUserLimit { + continue + } + niceName := strings.Title(strings.ReplaceAll(featureName, "_", " ")) + // Ensures features that are not entitled are properly disabled. + require.False(t, entitlements.Features[featureName].Enabled) + require.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[featureName].Entitlement) + require.Contains(t, entitlements.Warnings, fmt.Sprintf("%s is enabled but your license is not entitled to this feature.", niceName)) + } + }) + t.Run("TooManyUsers", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertUser(context.Background(), database.InsertUserParams{}) + db.InsertUser(context.Background(), database.InsertUserParams{}) + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + UserLimit: 1, + }), + Exp: time.Now().Add(time.Hour), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.Contains(t, entitlements.Warnings, "Your deployment has 2 active users but is only licensed for 1.") + }) + t.Run("MultipleLicenseEnabled", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + // One trial + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + Exp: time.Now().Add(time.Hour), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + Trial: true, + }), + }) + // One not + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + Exp: time.Now().Add(time.Hour), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + Trial: false, + }), + }) + + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.False(t, entitlements.Trial) + }) + + t.Run("AllFeatures", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + Exp: time.Now().Add(time.Hour), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AllFeatures: true, + }), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.False(t, entitlements.Trial) + for _, featureName := range codersdk.FeatureNames { + if featureName == codersdk.FeatureUserLimit { + continue + } + require.True(t, entitlements.Features[featureName].Enabled) + require.Equal(t, codersdk.EntitlementEntitled, entitlements.Features[featureName].Entitlement) + } + }) +} diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 9d43bbe6c2..f56df142cd 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -23,19 +23,13 @@ import ( "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd/license" ) const ( - CurrentVersion = 3 - HeaderKeyID = "kid" - AccountTypeSalesforce = "salesforce" - VersionClaim = "version" - PubsubEventLicenses = "licenses" ) -var ValidMethods = []string{"EdDSA"} - // key20220812 is the Coder license public key with id 2022-08-12 used to validate licenses signed // by our signing infrastructure // @@ -44,34 +38,6 @@ var key20220812 []byte var Keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)} -type Features struct { - UserLimit int64 `json:"user_limit"` - AuditLog int64 `json:"audit_log"` - BrowserOnly int64 `json:"browser_only"` - SCIM int64 `json:"scim"` - WorkspaceQuota int64 `json:"workspace_quota"` -} - -type Claims struct { - jwt.RegisteredClaims - // LicenseExpires is the end of the legit license term, and the start of the grace period, if - // there is one. The standard JWT claim "exp" (ExpiresAt in jwt.RegisteredClaims, above) is - // the end of the grace period (identical to LicenseExpires if there is no grace period). - // The reason we use the standard claim for the end of the grace period is that we want JWT - // processing libraries to consider the token "valid" until then. - LicenseExpires *jwt.NumericDate `json:"license_expires,omitempty"` - AccountType string `json:"account_type,omitempty"` - AccountID string `json:"account_id,omitempty"` - Version uint64 `json:"version"` - Features Features `json:"features"` -} - -var ( - ErrInvalidVersion = xerrors.New("license must be version 3") - ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) - ErrMissingLicenseExpires = xerrors.New("license missing license_expires") -) - // postLicense adds a new Enterprise license to the cluster. We allow multiple different licenses // in the cluster at one time for several reasons: // @@ -93,7 +59,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) { return } - claims, err := parseLicense(addLicense.License, api.Keys) + claims, err := license.Parse(addLicense.License, api.Keys) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid license", @@ -261,65 +227,3 @@ func decodeClaims(l database.License) (jwt.MapClaims, error) { err = d.Decode(&c) return c, err } - -// parseLicense parses the license and returns the claims. If the license's signature is invalid or -// is not parsable, an error is returned. -func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) { - tok, err := jwt.Parse( - l, - keyFunc(keys), - jwt.WithValidMethods(ValidMethods), - ) - if err != nil { - return nil, err - } - if claims, ok := tok.Claims.(jwt.MapClaims); ok && tok.Valid { - version, ok := claims[VersionClaim].(float64) - if !ok { - return nil, ErrInvalidVersion - } - if int64(version) != CurrentVersion { - return nil, ErrInvalidVersion - } - return claims, nil - } - return nil, xerrors.New("unable to parse Claims") -} - -// validateDBLicense validates a database.License record, and if valid, returns the claims. If -// unparsable or invalid, it returns an error -func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) { - tok, err := jwt.ParseWithClaims( - l.JWT, - &Claims{}, - keyFunc(keys), - jwt.WithValidMethods(ValidMethods), - ) - if err != nil { - return nil, err - } - if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { - if claims.Version != uint64(CurrentVersion) { - return nil, ErrInvalidVersion - } - if claims.LicenseExpires == nil { - return nil, ErrMissingLicenseExpires - } - return claims, nil - } - return nil, xerrors.New("unable to parse Claims") -} - -func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { - return func(j *jwt.Token) (interface{}, error) { - keyID, ok := j.Header[HeaderKeyID].(string) - if !ok { - return nil, ErrMissingKeyID - } - k, ok := keys[keyID] - if !ok { - return nil, xerrors.Errorf("no key with ID %s", keyID) - } - return k, nil - } -} diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go index c4b7111597..59d36cc915 100644 --- a/enterprise/coderd/licenses_test.go +++ b/enterprise/coderd/licenses_test.go @@ -12,8 +12,8 @@ import ( "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" - "github.com/coder/coder/enterprise/coderd" "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/enterprise/coderd/license" "github.com/coder/coder/testutil" ) @@ -25,7 +25,7 @@ func TestPostLicense(t *testing.T) { client := coderdenttest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - AccountType: coderd.AccountTypeSalesforce, + AccountType: license.AccountTypeSalesforce, AccountID: "testing", AuditLog: true, }) @@ -89,6 +89,7 @@ func TestGetLicense(t *testing.T) { AuditLog: true, SCIM: true, BrowserOnly: true, + Trial: true, UserLimit: 200, }) @@ -106,6 +107,7 @@ func TestGetLicense(t *testing.T) { }, licenses[0].Claims["features"]) assert.Equal(t, int32(2), licenses[1].ID) assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) + assert.Equal(t, true, licenses[1].Claims["trial"]) assert.Equal(t, map[string]interface{}{ codersdk.FeatureUserLimit: json.Number("200"), codersdk.FeatureAuditLog: json.Number("1"), diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index 1d01a5601d..7ee4a41a79 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -21,10 +21,10 @@ import ( func (api *API) scimEnabledMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { api.entitlementsMu.RLock() - scim := api.entitlements.scim + scim := api.entitlements.Features[codersdk.FeatureSCIM].Enabled api.entitlementsMu.RUnlock() - if scim == codersdk.EntitlementNotEntitled { + if !scim { httpapi.RouteNotFound(rw) return } diff --git a/enterprise/coderd/workspaceagents.go b/enterprise/coderd/workspaceagents.go index def912a90f..d5595e6570 100644 --- a/enterprise/coderd/workspaceagents.go +++ b/enterprise/coderd/workspaceagents.go @@ -10,9 +10,9 @@ import ( func (api *API) shouldBlockNonBrowserConnections(rw http.ResponseWriter) bool { api.entitlementsMu.Lock() - browserOnly := api.entitlements.browserOnly + browserOnly := api.entitlements.Features[codersdk.FeatureBrowserOnly].Enabled api.entitlementsMu.Unlock() - if api.BrowserOnly && browserOnly != codersdk.EntitlementNotEntitled { + if browserOnly { httpapi.Write(context.Background(), rw, http.StatusConflict, codersdk.Response{ Message: "Non-browser connections are disabled for your deployment.", }) diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 2ee6fac685..ed03fb145a 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -30,6 +30,7 @@ export const defaultEntitlements = (): TypesGen.Entitlements => { has_license: false, warnings: [], experimental: false, + trial: false, } } diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 629f1963e1..b14154e096 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -245,6 +245,7 @@ export interface Entitlements { readonly warnings: string[] readonly has_license: boolean readonly experimental: boolean + readonly trial: boolean } // From codersdk/features.go diff --git a/site/src/testHelpers/entities.ts b/site/src/testHelpers/entities.ts index a1d56f1d9a..36302613f3 100644 --- a/site/src/testHelpers/entities.ts +++ b/site/src/testHelpers/entities.ts @@ -777,12 +777,14 @@ export const MockEntitlements: TypesGen.Entitlements = { has_license: false, features: {}, experimental: false, + trial: false, } export const MockEntitlementsWithWarnings: TypesGen.Entitlements = { warnings: ["You are over your active user limit.", "And another thing."], has_license: true, experimental: false, + trial: false, features: { user_limit: { enabled: true, @@ -805,6 +807,7 @@ export const MockEntitlementsWithAuditLog: TypesGen.Entitlements = { warnings: [], has_license: true, experimental: false, + trial: false, features: { audit_log: { enabled: true, diff --git a/site/src/xServices/entitlements/entitlementsXService.ts b/site/src/xServices/entitlements/entitlementsXService.ts index 1e93c7641a..83ed44d120 100644 --- a/site/src/xServices/entitlements/entitlementsXService.ts +++ b/site/src/xServices/entitlements/entitlementsXService.ts @@ -24,6 +24,7 @@ const emptyEntitlements = { features: {}, has_license: false, experimental: false, + trial: false, } export const entitlementsMachine = createMachine(