feat: Add trial property to licenses (#4372)

* feat: Add trial property to licenses

This allows the frontend to display whether the user is on
a trial license of Coder. This is useful for advertising
Enterprise functionality.

* Improve tests for license enablement code

* Add all features property
This commit is contained in:
Kyle Carberry 2022-10-06 19:28:22 -05:00 committed by GitHub
parent 05670d133e
commit 915bb41ea2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 507 additions and 332 deletions

View File

@ -9,8 +9,8 @@
"cliflag",
"cliui",
"codecov",
"Codespaces",
"coderd",
"coderdenttest",
"coderdtest",
"codersdk",
"cronstrue",
@ -24,6 +24,7 @@
"drpcmux",
"drpcserver",
"Dsts",
"enablements",
"fatih",
"Formik",
"gitsshkey",

View File

@ -42,6 +42,7 @@ type Entitlements struct {
Warnings []string `json:"warnings"`
HasLicense bool `json:"has_license"`
Experimental bool `json:"experimental"`
Trial bool `json:"trial"`
}
func (c *Client) Entitlements(ctx context.Context) (Entitlements, error) {

View File

@ -57,7 +57,7 @@ func TestFeaturesList(t *testing.T) {
var entitlements codersdk.Entitlements
err := json.Unmarshal(buf.Bytes(), &entitlements)
require.NoError(t, err, "unmarshal JSON output")
assert.Len(t, entitlements.Features, 4)
assert.Len(t, entitlements.Features, 5)
assert.Empty(t, entitlements.Warnings)
assert.Equal(t, codersdk.EntitlementNotEntitled,
entitlements.Features[codersdk.FeatureUserLimit].Entitlement)
@ -67,6 +67,8 @@ func TestFeaturesList(t *testing.T) {
entitlements.Features[codersdk.FeatureBrowserOnly].Entitlement)
assert.Equal(t, codersdk.EntitlementNotEntitled,
entitlements.Features[codersdk.FeatureWorkspaceQuota].Entitlement)
assert.Equal(t, codersdk.EntitlementNotEntitled,
entitlements.Features[codersdk.FeatureSCIM].Entitlement)
assert.False(t, entitlements.HasLicense)
assert.False(t, entitlements.Experimental)
})

View File

@ -3,7 +3,6 @@ package coderd
import (
"context"
"crypto/ed25519"
"fmt"
"net/http"
"sync"
"time"
@ -15,11 +14,14 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd"
agplaudit "github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/workspacequota"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/audit"
"github.com/coder/coder/enterprise/audit/backends"
"github.com/coder/coder/enterprise/coderd/license"
)
// New constructs an Enterprise coderd API instance.
@ -34,19 +36,8 @@ func New(ctx context.Context, options *Options) (*API, error) {
}
ctx, cancelFunc := context.WithCancel(ctx)
api := &API{
AGPL: coderd.New(options.Options),
Options: options,
entitlements: entitlements{
activeUsers: codersdk.Feature{
Entitlement: codersdk.EntitlementNotEntitled,
Enabled: false,
},
auditLogs: codersdk.EntitlementNotEntitled,
browserOnly: codersdk.EntitlementNotEntitled,
scim: codersdk.EntitlementNotEntitled,
workspaceQuota: codersdk.EntitlementNotEntitled,
},
AGPL: coderd.New(options.Options),
Options: options,
cancelEntitlementsLoop: cancelFunc,
}
oauthConfigs := &httpmw.OAuth2Configs{
@ -117,16 +108,7 @@ type API struct {
cancelEntitlementsLoop func()
entitlementsMu sync.RWMutex
entitlements entitlements
}
type entitlements struct {
hasLicense bool
activeUsers codersdk.Feature
auditLogs codersdk.Entitlement
browserOnly codersdk.Entitlement
scim codersdk.Entitlement
workspaceQuota codersdk.Entitlement
entitlements codersdk.Entitlements
}
func (api *API) Close() error {
@ -135,94 +117,57 @@ func (api *API) Close() error {
}
func (api *API) updateEntitlements(ctx context.Context) error {
licenses, err := api.Database.GetUnexpiredLicenses(ctx)
api.entitlementsMu.Lock()
defer api.entitlementsMu.Unlock()
entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, api.Keys, map[string]bool{
codersdk.FeatureAuditLog: api.AuditLogging,
codersdk.FeatureBrowserOnly: api.BrowserOnly,
codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0,
codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0,
})
if err != nil {
return err
}
api.entitlementsMu.Lock()
defer api.entitlementsMu.Unlock()
now := time.Now()
// Default all entitlements to be disabled.
entitlements := entitlements{
hasLicense: false,
activeUsers: codersdk.Feature{
Enabled: false,
Entitlement: codersdk.EntitlementNotEntitled,
},
auditLogs: codersdk.EntitlementNotEntitled,
scim: codersdk.EntitlementNotEntitled,
browserOnly: codersdk.EntitlementNotEntitled,
workspaceQuota: codersdk.EntitlementNotEntitled,
featureChanged := func(featureName string) (changed bool, enabled bool) {
if api.entitlements.Features == nil {
return true, entitlements.Features[featureName].Enabled
}
oldFeature := api.entitlements.Features[featureName]
newFeature := entitlements.Features[featureName]
if oldFeature.Enabled != newFeature.Enabled {
return true, newFeature.Enabled
}
return false, newFeature.Enabled
}
// Here we loop through licenses to detect enabled features.
for _, l := range licenses {
claims, err := validateDBLicense(l, api.Keys)
if err != nil {
api.Logger.Debug(ctx, "skipping invalid license",
slog.F("id", l.ID), slog.Error(err))
continue
}
entitlements.hasLicense = true
entitlement := codersdk.EntitlementEntitled
if now.After(claims.LicenseExpires.Time) {
// if the grace period were over, the validation fails, so if we are after
// LicenseExpires we must be in grace period.
entitlement = codersdk.EntitlementGracePeriod
}
if claims.Features.UserLimit > 0 {
entitlements.activeUsers = codersdk.Feature{
Enabled: true,
Entitlement: entitlement,
}
currentLimit := int64(0)
if entitlements.activeUsers.Limit != nil {
currentLimit = *entitlements.activeUsers.Limit
}
limit := max(currentLimit, claims.Features.UserLimit)
entitlements.activeUsers.Limit = &limit
}
if claims.Features.AuditLog > 0 {
entitlements.auditLogs = entitlement
}
if claims.Features.BrowserOnly > 0 {
entitlements.browserOnly = entitlement
}
if claims.Features.SCIM > 0 {
entitlements.scim = entitlement
}
if claims.Features.WorkspaceQuota > 0 {
entitlements.workspaceQuota = entitlement
}
}
if entitlements.auditLogs != api.entitlements.auditLogs {
// A flag could be added to the options that would allow disabling
// enhanced audit logging here!
if entitlements.auditLogs != codersdk.EntitlementNotEntitled && api.AuditLogging {
auditor := audit.NewAuditor(
if changed, enabled := featureChanged(codersdk.FeatureAuditLog); changed {
auditor := agplaudit.NewNop()
if enabled {
auditor = audit.NewAuditor(
audit.DefaultFilter,
backends.NewPostgres(api.Database, true),
backends.NewSlog(api.Logger),
)
api.AGPL.Auditor.Store(&auditor)
}
api.AGPL.Auditor.Store(&auditor)
}
if entitlements.browserOnly != api.entitlements.browserOnly {
if changed, enabled := featureChanged(codersdk.FeatureBrowserOnly); changed {
var handler func(rw http.ResponseWriter) bool
if entitlements.browserOnly != codersdk.EntitlementNotEntitled && api.BrowserOnly {
if enabled {
handler = api.shouldBlockNonBrowserConnections
}
api.AGPL.WorkspaceClientCoordinateOverride.Store(&handler)
}
if entitlements.workspaceQuota != api.entitlements.workspaceQuota {
if entitlements.workspaceQuota != codersdk.EntitlementNotEntitled && api.UserWorkspaceQuota > 0 {
enforcer := NewEnforcer(api.Options.UserWorkspaceQuota)
api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer)
if changed, enabled := featureChanged(codersdk.FeatureWorkspaceQuota); changed {
enforcer := workspacequota.NewNop()
if enabled {
enforcer = NewEnforcer(api.Options.UserWorkspaceQuota)
}
api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer)
}
api.entitlements = entitlements
@ -235,82 +180,7 @@ func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
api.entitlementsMu.RLock()
entitlements := api.entitlements
api.entitlementsMu.RUnlock()
resp := codersdk.Entitlements{
Features: make(map[string]codersdk.Feature),
Warnings: make([]string, 0),
HasLicense: entitlements.hasLicense,
Experimental: api.Experimental,
}
if entitlements.activeUsers.Limit != nil {
activeUserCount, err := api.Database.GetActiveUserCount(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Unable to query database",
Detail: err.Error(),
})
return
}
entitlements.activeUsers.Actual = &activeUserCount
if activeUserCount > *entitlements.activeUsers.Limit {
resp.Warnings = append(resp.Warnings,
fmt.Sprintf(
"Your deployment has %d active users but is only licensed for %d.",
activeUserCount, *entitlements.activeUsers.Limit))
}
}
resp.Features[codersdk.FeatureUserLimit] = entitlements.activeUsers
// Audit logs
resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{
Entitlement: entitlements.auditLogs,
Enabled: api.AuditLogging,
}
// Audit logging is enabled by default. We don't want to display
// a warning if they don't have a license.
if entitlements.hasLicense && api.AuditLogging {
if entitlements.auditLogs == codersdk.EntitlementNotEntitled {
resp.Warnings = append(resp.Warnings,
"Audit logging is enabled but your license is not entitled to this feature.")
}
if entitlements.auditLogs == codersdk.EntitlementGracePeriod {
resp.Warnings = append(resp.Warnings,
"Audit logging is enabled but your license for this feature is expired.")
}
}
resp.Features[codersdk.FeatureBrowserOnly] = codersdk.Feature{
Entitlement: entitlements.browserOnly,
Enabled: api.BrowserOnly,
}
if api.BrowserOnly {
if entitlements.browserOnly == codersdk.EntitlementNotEntitled {
resp.Warnings = append(resp.Warnings,
"Browser only connections are enabled but your license is not entitled to this feature.")
}
if entitlements.browserOnly == codersdk.EntitlementGracePeriod {
resp.Warnings = append(resp.Warnings,
"Browser only connections are enabled but your license for this feature is expired.")
}
}
resp.Features[codersdk.FeatureWorkspaceQuota] = codersdk.Feature{
Entitlement: entitlements.workspaceQuota,
Enabled: api.UserWorkspaceQuota > 0,
}
if api.UserWorkspaceQuota > 0 {
if entitlements.workspaceQuota == codersdk.EntitlementNotEntitled {
resp.Warnings = append(resp.Warnings,
"Workspace quotas are enabled but your license is not entitled to this feature.")
}
if entitlements.workspaceQuota == codersdk.EntitlementGracePeriod {
resp.Warnings = append(resp.Warnings,
"Workspace quotas are enabled but your license for this feature is expired.")
}
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
httpapi.Write(ctx, rw, http.StatusOK, entitlements)
}
func (api *API) runEntitlementsLoop(ctx context.Context) {
@ -374,10 +244,3 @@ func (api *API) runEntitlementsLoop(ctx context.Context) {
}
}
}
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@ -86,48 +86,6 @@ func TestEntitlements(t *testing.T) {
assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement)
assert.True(t, al.Enabled)
})
t.Run("Warnings", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, &coderdenttest.Options{
AuditLogging: true,
BrowserOnly: true,
})
first := coderdtest.CreateFirstUser(t, client)
for i := 0; i < 4; i++ {
coderdtest.CreateAnotherUser(t, client, first.OrganizationID)
}
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
UserLimit: 4,
AuditLog: true,
BrowserOnly: true,
GraceAt: time.Now().Add(-time.Second),
})
res, err := client.Entitlements(context.Background())
require.NoError(t, err)
assert.True(t, res.HasLicense)
ul := res.Features[codersdk.FeatureUserLimit]
assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement)
assert.Equal(t, int64(4), *ul.Limit)
assert.Equal(t, int64(5), *ul.Actual)
assert.True(t, ul.Enabled)
al := res.Features[codersdk.FeatureAuditLog]
assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement)
assert.True(t, al.Enabled)
assert.Nil(t, al.Limit)
assert.Nil(t, al.Actual)
bo := res.Features[codersdk.FeatureBrowserOnly]
assert.Equal(t, codersdk.EntitlementGracePeriod, bo.Entitlement)
assert.True(t, bo.Enabled)
assert.Nil(t, bo.Limit)
assert.Nil(t, bo.Actual)
assert.Len(t, res.Warnings, 3)
assert.Contains(t, res.Warnings,
"Your deployment has 5 active users but is only licensed for 4.")
assert.Contains(t, res.Warnings,
"Audit logging is enabled but your license for this feature is expired.")
assert.Contains(t, res.Warnings,
"Browser only connections are enabled but your license for this feature is expired.")
})
t.Run("Pubsub", func(t *testing.T) {
t.Parallel()
client, _, api := coderdenttest.NewWithAPI(t, nil)

View File

@ -15,6 +15,7 @@ import (
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd"
"github.com/coder/coder/enterprise/coderd/license"
)
const (
@ -24,6 +25,8 @@ const (
var (
testPrivateKey ed25519.PrivateKey
testPublicKey ed25519.PublicKey
Keys = map[string]ed25519.PublicKey{}
)
func init() {
@ -32,6 +35,7 @@ func init() {
if err != nil {
panic(err)
}
Keys[testKeyID] = testPublicKey
}
type Options struct {
@ -64,9 +68,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
UserWorkspaceQuota: options.UserWorkspaceQuota,
Options: oop,
EntitlementsUpdateInterval: options.EntitlementsUpdateInterval,
Keys: map[string]ed25519.PublicKey{
testKeyID: testPublicKey,
},
Keys: Keys,
})
assert.NoError(t, err)
srv.Config.Handler = coderAPI.AGPL.RootHandler
@ -85,6 +87,8 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
type LicenseOptions struct {
AccountType string
AccountID string
Trial bool
AllFeatures bool
GraceAt time.Time
ExpiresAt time.Time
UserLimit int64
@ -96,11 +100,11 @@ type LicenseOptions struct {
// AddLicense generates a new license with the options provided and inserts it.
func AddLicense(t *testing.T, client *codersdk.Client, options LicenseOptions) codersdk.License {
license, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
l, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
License: GenerateLicense(t, options),
})
require.NoError(t, err)
return license
return l
}
// GenerateLicense returns a signed JWT using the test key.
@ -128,7 +132,7 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
workspaceQuota = 1
}
c := &coderd.Claims{
c := &license.Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@testing.test",
ExpiresAt: jwt.NewNumericDate(options.ExpiresAt),
@ -138,8 +142,10 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
LicenseExpires: jwt.NewNumericDate(options.GraceAt),
AccountType: options.AccountType,
AccountID: options.AccountID,
Version: coderd.CurrentVersion,
Features: coderd.Features{
Trial: options.Trial,
Version: license.CurrentVersion,
AllFeatures: options.AllFeatures,
Features: license.Features{
UserLimit: options.UserLimit,
AuditLog: auditLog,
BrowserOnly: browserOnly,
@ -148,7 +154,7 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)
tok.Header[coderd.HeaderKeyID] = testKeyID
tok.Header[license.HeaderKeyID] = testKeyID
signedTok, err := tok.SignedString(testPrivateKey)
require.NoError(t, err)
return signedTok

View File

@ -0,0 +1,243 @@
package license
import (
"context"
"crypto/ed25519"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
)
// Entitlements processes licenses to return whether features are enabled or not.
func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, keys map[string]ed25519.PublicKey, enablements map[string]bool) (codersdk.Entitlements, error) {
now := time.Now()
// Default all entitlements to be disabled.
entitlements := codersdk.Entitlements{
Features: map[string]codersdk.Feature{},
Warnings: []string{},
}
for _, featureName := range codersdk.FeatureNames {
entitlements.Features[featureName] = codersdk.Feature{
Entitlement: codersdk.EntitlementNotEntitled,
Enabled: enablements[featureName],
}
}
licenses, err := db.GetUnexpiredLicenses(ctx)
if err != nil {
return entitlements, err
}
activeUserCount, err := db.GetActiveUserCount(ctx)
if err != nil {
return entitlements, xerrors.Errorf("query active user count: %w", err)
}
allFeatures := false
// Here we loop through licenses to detect enabled features.
for _, l := range licenses {
claims, err := validateDBLicense(l, keys)
if err != nil {
logger.Debug(ctx, "skipping invalid license",
slog.F("id", l.ID), slog.Error(err))
continue
}
entitlements.HasLicense = true
entitlement := codersdk.EntitlementEntitled
entitlements.Trial = claims.Trial
if now.After(claims.LicenseExpires.Time) {
// if the grace period were over, the validation fails, so if we are after
// LicenseExpires we must be in grace period.
entitlement = codersdk.EntitlementGracePeriod
}
if claims.Features.UserLimit > 0 {
entitlements.Features[codersdk.FeatureUserLimit] = codersdk.Feature{
Enabled: true,
Entitlement: entitlement,
Limit: &claims.Features.UserLimit,
Actual: &activeUserCount,
}
if activeUserCount > claims.Features.UserLimit {
entitlements.Warnings = append(entitlements.Warnings, fmt.Sprintf(
"Your deployment has %d active users but is only licensed for %d.",
activeUserCount, claims.Features.UserLimit))
}
}
if claims.Features.AuditLog > 0 {
entitlements.Features[codersdk.FeatureAuditLog] = codersdk.Feature{
Entitlement: entitlement,
Enabled: enablements[codersdk.FeatureAuditLog],
}
}
if claims.Features.BrowserOnly > 0 {
entitlements.Features[codersdk.FeatureBrowserOnly] = codersdk.Feature{
Entitlement: entitlement,
Enabled: enablements[codersdk.FeatureBrowserOnly],
}
}
if claims.Features.SCIM > 0 {
entitlements.Features[codersdk.FeatureSCIM] = codersdk.Feature{
Entitlement: entitlement,
Enabled: enablements[codersdk.FeatureSCIM],
}
}
if claims.Features.WorkspaceQuota > 0 {
entitlements.Features[codersdk.FeatureWorkspaceQuota] = codersdk.Feature{
Entitlement: entitlement,
Enabled: enablements[codersdk.FeatureWorkspaceQuota],
}
}
if claims.AllFeatures {
allFeatures = true
}
}
if allFeatures {
for _, featureName := range codersdk.FeatureNames {
// No user limit!
if featureName == codersdk.FeatureUserLimit {
continue
}
feature := entitlements.Features[featureName]
feature.Entitlement = codersdk.EntitlementEntitled
entitlements.Features[featureName] = feature
}
}
if entitlements.HasLicense {
for _, featureName := range codersdk.FeatureNames {
// The user limit has it's own warnings!
if featureName == codersdk.FeatureUserLimit {
continue
}
feature := entitlements.Features[featureName]
if !feature.Enabled {
continue
}
niceName := strings.Title(strings.ReplaceAll(featureName, "_", " "))
switch feature.Entitlement {
case codersdk.EntitlementNotEntitled:
entitlements.Warnings = append(entitlements.Warnings,
fmt.Sprintf("%s is enabled but your license is not entitled to this feature.", niceName))
// Disable the feature and add a warning...
feature.Enabled = false
entitlements.Features[featureName] = feature
case codersdk.EntitlementGracePeriod:
entitlements.Warnings = append(entitlements.Warnings,
fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName))
default:
}
}
}
return entitlements, nil
}
const (
CurrentVersion = 3
HeaderKeyID = "kid"
AccountTypeSalesforce = "salesforce"
VersionClaim = "version"
)
var (
ValidMethods = []string{"EdDSA"}
ErrInvalidVersion = xerrors.New("license must be version 3")
ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID)
ErrMissingLicenseExpires = xerrors.New("license missing license_expires")
)
type Features struct {
UserLimit int64 `json:"user_limit"`
AuditLog int64 `json:"audit_log"`
BrowserOnly int64 `json:"browser_only"`
SCIM int64 `json:"scim"`
WorkspaceQuota int64 `json:"workspace_quota"`
}
type Claims struct {
jwt.RegisteredClaims
// LicenseExpires is the end of the legit license term, and the start of the grace period, if
// there is one. The standard JWT claim "exp" (ExpiresAt in jwt.RegisteredClaims, above) is
// the end of the grace period (identical to LicenseExpires if there is no grace period).
// The reason we use the standard claim for the end of the grace period is that we want JWT
// processing libraries to consider the token "valid" until then.
LicenseExpires *jwt.NumericDate `json:"license_expires,omitempty"`
AccountType string `json:"account_type,omitempty"`
AccountID string `json:"account_id,omitempty"`
Trial bool `json:"trial"`
AllFeatures bool `json:"all_features"`
Version uint64 `json:"version"`
Features Features `json:"features"`
}
// Parse consumes a license and returns the claims.
func Parse(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) {
tok, err := jwt.Parse(
l,
keyFunc(keys),
jwt.WithValidMethods(ValidMethods),
)
if err != nil {
return nil, err
}
if claims, ok := tok.Claims.(jwt.MapClaims); ok && tok.Valid {
version, ok := claims[VersionClaim].(float64)
if !ok {
return nil, ErrInvalidVersion
}
if int64(version) != CurrentVersion {
return nil, ErrInvalidVersion
}
return claims, nil
}
return nil, xerrors.New("unable to parse Claims")
}
// validateDBLicense validates a database.License record, and if valid, returns the claims. If
// unparsable or invalid, it returns an error
func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) {
tok, err := jwt.ParseWithClaims(
l.JWT,
&Claims{},
keyFunc(keys),
jwt.WithValidMethods(ValidMethods),
)
if err != nil {
return nil, err
}
if claims, ok := tok.Claims.(*Claims); ok && tok.Valid {
if claims.Version != uint64(CurrentVersion) {
return nil, ErrInvalidVersion
}
if claims.LicenseExpires == nil {
return nil, ErrMissingLicenseExpires
}
return claims, nil
}
return nil, xerrors.New("unable to parse Claims")
}
func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) {
return func(j *jwt.Token) (interface{}, error) {
keyID, ok := j.Header[HeaderKeyID].(string)
if !ok {
return nil, ErrMissingKeyID
}
k, ok := keys[keyID]
if !ok {
return nil, xerrors.Errorf("no key with ID %s", keyID)
}
return k, nil
}
}

View File

@ -0,0 +1,189 @@
package license_test
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/enterprise/coderd/license"
)
func TestEntitlements(t *testing.T) {
t.Parallel()
all := map[string]bool{
codersdk.FeatureAuditLog: true,
codersdk.FeatureBrowserOnly: true,
codersdk.FeatureSCIM: true,
codersdk.FeatureWorkspaceQuota: true,
}
t.Run("Defaults", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.False(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
for _, featureName := range codersdk.FeatureNames {
require.False(t, entitlements.Features[featureName].Enabled)
require.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[featureName].Entitlement)
}
})
t.Run("SingleLicenseNothing", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
for _, featureName := range codersdk.FeatureNames {
require.False(t, entitlements.Features[featureName].Enabled)
require.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[featureName].Entitlement)
}
})
t.Run("SingleLicenseAll", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
WorkspaceQuota: true,
}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
for _, featureName := range codersdk.FeatureNames {
require.Equal(t, codersdk.EntitlementEntitled, entitlements.Features[featureName].Entitlement)
}
})
t.Run("SingleLicenseGrace", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
WorkspaceQuota: true,
GraceAt: time.Now().Add(-time.Hour),
ExpiresAt: time.Now().Add(time.Hour),
}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
for _, featureName := range codersdk.FeatureNames {
if featureName == codersdk.FeatureUserLimit {
continue
}
niceName := strings.Title(strings.ReplaceAll(featureName, "_", " "))
require.Equal(t, codersdk.EntitlementGracePeriod, entitlements.Features[featureName].Entitlement)
require.Contains(t, entitlements.Warnings, fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName))
}
})
t.Run("SingleLicenseNotEntitled", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
for _, featureName := range codersdk.FeatureNames {
if featureName == codersdk.FeatureUserLimit {
continue
}
niceName := strings.Title(strings.ReplaceAll(featureName, "_", " "))
// Ensures features that are not entitled are properly disabled.
require.False(t, entitlements.Features[featureName].Enabled)
require.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[featureName].Entitlement)
require.Contains(t, entitlements.Warnings, fmt.Sprintf("%s is enabled but your license is not entitled to this feature.", niceName))
}
})
t.Run("TooManyUsers", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertUser(context.Background(), database.InsertUserParams{})
db.InsertUser(context.Background(), database.InsertUserParams{})
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
UserLimit: 1,
}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.Contains(t, entitlements.Warnings, "Your deployment has 2 active users but is only licensed for 1.")
})
t.Run("MultipleLicenseEnabled", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
// One trial
db.InsertLicense(context.Background(), database.InsertLicenseParams{
Exp: time.Now().Add(time.Hour),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
Trial: true,
}),
})
// One not
db.InsertLicense(context.Background(), database.InsertLicenseParams{
Exp: time.Now().Add(time.Hour),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
Trial: false,
}),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
})
t.Run("AllFeatures", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
Exp: time.Now().Add(time.Hour),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
AllFeatures: true,
}),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
for _, featureName := range codersdk.FeatureNames {
if featureName == codersdk.FeatureUserLimit {
continue
}
require.True(t, entitlements.Features[featureName].Enabled)
require.Equal(t, codersdk.EntitlementEntitled, entitlements.Features[featureName].Entitlement)
}
})
}

View File

@ -23,19 +23,13 @@ import (
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd/license"
)
const (
CurrentVersion = 3
HeaderKeyID = "kid"
AccountTypeSalesforce = "salesforce"
VersionClaim = "version"
PubsubEventLicenses = "licenses"
)
var ValidMethods = []string{"EdDSA"}
// key20220812 is the Coder license public key with id 2022-08-12 used to validate licenses signed
// by our signing infrastructure
//
@ -44,34 +38,6 @@ var key20220812 []byte
var Keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)}
type Features struct {
UserLimit int64 `json:"user_limit"`
AuditLog int64 `json:"audit_log"`
BrowserOnly int64 `json:"browser_only"`
SCIM int64 `json:"scim"`
WorkspaceQuota int64 `json:"workspace_quota"`
}
type Claims struct {
jwt.RegisteredClaims
// LicenseExpires is the end of the legit license term, and the start of the grace period, if
// there is one. The standard JWT claim "exp" (ExpiresAt in jwt.RegisteredClaims, above) is
// the end of the grace period (identical to LicenseExpires if there is no grace period).
// The reason we use the standard claim for the end of the grace period is that we want JWT
// processing libraries to consider the token "valid" until then.
LicenseExpires *jwt.NumericDate `json:"license_expires,omitempty"`
AccountType string `json:"account_type,omitempty"`
AccountID string `json:"account_id,omitempty"`
Version uint64 `json:"version"`
Features Features `json:"features"`
}
var (
ErrInvalidVersion = xerrors.New("license must be version 3")
ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID)
ErrMissingLicenseExpires = xerrors.New("license missing license_expires")
)
// postLicense adds a new Enterprise license to the cluster. We allow multiple different licenses
// in the cluster at one time for several reasons:
//
@ -93,7 +59,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
return
}
claims, err := parseLicense(addLicense.License, api.Keys)
claims, err := license.Parse(addLicense.License, api.Keys)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
@ -261,65 +227,3 @@ func decodeClaims(l database.License) (jwt.MapClaims, error) {
err = d.Decode(&c)
return c, err
}
// parseLicense parses the license and returns the claims. If the license's signature is invalid or
// is not parsable, an error is returned.
func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) {
tok, err := jwt.Parse(
l,
keyFunc(keys),
jwt.WithValidMethods(ValidMethods),
)
if err != nil {
return nil, err
}
if claims, ok := tok.Claims.(jwt.MapClaims); ok && tok.Valid {
version, ok := claims[VersionClaim].(float64)
if !ok {
return nil, ErrInvalidVersion
}
if int64(version) != CurrentVersion {
return nil, ErrInvalidVersion
}
return claims, nil
}
return nil, xerrors.New("unable to parse Claims")
}
// validateDBLicense validates a database.License record, and if valid, returns the claims. If
// unparsable or invalid, it returns an error
func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) {
tok, err := jwt.ParseWithClaims(
l.JWT,
&Claims{},
keyFunc(keys),
jwt.WithValidMethods(ValidMethods),
)
if err != nil {
return nil, err
}
if claims, ok := tok.Claims.(*Claims); ok && tok.Valid {
if claims.Version != uint64(CurrentVersion) {
return nil, ErrInvalidVersion
}
if claims.LicenseExpires == nil {
return nil, ErrMissingLicenseExpires
}
return claims, nil
}
return nil, xerrors.New("unable to parse Claims")
}
func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) {
return func(j *jwt.Token) (interface{}, error) {
keyID, ok := j.Header[HeaderKeyID].(string)
if !ok {
return nil, ErrMissingKeyID
}
k, ok := keys[keyID]
if !ok {
return nil, xerrors.Errorf("no key with ID %s", keyID)
}
return k, nil
}
}

View File

@ -12,8 +12,8 @@ import (
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/enterprise/coderd/license"
"github.com/coder/coder/testutil"
)
@ -25,7 +25,7 @@ func TestPostLicense(t *testing.T) {
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountType: coderd.AccountTypeSalesforce,
AccountType: license.AccountTypeSalesforce,
AccountID: "testing",
AuditLog: true,
})
@ -89,6 +89,7 @@ func TestGetLicense(t *testing.T) {
AuditLog: true,
SCIM: true,
BrowserOnly: true,
Trial: true,
UserLimit: 200,
})
@ -106,6 +107,7 @@ func TestGetLicense(t *testing.T) {
}, licenses[0].Claims["features"])
assert.Equal(t, int32(2), licenses[1].ID)
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
assert.Equal(t, true, licenses[1].Claims["trial"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("200"),
codersdk.FeatureAuditLog: json.Number("1"),

View File

@ -21,10 +21,10 @@ import (
func (api *API) scimEnabledMW(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
api.entitlementsMu.RLock()
scim := api.entitlements.scim
scim := api.entitlements.Features[codersdk.FeatureSCIM].Enabled
api.entitlementsMu.RUnlock()
if scim == codersdk.EntitlementNotEntitled {
if !scim {
httpapi.RouteNotFound(rw)
return
}

View File

@ -10,9 +10,9 @@ import (
func (api *API) shouldBlockNonBrowserConnections(rw http.ResponseWriter) bool {
api.entitlementsMu.Lock()
browserOnly := api.entitlements.browserOnly
browserOnly := api.entitlements.Features[codersdk.FeatureBrowserOnly].Enabled
api.entitlementsMu.Unlock()
if api.BrowserOnly && browserOnly != codersdk.EntitlementNotEntitled {
if browserOnly {
httpapi.Write(context.Background(), rw, http.StatusConflict, codersdk.Response{
Message: "Non-browser connections are disabled for your deployment.",
})

View File

@ -30,6 +30,7 @@ export const defaultEntitlements = (): TypesGen.Entitlements => {
has_license: false,
warnings: [],
experimental: false,
trial: false,
}
}

View File

@ -245,6 +245,7 @@ export interface Entitlements {
readonly warnings: string[]
readonly has_license: boolean
readonly experimental: boolean
readonly trial: boolean
}
// From codersdk/features.go

View File

@ -777,12 +777,14 @@ export const MockEntitlements: TypesGen.Entitlements = {
has_license: false,
features: {},
experimental: false,
trial: false,
}
export const MockEntitlementsWithWarnings: TypesGen.Entitlements = {
warnings: ["You are over your active user limit.", "And another thing."],
has_license: true,
experimental: false,
trial: false,
features: {
user_limit: {
enabled: true,
@ -805,6 +807,7 @@ export const MockEntitlementsWithAuditLog: TypesGen.Entitlements = {
warnings: [],
has_license: true,
experimental: false,
trial: false,
features: {
audit_log: {
enabled: true,

View File

@ -24,6 +24,7 @@ const emptyEntitlements = {
features: {},
has_license: false,
experimental: false,
trial: false,
}
export const entitlementsMachine = createMachine(