chore: Refactor Enterprise code to layer on top of AGPL (#4034)

* chore: Refactor Enterprise code to layer on top of AGPL

This is an experiment to invert the import order of the Enterprise
code to layer on top of AGPL.

* Fix Garrett's comments

* Add pointer.Handle to atomically obtain references

This uses a context to ensure the same value persists through
multiple executions to `Load()`.

* Remove entitlements API from AGPL coderd

* Remove AGPL Coder entitlements endpoint test

* Fix warnings output

* Add command-line flag to toggle audit logging

* Fix hasLicense being set

* Remove features interface

* Fix audit logging default

* Add bash as a dependency

* Add comment

* Add tests for resync and pubsub, and add back previous exp backoff retry

* Separate authz code again

* Add pointer loading example from comment

* Fix duplicate test, remove pointer.Handle

* Fix expired license

* Add entitlements struct

* Fix context passing
This commit is contained in:
Kyle Carberry 2022-09-19 23:11:01 -05:00 committed by GitHub
parent 714c366d16
commit db0ba8588e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1402 additions and 2069 deletions

View File

@ -20,8 +20,6 @@ import (
"github.com/stretchr/testify/require"
gossh "golang.org/x/crypto/ssh"
"cdr.dev/slog"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
@ -83,18 +81,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str
errC <- cmd.ExecuteContext(ctx)
}()
t.Cleanup(func() { require.NoError(t, <-errC) })
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
require.NoError(t, err)
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
require.NoError(t, err)
defer dialer.Close()
require.Eventually(t, func() bool {
_, err = dialer.Ping()
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
return agentClient, agentToken, pubkey
}

View File

@ -91,12 +91,13 @@ func Core() []*cobra.Command {
users(),
versionCmd(),
workspaceAgent(),
features(),
}
}
func AGPL() []*cobra.Command {
all := append(Core(), Server(coderd.New))
all := append(Core(), Server(func(_ context.Context, o *coderd.Options) (*coderd.API, error) {
return coderd.New(o), nil
}))
return all
}
@ -548,13 +549,11 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {
defer cancel()
entitlements, err := client.Entitlements(ctx)
if err != nil {
return xerrors.Errorf("get entitlements to show warnings: %w", err)
if err == nil {
for _, w := range entitlements.Warnings {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w))
}
}
for _, w := range entitlements.Warnings {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w))
}
return nil
}

View File

@ -68,7 +68,7 @@ import (
)
// nolint:gocyclo
func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command {
var (
accessURL string
address string
@ -489,7 +489,10 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
), promAddress, "prometheus")()
}
coderAPI := newAPI(options)
coderAPI, err := newAPI(ctx, options)
if err != nil {
return err
}
defer coderAPI.Close()
client := codersdk.New(localURL)
@ -536,7 +539,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
// These errors are typically noise like "TLS: EOF". Vault does similar:
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
ErrorLog: log.New(io.Discard, "", 0),
Handler: coderAPI.Handler,
Handler: coderAPI.RootHandler,
BaseContext: func(_ net.Listener) context.Context {
return shutdownConnsCtx
},

View File

@ -12,14 +12,13 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/features"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/tracing"
)
type RequestParams struct {
Features features.Service
Log slog.Logger
Audit Auditor
Log slog.Logger
Request *http.Request
Action database.AuditAction
@ -102,15 +101,6 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
params: p,
}
feats := struct {
Audit Auditor
}{}
err := p.Features.Get(&feats)
if err != nil {
p.Log.Error(p.Request.Context(), "unable to get auditor interface", slog.Error(err))
return req, func() {}
}
return req, func() {
ctx := context.Background()
logCtx := p.Request.Context()
@ -120,7 +110,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
return
}
diff := Diff(feats.Audit, req.Old, req.New)
diff := Diff(p.Audit, req.Old, req.New)
diffRaw, _ := json.Marshal(diff)
ip, err := parseIP(p.Request.RemoteAddr)
@ -128,7 +118,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
p.Log.Warn(logCtx, "parse ip", slog.Error(err))
}
err = feats.Audit.Export(ctx, database.AuditLog{
err = p.Audit.Export(ctx, database.AuditLog{
ID: uuid.New(),
Time: database.Now(),
UserID: httpmw.APIKey(p.Request).UserID,

View File

@ -43,7 +43,7 @@ type HTTPAuthorizer struct {
// return
// }
func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
return api.httpAuth.Authorize(r, action, object)
return api.HTTPAuth.Authorize(r, action, object)
}
// Authorize will return false if the user is not authorized to do the action.

View File

@ -7,6 +7,7 @@ import (
"net/url"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/andybalholm/brotli"
@ -24,9 +25,9 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/awsidentity"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/features"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
@ -50,6 +51,7 @@ type Options struct {
// CacheDir is used for caching files served by the API.
CacheDir string
Auditor audit.Auditor
AgentConnectionUpdateFrequency time.Duration
AgentInactiveDisconnectTimeout time.Duration
// APIRateLimit is the minutely throughput rate limit per user or ip.
@ -68,8 +70,6 @@ type Options struct {
Telemetry telemetry.Reporter
TracerProvider trace.TracerProvider
AutoImportTemplates []AutoImportTemplate
LicenseHandler http.Handler
FeaturesService features.Service
TailnetCoordinator *tailnet.Coordinator
DERPMap *tailcfg.DERPMap
@ -80,6 +80,9 @@ type Options struct {
// New constructs a Coder API handler.
func New(options *Options) *API {
if options == nil {
options = &Options{}
}
if options.AgentConnectionUpdateFrequency == 0 {
options.AgentConnectionUpdateFrequency = 3 * time.Second
}
@ -117,11 +120,8 @@ func New(options *Options) *API {
if options.TailnetCoordinator == nil {
options.TailnetCoordinator = tailnet.NewCoordinator()
}
if options.LicenseHandler == nil {
options.LicenseHandler = licenses()
}
if options.FeaturesService == nil {
options.FeaturesService = &featuresService{}
if options.Auditor == nil {
options.Auditor = audit.NewNop()
}
siteCacheDir := options.CacheDir
@ -142,14 +142,16 @@ func New(options *Options) *API {
r := chi.NewRouter()
api := &API{
Options: options,
Handler: r,
RootHandler: r,
siteHandler: site.Handler(site.FS(), binFS),
httpAuth: &HTTPAuthorizer{
HTTPAuth: &HTTPAuthorizer{
Authorizer: options.Authorizer,
Logger: options.Logger,
},
metricsCache: metricsCache,
Auditor: atomic.Pointer[audit.Auditor]{},
}
api.Auditor.Store(&options.Auditor)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
oauthConfigs := &httpmw.OAuth2Configs{
@ -218,6 +220,8 @@ func New(options *Options) *API {
})
r.Route("/api/v2", func(r chi.Router) {
api.APIHandler = r
r.NotFound(func(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: "Route not found.",
@ -473,14 +477,6 @@ func New(options *Options) *API {
r.Get("/resources", api.workspaceBuildResources)
r.Get("/state", api.workspaceBuildState)
})
r.Route("/entitlements", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/", api.FeaturesService.EntitlementsAPI)
})
r.Route("/licenses", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Mount("/", options.LicenseHandler)
})
})
r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP)
@ -489,17 +485,20 @@ func New(options *Options) *API {
type API struct {
*Options
Auditor atomic.Pointer[audit.Auditor]
HTTPAuth *HTTPAuthorizer
derpServer *derp.Server
// APIHandler serves "/api/v2"
APIHandler chi.Router
// RootHandler serves "/"
RootHandler chi.Router
Handler chi.Router
derpServer *derp.Server
metricsCache *metricscache.Cache
siteHandler http.Handler
websocketWaitMutex sync.Mutex
websocketWaitGroup sync.WaitGroup
workspaceAgentCache *wsconncache.Cache
httpAuth *HTTPAuthorizer
metricsCache *metricscache.Cache
}
// Close waits for all WebSocket connections to drain before returning.

View File

@ -38,16 +38,6 @@ func TestBuildInfo(t *testing.T) {
require.Equal(t, buildinfo.Version(), buildInfo.Version, "version")
}
// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered.
func TestAuthorizeAllEndpoints(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
a := coderdtest.NewAuthTester(ctx, t, nil)
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
a.Test(ctx, assertRoute, skipRoutes)
}
func TestDERP(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)

View File

@ -22,143 +22,6 @@ import (
"github.com/coder/coder/testutil"
)
type RouteCheck struct {
NoAuthorize bool
AssertAction rbac.Action
AssertObject rbac.Object
StatusCode int
}
type AuthTester struct {
t *testing.T
api *coderd.API
authorizer *recordingAuthorizer
Client *codersdk.Client
Workspace codersdk.Workspace
Organization codersdk.Organization
Admin codersdk.CreateFirstUserResponse
Template codersdk.Template
Version codersdk.TemplateVersion
WorkspaceResource codersdk.WorkspaceResource
File codersdk.UploadResponse
TemplateVersionDryRun codersdk.ProvisionerJob
TemplateParam codersdk.Parameter
URLParams map[string]string
}
func NewAuthTester(ctx context.Context, t *testing.T, options *Options) *AuthTester {
authorizer := &recordingAuthorizer{}
if options == nil {
options = &Options{}
}
if options.Authorizer != nil {
t.Error("NewAuthTester cannot be called with custom Authorizer")
}
options.Authorizer = authorizer
options.IncludeProvisionerDaemon = true
client, _, api := newWithAPI(t, options)
admin := CreateFirstUser(t, client)
// The provisioner will call to coderd and register itself. This is async,
// so we wait for it to occur.
require.Eventually(t, func() bool {
provisionerds, err := client.ProvisionerDaemons(ctx)
return assert.NoError(t, err) && len(provisionerds) > 0
}, testutil.WaitLong, testutil.IntervalSlow)
provisionerds, err := client.ProvisionerDaemons(ctx)
require.NoError(t, err, "fetch provisioners")
require.Len(t, provisionerds, 1)
organization, err := client.Organization(ctx, admin.OrganizationID)
require.NoError(t, err, "fetch org")
// Setup some data in the database.
version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
// Return a workspace resource
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agents: []*proto.Agent{{
Name: "agent",
Id: "something",
Auth: &proto.Agent_Token{},
Apps: []*proto.App{{
Name: "testapp",
Url: "http://localhost:3000",
}},
}},
}},
},
},
}},
})
AwaitTemplateVersionJob(t, client, version.ID)
template := CreateTemplate(t, client, admin.OrganizationID, version.ID)
workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID)
AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024))
require.NoError(t, err, "upload file")
workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
require.NoError(t, err, "workspace resources")
templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{
ParameterValues: []codersdk.CreateParameterRequest{},
})
require.NoError(t, err, "template version dry-run")
templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{
Name: "test-param",
SourceValue: "hello world",
SourceScheme: codersdk.ParameterSourceSchemeData,
DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable,
})
require.NoError(t, err, "create template param")
urlParameters := map[string]string{
"{organization}": admin.OrganizationID.String(),
"{user}": admin.UserID.String(),
"{organizationname}": organization.Name,
"{workspace}": workspace.ID.String(),
"{workspacebuild}": workspace.LatestBuild.ID.String(),
"{workspacename}": workspace.Name,
"{workspaceagent}": workspaceResources[0].Agents[0].ID.String(),
"{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10),
"{template}": template.ID.String(),
"{hash}": file.Hash,
"{workspaceresource}": workspaceResources[0].ID.String(),
"{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name,
"{templateversion}": version.ID.String(),
"{jobID}": templateVersionDryRun.ID.String(),
"{templatename}": template.Name,
"{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name,
// Only checking template scoped params here
"parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s",
string(templateParam.Scope), templateParam.ScopeID.String()),
}
return &AuthTester{
t: t,
api: api,
authorizer: authorizer,
Client: client,
Workspace: workspace,
Organization: organization,
Admin: admin,
Template: template,
Version: version,
WorkspaceResource: workspaceResources[0],
File: file,
TemplateVersionDryRun: templateVersionDryRun,
TemplateParam: templateParam,
URLParams: urlParameters,
}
}
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
// Some quick reused objects
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
@ -181,7 +44,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"POST:/api/v2/users/login": {NoAuthorize: true},
"GET:/api/v2/users/authmethods": {NoAuthorize: true},
"POST:/api/v2/csp/reports": {NoAuthorize: true},
"GET:/api/v2/entitlements": {NoAuthorize: true},
// Has it's own auth
"GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true},
@ -408,6 +270,134 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
return skipRoutes, assertRoute
}
type RouteCheck struct {
NoAuthorize bool
AssertAction rbac.Action
AssertObject rbac.Object
StatusCode int
}
type AuthTester struct {
t *testing.T
api *coderd.API
authorizer *RecordingAuthorizer
Client *codersdk.Client
Workspace codersdk.Workspace
Organization codersdk.Organization
Admin codersdk.CreateFirstUserResponse
Template codersdk.Template
Version codersdk.TemplateVersion
WorkspaceResource codersdk.WorkspaceResource
File codersdk.UploadResponse
TemplateVersionDryRun codersdk.ProvisionerJob
TemplateParam codersdk.Parameter
URLParams map[string]string
}
func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, api *coderd.API, admin codersdk.CreateFirstUserResponse) *AuthTester {
authorizer, ok := api.Authorizer.(*RecordingAuthorizer)
if !ok {
t.Fail()
}
// The provisioner will call to coderd and register itself. This is async,
// so we wait for it to occur.
require.Eventually(t, func() bool {
provisionerds, err := client.ProvisionerDaemons(ctx)
return assert.NoError(t, err) && len(provisionerds) > 0
}, testutil.WaitLong, testutil.IntervalSlow)
provisionerds, err := client.ProvisionerDaemons(ctx)
require.NoError(t, err, "fetch provisioners")
require.Len(t, provisionerds, 1)
organization, err := client.Organization(ctx, admin.OrganizationID)
require.NoError(t, err, "fetch org")
// Setup some data in the database.
version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
// Return a workspace resource
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agents: []*proto.Agent{{
Name: "agent",
Id: "something",
Auth: &proto.Agent_Token{},
Apps: []*proto.App{{
Name: "testapp",
Url: "http://localhost:3000",
}},
}},
}},
},
},
}},
})
AwaitTemplateVersionJob(t, client, version.ID)
template := CreateTemplate(t, client, admin.OrganizationID, version.ID)
workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID)
AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024))
require.NoError(t, err, "upload file")
workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
require.NoError(t, err, "workspace resources")
templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{
ParameterValues: []codersdk.CreateParameterRequest{},
})
require.NoError(t, err, "template version dry-run")
templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{
Name: "test-param",
SourceValue: "hello world",
SourceScheme: codersdk.ParameterSourceSchemeData,
DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable,
})
require.NoError(t, err, "create template param")
urlParameters := map[string]string{
"{organization}": admin.OrganizationID.String(),
"{user}": admin.UserID.String(),
"{organizationname}": organization.Name,
"{workspace}": workspace.ID.String(),
"{workspacebuild}": workspace.LatestBuild.ID.String(),
"{workspacename}": workspace.Name,
"{workspaceagent}": workspaceResources[0].Agents[0].ID.String(),
"{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10),
"{template}": template.ID.String(),
"{hash}": file.Hash,
"{workspaceresource}": workspaceResources[0].ID.String(),
"{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name,
"{templateversion}": version.ID.String(),
"{jobID}": templateVersionDryRun.ID.String(),
"{templatename}": template.Name,
"{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name,
// Only checking template scoped params here
"parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s",
string(templateParam.Scope), templateParam.ScopeID.String()),
}
return &AuthTester{
t: t,
api: api,
authorizer: authorizer,
Client: client,
Workspace: workspace,
Organization: organization,
Admin: admin,
Template: template,
Version: version,
WorkspaceResource: workspaceResources[0],
File: file,
TemplateVersionDryRun: templateVersionDryRun,
TemplateParam: templateParam,
URLParams: urlParameters,
}
}
func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) {
// Always fail auth from this point forward
a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil)
@ -433,7 +423,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
}
err := chi.Walk(
a.api.Handler,
a.api.RootHandler,
func(
method string,
route string,
@ -513,14 +503,14 @@ type authCall struct {
Object rbac.Object
}
type recordingAuthorizer struct {
type RecordingAuthorizer struct {
Called *authCall
AlwaysReturn error
}
var _ rbac.Authorizer = (*recordingAuthorizer)(nil)
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
r.Called = &authCall{
SubjectID: subjectID,
Roles: roleNames,
@ -531,7 +521,7 @@ func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro
return r.AlwaysReturn
}
func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
return &fakePreparedAuthorizer{
Original: r,
SubjectID: subjectID,
@ -541,12 +531,12 @@ func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID str
}, nil
}
func (r *recordingAuthorizer) reset() {
func (r *RecordingAuthorizer) reset() {
r.Called = nil
}
type fakePreparedAuthorizer struct {
Original *recordingAuthorizer
Original *RecordingAuthorizer
SubjectID string
Roles []string
Scope rbac.Scope

View File

@ -0,0 +1,20 @@
package coderdtest_test
import (
"context"
"testing"
"github.com/coder/coder/coderd/coderdtest"
)
func TestAuthorizeAllEndpoints(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Authorizer: &coderdtest.RecordingAuthorizer{},
IncludeProvisionerDaemon: true,
})
admin := coderdtest.CreateFirstUser(t, client)
a := coderdtest.NewAuthTester(context.Background(), t, client, api, admin)
skipRoute, assertRoute := coderdtest.AGPLRoutes(a)
a.Test(context.Background(), assertRoute, skipRoute)
}

View File

@ -80,7 +80,6 @@ type Options struct {
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
IncludeProvisionerDaemon bool
APIBuilder func(*coderd.Options) *coderd.API
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
}
@ -112,14 +111,11 @@ func NewWithProvisionerCloser(t *testing.T, options *Options) (*codersdk.Client,
// and is a temporary measure while the API to register provisioners is ironed
// out.
func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) {
client, closer, _ := newWithAPI(t, options)
client, closer, _ := NewWithAPI(t, options)
return client, closer
}
// newWithAPI constructs an in-memory API instance and returns a client to talk to it.
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
// Do not expose the API or wrath shall descend upon thee.
func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) {
if options == nil {
options = &Options{}
}
@ -140,9 +136,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
close(options.AutobuildStats)
})
}
if options.APIBuilder == nil {
options.APIBuilder = coderd.New
}
// This can be hotswapped for a live database instance.
db := databasefake.New()
@ -166,8 +159,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
}
ctx, cancelFunc := context.WithCancel(context.Background())
defer t.Cleanup(cancelFunc) // Defer to ensure cancelFunc is executed first.
lifecycleExecutor := executor.New(
ctx,
db,
@ -201,13 +192,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
}
features := coderd.DisabledImplementations
if options.Auditor != nil {
features.Auditor = options.Auditor
}
// We set the handler after server creation for the access URL.
coderAPI := options.APIBuilder(&coderd.Options{
return srv, cancelFunc, &coderd.Options{
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
// Force a long disconnection timeout to ensure
// agents are not marked as disconnected during slow tests.
@ -218,6 +203,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
Database: db,
Pubsub: pubsub,
Auditor: options.Auditor,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
@ -248,22 +234,30 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
AutoImportTemplates: options.AutoImportTemplates,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
FeaturesService: coderd.NewMockFeaturesService(features),
})
t.Cleanup(func() {
_ = coderAPI.Close()
})
srv.Config.Handler = coderAPI.Handler
}
}
// NewWithAPI constructs an in-memory API instance and returns a client to talk to it.
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
// Do not expose the API or wrath shall descend upon thee.
func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
if options == nil {
options = &Options{}
}
srv, cancelFunc, newOptions := NewOptions(t, options)
// We set the handler after server creation for the access URL.
coderAPI := coderd.New(newOptions)
srv.Config.Handler = coderAPI.RootHandler
var provisionerCloser io.Closer = nopcloser{}
if options.IncludeProvisionerDaemon {
provisionerCloser = NewProvisionerDaemon(t, coderAPI)
}
t.Cleanup(func() {
cancelFunc()
_ = provisionerCloser.Close()
_ = coderAPI.Close()
})
return codersdk.New(serverURL), provisionerCloser, coderAPI
return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI
}
// NewProvisionerDaemon launches a provisionerd instance configured to work

View File

@ -1,97 +0,0 @@
package coderd
import (
"net/http"
"reflect"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/features"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
func NewMockFeaturesService(feats FeatureInterfaces) features.Service {
return &featuresService{
feats: &feats,
}
}
type featuresService struct {
feats *FeatureInterfaces
}
func (*featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) {
feats := make(map[string]codersdk.Feature)
for _, f := range codersdk.FeatureNames {
feats[f] = codersdk.Feature{
Entitlement: codersdk.EntitlementNotEntitled,
Enabled: false,
}
}
httpapi.Write(rw, http.StatusOK, codersdk.Entitlements{
Features: feats,
Warnings: []string{},
HasLicense: false,
})
}
// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a
// struct type containing feature interfaces as fields. The AGPL featureService always returns the
// "disabled" version of the feature interface because it doesn't include any enterprise features
// by definition.
func (f *featuresService) Get(ps any) error {
if reflect.TypeOf(ps).Kind() != reflect.Pointer {
return xerrors.New("input must be pointer to struct")
}
vs := reflect.ValueOf(ps).Elem()
if vs.Kind() != reflect.Struct {
return xerrors.New("input must be pointer to struct")
}
for i := 0; i < vs.NumField(); i++ {
vf := vs.Field(i)
tf := vf.Type()
if tf.Kind() != reflect.Interface {
return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String())
}
err := f.setImplementation(vf, tf)
if err != nil {
return err
}
}
return nil
}
// setImplementation finds the correct implementation for the field's type, and sets it on the
// struct. It returns an error if unsuccessful
func (f *featuresService) setImplementation(vf reflect.Value, tf reflect.Type) error {
feats := f.feats
if feats == nil {
feats = &DisabledImplementations
}
// when we get more than a few features it might make sense to have a data structure for finding
// the correct implementation that's faster than just a linear search, but for now just spin
// through the implementations we have.
vd := reflect.ValueOf(*feats)
for j := 0; j < vd.NumField(); j++ {
vdf := vd.Field(j)
if vdf.Type() == tf {
vf.Set(vdf)
return nil
}
}
return xerrors.Errorf("unable to find implementation of interface %s", tf.String())
}
// FeatureInterfaces contains a field for each interface controlled by an enterprise feature.
type FeatureInterfaces struct {
Auditor audit.Auditor
}
// DisabledImplementations includes all the implementations of turned-off features. There are no
// turned-on implementations in AGPL code.
var DisabledImplementations = FeatureInterfaces{
Auditor: audit.NewNop(),
}

View File

@ -1,13 +0,0 @@
package features
import "net/http"
// Service is the interface for interacting with enterprise features.
type Service interface {
EntitlementsAPI(w http.ResponseWriter, r *http.Request)
// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a
// struct type containing feature interfaces as fields. The FeatureService sets all fields to
// the correct implementations depending on whether the features are turned on.
Get(s any) error
}

View File

@ -1,100 +0,0 @@
package coderd
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/codersdk"
)
func TestEntitlements(t *testing.T) {
t.Parallel()
t.Run("GET", func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
rw := httptest.NewRecorder()
(&featuresService{}).EntitlementsAPI(rw, r)
resp := rw.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
dec := json.NewDecoder(resp.Body)
var result codersdk.Entitlements
err := dec.Decode(&result)
require.NoError(t, err)
assert.False(t, result.HasLicense)
assert.Empty(t, result.Warnings)
for _, f := range codersdk.FeatureNames {
require.Contains(t, result.Features, f)
fe := result.Features[f]
assert.False(t, fe.Enabled)
assert.Equal(t, codersdk.EntitlementNotEntitled, fe.Entitlement)
}
})
}
func TestFeaturesServiceGet(t *testing.T) {
t.Parallel()
t.Run("Auditor", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
Auditor audit.Auditor
}{}
err := uut.Get(&target)
require.NoError(t, err)
assert.NotNil(t, target.Auditor)
})
t.Run("NotPointer", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
Auditor audit.Auditor
}{}
err := uut.Get(target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
t.Run("UnknownInterface", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
test testInterface
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.test)
})
t.Run("PointerToNonStruct", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
var target audit.Auditor
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target)
})
t.Run("StructWithNonInterfaces", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
N int64
Auditor audit.Auditor
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
}
type testInterface interface {
Test() error
}

View File

@ -1,24 +0,0 @@
package coderd
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
func licenses() http.Handler {
r := chi.NewRouter()
r.NotFound(unsupported)
return r
}
func unsupported(rw http.ResponseWriter, _ *http.Request) {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: "Unsupported",
Detail: "These endpoints are not supported in AGPL-licensed Coder",
Validations: nil,
})
}

View File

@ -48,7 +48,7 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
if daemons == nil {
daemons = []database.ProvisionerDaemon{}
}
daemons, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, daemons)
daemons, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, daemons)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",

View File

@ -41,7 +41,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) {
api := New(&opts)
defer api.Close()
server := httptest.NewServer(api.Handler)
server := httptest.NewServer(api.RootHandler)
defer server.Close()
userID := uuid.New()
keyID, keySecret, err := generateAPIKeyIDSecret()

View File

@ -85,11 +85,12 @@ func (api *API) template(rw http.ResponseWriter, r *http.Request) {
func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
var (
template = httpmw.TemplateParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionDelete,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionDelete,
})
)
defer commitAudit()
@ -139,17 +140,18 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
createTemplate codersdk.CreateTemplateRequest
organization = httpmw.OrganizationParam(r)
apiKey = httpmw.APIKey(r)
auditor = *api.Auditor.Load()
templateAudit, commitTemplateAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
})
templateVersionAudit, commitTemplateVersionAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitTemplateAudit()
@ -340,7 +342,7 @@ func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request)
}
// Filter templates based on rbac permissions
templates, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, templates)
templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching templates.",
@ -435,11 +437,12 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re
func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
var (
template = httpmw.TemplateParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()

View File

@ -559,11 +559,12 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) {
func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Request) {
var (
template = httpmw.TemplateParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -631,11 +632,12 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
var (
apiKey = httpmw.APIKey(r)
organization = httpmw.OrganizationParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
})
req codersdk.CreateTemplateVersionRequest

View File

@ -220,7 +220,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
return
}
users, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, users)
users, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, users)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching users.",
@ -255,11 +255,12 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
// Creates a new user.
func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
auditor := *api.Auditor.Load()
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
})
defer commitAudit()
@ -339,12 +340,13 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
}
func (api *API) deleteUser(rw http.ResponseWriter, r *http.Request) {
auditor := *api.Auditor.Load()
user := httpmw.UserParam(r)
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionDelete,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionDelete,
})
aReq.Old = user
defer commitAudit()
@ -414,11 +416,12 @@ func (api *API) userByName(rw http.ResponseWriter, r *http.Request) {
func (api *API) putUserProfile(rw http.ResponseWriter, r *http.Request) {
var (
user = httpmw.UserParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -494,11 +497,12 @@ func (api *API) putUserStatus(status database.UserStatus) func(rw http.ResponseW
var (
user = httpmw.UserParam(r)
apiKey = httpmw.APIKey(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -560,11 +564,12 @@ func (api *API) putUserPassword(rw http.ResponseWriter, r *http.Request) {
var (
user = httpmw.UserParam(r)
params codersdk.UpdateUserPasswordRequest
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -673,7 +678,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) {
}
// Only include ones we can read from RBAC.
memberships, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, memberships)
memberships, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, memberships)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching memberships.",
@ -698,11 +703,12 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) {
user = httpmw.UserParam(r)
actorRoles = httpmw.UserAuthorization(r)
apiKey = httpmw.APIKey(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -812,7 +818,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) {
}
// Only return orgs the user can read.
organizations, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, organizations)
organizations, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, organizations)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching organizations.",
@ -1176,9 +1182,9 @@ func (api *API) createUser(ctx context.Context, store database.Store, req create
func (api *API) setAuthCookie(rw http.ResponseWriter, cookie *http.Cookie) {
http.SetCookie(rw, cookie)
devurlCookie := api.applicationCookie(cookie)
if devurlCookie != nil {
http.SetCookie(rw, devurlCookie)
appCookie := api.applicationCookie(cookie)
if appCookie != nil {
http.SetCookie(rw, appCookie)
}
}

View File

@ -32,7 +32,11 @@ func TestFirstUser(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{})
has, err := client.HasFirstUser(context.Background())
require.NoError(t, err)
require.False(t, has)
_, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{})
require.Error(t, err)
})

View File

@ -119,7 +119,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
}
// Only return workspaces the user can read
workspaces, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, workspaces)
workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspaces.",
@ -217,11 +217,12 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
var (
organization = httpmw.OrganizationParam(r)
apiKey = httpmw.APIKey(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
})
)
defer commitAudit()
@ -480,11 +481,12 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
var (
workspace = httpmw.WorkspaceParam(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -556,11 +558,12 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) {
var (
workspace = httpmw.WorkspaceParam(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -616,11 +619,12 @@ func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) {
func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
var (
workspace = httpmw.WorkspaceParam(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()

View File

@ -3,12 +3,15 @@ package cli
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
agpl "github.com/coder/coder/cli"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
)
@ -36,11 +39,15 @@ func featuresList() *cobra.Command {
Use: "list",
Aliases: []string{"ls"},
RunE: func(cmd *cobra.Command, args []string) error {
client, err := CreateClient(cmd)
client, err := agpl.CreateClient(cmd)
if err != nil {
return err
}
entitlements, err := client.Entitlements(cmd.Context())
var apiError *codersdk.Error
if errors.As(err, &apiError) && apiError.StatusCode() == http.StatusNotFound {
return xerrors.New("You are on the AGPL licensed version of Coder that does not have Enterprise functionality!")
}
if err != nil {
return err
}

View File

@ -11,6 +11,8 @@ import (
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/cli"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/pty/ptytest"
)
@ -18,9 +20,9 @@ func TestFeaturesList(t *testing.T) {
t.Parallel()
t.Run("Table", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
client := coderdenttest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.New(t, "features", "list")
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "features", "list")
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
@ -36,9 +38,9 @@ func TestFeaturesList(t *testing.T) {
t.Run("JSON", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
client := coderdenttest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.New(t, "features", "list", "-o", "json")
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "features", "list", "-o", "json")
clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})

View File

@ -23,7 +23,7 @@ import (
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/cli"
"github.com/coder/coder/enterprise/coderd"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/testutil"
)
@ -124,7 +124,7 @@ func TestLicensesAddReal(t *testing.T) {
t.Parallel()
t.Run("Fails", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise})
client := coderdenttest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(),
"licenses", "add", "-l", fakeLicenseJWT)
@ -175,7 +175,7 @@ func TestLicensesListReal(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise})
client := coderdenttest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(),
"licenses", "list")
@ -219,7 +219,7 @@ func TestLicensesDeleteReal(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise})
client := coderdenttest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(),
"licenses", "delete", "1")

View File

@ -4,12 +4,12 @@ import (
"github.com/spf13/cobra"
agpl "github.com/coder/coder/cli"
"github.com/coder/coder/enterprise/coderd"
)
func enterpriseOnly() []*cobra.Command {
return []*cobra.Command{
agpl.Server(coderd.NewEnterprise),
server(),
features(),
licenses(),
}
}

33
enterprise/cli/server.go Normal file
View File

@ -0,0 +1,33 @@
package cli
import (
"context"
"github.com/spf13/cobra"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/enterprise/coderd"
agpl "github.com/coder/coder/cli"
agplcoderd "github.com/coder/coder/coderd"
)
func server() *cobra.Command {
var (
auditLogging bool
)
cmd := agpl.Server(func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) {
api, err := coderd.New(ctx, &coderd.Options{
AuditLogging: auditLogging,
Options: options,
})
if err != nil {
return nil, err
}
return api.AGPL, nil
})
cliflag.BoolVarP(cmd.Flags(), &auditLogging, "audit-logging", "", "CODER_AUDIT_LOGGING", true,
"Specifies whether audit logging is enabled.")
return cmd
}

View File

@ -1,80 +0,0 @@
package coderd
import (
"context"
"crypto/ed25519"
"crypto/rand"
"fmt"
"net/http"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered.
// these tests patch the map of license keys, so cannot be run in parallel
// nolint:paralleltest
func TestAuthorizeAllEndpoints(t *testing.T) {
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := "testing"
oldKeys := keys
defer func() {
t.Log("restoring keys")
keys = oldKeys
}()
keys = map[string]ed25519.PublicKey{keyID: pubKey}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
a := coderdtest.NewAuthTester(ctx, t, &coderdtest.Options{APIBuilder: NewEnterprise})
// We need a license in the DB, so that when we call GET api/v2/licenses there is one in the
// list to check authz on.
claims := &Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@coder.test",
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
},
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
AccountType: AccountTypeSalesforce,
AccountID: "testing",
Version: CurrentVersion,
Features: Features{
UserLimit: 0,
AuditLog: 1,
},
}
lic, err := makeLicense(claims, privKey, keyID)
require.NoError(t, err)
license, err := a.Client.AddLicense(ctx, codersdk.AddLicenseRequest{
License: lic,
})
require.NoError(t, err)
a.URLParams["licenses/{id}"] = fmt.Sprintf("licenses/%d", license.ID)
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionCreate,
AssertObject: rbac.ResourceLicense,
}
assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{
StatusCode: http.StatusOK,
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceLicense,
}
assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionDelete,
AssertObject: rbac.ResourceLicense,
}
a.Test(ctx, assertRoute, skipRoutes)
}

View File

@ -2,48 +2,282 @@ package coderd
import (
"context"
"os"
"strings"
"crypto/ed25519"
"fmt"
"net/http"
"sync"
"time"
"golang.org/x/xerrors"
"github.com/cenkalti/backoff/v4"
"github.com/go-chi/chi/v5"
"cdr.dev/slog"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/rbac"
agplaudit "github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/audit"
"github.com/coder/coder/enterprise/audit/backends"
)
const EnvAuditLogEnable = "CODER_AUDIT_LOG_ENABLE"
// New constructs an Enterprise coderd API instance.
// This handler is designed to wrap the AGPL Coder code and
// layer Enterprise functionality on top as much as possible.
func New(ctx context.Context, options *Options) (*API, error) {
if options.EntitlementsUpdateInterval == 0 {
options.EntitlementsUpdateInterval = 10 * time.Minute
}
if options.Keys == nil {
options.Keys = Keys
}
ctx, cancelFunc := context.WithCancel(ctx)
api := &API{
AGPL: coderd.New(options.Options),
Options: options,
func NewEnterprise(options *coderd.Options) *coderd.API {
var eOpts = *options
if eOpts.Authorizer == nil {
var err error
eOpts.Authorizer, err = rbac.NewAuthorizer()
entitlements: entitlements{
activeUsers: codersdk.Feature{
Entitlement: codersdk.EntitlementNotEntitled,
Enabled: false,
},
auditLogs: codersdk.EntitlementNotEntitled,
},
cancelEntitlementsLoop: cancelFunc,
}
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
OIDC: options.OIDCConfig,
}
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
api.AGPL.APIHandler.Group(func(r chi.Router) {
r.Get("/entitlements", api.serveEntitlements)
r.Route("/licenses", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Post("/", api.postLicense)
r.Get("/", api.licenses)
r.Delete("/{id}", api.deleteLicense)
})
})
err := api.updateEntitlements(ctx)
if err != nil {
return nil, xerrors.Errorf("update entitlements: %w", err)
}
go api.runEntitlementsLoop(ctx)
return api, nil
}
type Options struct {
*coderd.Options
AuditLogging bool
EntitlementsUpdateInterval time.Duration
Keys map[string]ed25519.PublicKey
}
type API struct {
AGPL *coderd.API
*Options
cancelEntitlementsLoop func()
entitlementsMu sync.RWMutex
entitlements entitlements
}
type entitlements struct {
hasLicense bool
activeUsers codersdk.Feature
auditLogs codersdk.Entitlement
}
func (api *API) Close() error {
api.cancelEntitlementsLoop()
return api.AGPL.Close()
}
func (api *API) updateEntitlements(ctx context.Context) error {
licenses, err := api.Database.GetUnexpiredLicenses(ctx)
if err != nil {
return err
}
api.entitlementsMu.Lock()
defer api.entitlementsMu.Unlock()
now := time.Now()
// Default all entitlements to be disabled.
entitlements := entitlements{
hasLicense: false,
activeUsers: codersdk.Feature{
Enabled: false,
Entitlement: codersdk.EntitlementNotEntitled,
},
auditLogs: codersdk.EntitlementNotEntitled,
}
// Here we loop through licenses to detect enabled features.
for _, l := range licenses {
claims, err := validateDBLicense(l, api.Keys)
if err != nil {
// This should never happen, as the unit tests would fail if the
// default built in authorizer failed.
panic(xerrors.Errorf("rego authorize panic: %w", err))
api.Logger.Debug(ctx, "skipping invalid license",
slog.F("id", l.ID), slog.Error(err))
continue
}
entitlements.hasLicense = true
entitlement := codersdk.EntitlementEntitled
if now.After(claims.LicenseExpires.Time) {
// if the grace period were over, the validation fails, so if we are after
// LicenseExpires we must be in grace period.
entitlement = codersdk.EntitlementGracePeriod
}
if claims.Features.UserLimit > 0 {
entitlements.activeUsers = codersdk.Feature{
Enabled: true,
Entitlement: entitlement,
}
currentLimit := int64(0)
if entitlements.activeUsers.Limit != nil {
currentLimit = *entitlements.activeUsers.Limit
}
limit := max(currentLimit, claims.Features.UserLimit)
entitlements.activeUsers.Limit = &limit
}
if claims.Features.AuditLog > 0 {
entitlements.auditLogs = entitlement
}
}
eOpts.LicenseHandler = newLicenseAPI(
eOpts.Logger,
eOpts.Database,
eOpts.Pubsub,
&coderd.HTTPAuthorizer{
Authorizer: eOpts.Authorizer,
Logger: eOpts.Logger,
}).handler()
en := Enablements{AuditLogs: true}
auditLog := os.Getenv(EnvAuditLogEnable)
auditLog = strings.ToLower(auditLog)
if auditLog == "disable" || auditLog == "false" || auditLog == "0" || auditLog == "no" {
en.AuditLogs = false
if entitlements.auditLogs != api.entitlements.auditLogs {
auditor := agplaudit.NewNop()
// A flag could be added to the options that would allow disabling
// enhanced audit logging here!
if entitlements.auditLogs == codersdk.EntitlementEntitled && api.AuditLogging {
auditor = audit.NewAuditor(
audit.DefaultFilter,
backends.NewPostgres(api.Database, true),
backends.NewSlog(api.Logger),
)
}
api.AGPL.Auditor.Store(&auditor)
}
eOpts.FeaturesService = newFeaturesService(
context.Background(),
eOpts.Logger,
eOpts.Database,
eOpts.Pubsub,
en,
)
return coderd.New(&eOpts)
api.entitlements = entitlements
return nil
}
func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
api.entitlementsMu.RLock()
entitlements := api.entitlements
api.entitlementsMu.RUnlock()
resp := codersdk.Entitlements{
Features: make(map[string]codersdk.Feature),
Warnings: make([]string, 0),
HasLicense: entitlements.hasLicense,
}
if entitlements.activeUsers.Limit != nil {
activeUserCount, err := api.Database.GetActiveUserCount(r.Context())
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Unable to query database",
Detail: err.Error(),
})
return
}
entitlements.activeUsers.Actual = &activeUserCount
if activeUserCount > *entitlements.activeUsers.Limit {
resp.Warnings = append(resp.Warnings,
fmt.Sprintf(
"Your deployment has %d active users but is only licensed for %d.",
activeUserCount, *entitlements.activeUsers.Limit))
}
}
resp.Features[codersdk.FeatureUserLimit] = entitlements.activeUsers
// Audit logs
resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{
Entitlement: entitlements.auditLogs,
Enabled: api.AuditLogging,
}
if entitlements.auditLogs == codersdk.EntitlementGracePeriod && api.AuditLogging {
resp.Warnings = append(resp.Warnings,
"Audit logging is enabled but your license for this feature is expired.")
}
httpapi.Write(rw, http.StatusOK, resp)
}
func (api *API) runEntitlementsLoop(ctx context.Context) {
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0 // retry indefinitely
b := backoff.WithContext(eb, ctx)
updates := make(chan struct{}, 1)
subscribed := false
for {
select {
case <-ctx.Done():
return
default:
// pass
}
if !subscribed {
cancel, err := api.Pubsub.Subscribe(PubsubEventLicenses, func(_ context.Context, _ []byte) {
// don't block. If the channel is full, drop the event, as there is a resync
// scheduled already.
select {
case updates <- struct{}{}:
// pass
default:
// pass
}
})
if err != nil {
api.Logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err))
select {
case <-ctx.Done():
return
case <-time.After(b.NextBackOff()):
}
continue
}
// nolint: revive
defer cancel()
subscribed = true
api.Logger.Debug(ctx, "successfully subscribed to pubsub")
}
api.Logger.Info(ctx, "syncing licensed entitlements")
err := api.updateEntitlements(ctx)
if err != nil {
api.Logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err))
time.Sleep(b.NextBackOff())
continue
}
b.Reset()
api.Logger.Debug(ctx, "synced licensed entitlements")
select {
case <-ctx.Done():
return
case <-time.After(api.EntitlementsUpdateInterval):
continue
case <-updates:
api.Logger.Debug(ctx, "got pubsub update")
continue
}
}
}
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@ -0,0 +1,204 @@
package coderd_test
import (
"context"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
agplaudit "github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/audit"
"github.com/coder/coder/enterprise/coderd"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestEntitlements(t *testing.T) {
t.Parallel()
t.Run("NoLicense", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
res, err := client.Entitlements(context.Background())
require.NoError(t, err)
require.False(t, res.HasLicense)
require.Empty(t, res.Warnings)
})
t.Run("FullLicense", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
})
res, err := client.Entitlements(context.Background())
require.NoError(t, err)
assert.True(t, res.HasLicense)
ul := res.Features[codersdk.FeatureUserLimit]
assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement)
assert.Equal(t, int64(100), *ul.Limit)
assert.Equal(t, int64(1), *ul.Actual)
assert.True(t, ul.Enabled)
al := res.Features[codersdk.FeatureAuditLog]
assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement)
assert.True(t, al.Enabled)
assert.Nil(t, al.Limit)
assert.Nil(t, al.Actual)
assert.Empty(t, res.Warnings)
})
t.Run("FullLicenseToNone", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
})
res, err := client.Entitlements(context.Background())
require.NoError(t, err)
assert.True(t, res.HasLicense)
al := res.Features[codersdk.FeatureAuditLog]
assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement)
assert.True(t, al.Enabled)
err = client.DeleteLicense(context.Background(), license.ID)
require.NoError(t, err)
res, err = client.Entitlements(context.Background())
require.NoError(t, err)
assert.False(t, res.HasLicense)
al = res.Features[codersdk.FeatureAuditLog]
assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement)
assert.True(t, al.Enabled)
})
t.Run("Warnings", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
for i := 0; i < 4; i++ {
coderdtest.CreateAnotherUser(t, client, first.OrganizationID)
}
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
UserLimit: 4,
AuditLog: true,
GraceAt: time.Now().Add(-time.Second),
})
res, err := client.Entitlements(context.Background())
require.NoError(t, err)
assert.True(t, res.HasLicense)
ul := res.Features[codersdk.FeatureUserLimit]
assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement)
assert.Equal(t, int64(4), *ul.Limit)
assert.Equal(t, int64(5), *ul.Actual)
assert.True(t, ul.Enabled)
al := res.Features[codersdk.FeatureAuditLog]
assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement)
assert.True(t, al.Enabled)
assert.Nil(t, al.Limit)
assert.Nil(t, al.Actual)
assert.Len(t, res.Warnings, 2)
assert.Contains(t, res.Warnings,
"Your deployment has 5 active users but is only licensed for 4.")
assert.Contains(t, res.Warnings,
"Audit logging is enabled but your license for this feature is expired.")
})
t.Run("Pubsub", func(t *testing.T) {
t.Parallel()
client, _, api := coderdenttest.NewWithAPI(t, nil)
entitlements, err := client.Entitlements(context.Background())
require.NoError(t, err)
require.False(t, entitlements.HasLicense)
coderdtest.CreateFirstUser(t, client)
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
UploadedAt: database.Now(),
Exp: database.Now().AddDate(1, 0, 0),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
AuditLog: true,
}),
})
require.NoError(t, err)
err = api.Pubsub.Publish(coderd.PubsubEventLicenses, []byte{})
require.NoError(t, err)
require.Eventually(t, func() bool {
entitlements, err := client.Entitlements(context.Background())
assert.NoError(t, err)
return entitlements.HasLicense
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("Resync", func(t *testing.T) {
t.Parallel()
client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
EntitlementsUpdateInterval: 25 * time.Millisecond,
})
entitlements, err := client.Entitlements(context.Background())
require.NoError(t, err)
require.False(t, entitlements.HasLicense)
coderdtest.CreateFirstUser(t, client)
// Valid
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
UploadedAt: database.Now(),
Exp: database.Now().AddDate(1, 0, 0),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
AuditLog: true,
}),
})
require.NoError(t, err)
// Expired
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
UploadedAt: database.Now(),
Exp: database.Now().AddDate(-1, 0, 0),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
ExpiresAt: database.Now().AddDate(-1, 0, 0),
}),
})
require.NoError(t, err)
// Invalid
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
UploadedAt: database.Now(),
Exp: database.Now().AddDate(1, 0, 0),
JWT: "invalid",
})
require.NoError(t, err)
require.Eventually(t, func() bool {
entitlements, err := client.Entitlements(context.Background())
assert.NoError(t, err)
return entitlements.HasLicense
}, testutil.WaitShort, testutil.IntervalFast)
})
}
func TestAuditLogging(t *testing.T) {
t.Parallel()
t.Run("Enabled", func(t *testing.T) {
t.Parallel()
client, _, api := coderdenttest.NewWithAPI(t, nil)
coderdtest.CreateFirstUser(t, client)
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AuditLog: true,
})
auditor := *api.AGPL.Auditor.Load()
ea := audit.NewAuditor(audit.DefaultFilter)
t.Logf("%T = %T", auditor, ea)
assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type())
})
t.Run("Disabled", func(t *testing.T) {
t.Parallel()
client, _, api := coderdenttest.NewWithAPI(t, nil)
coderdtest.CreateFirstUser(t, client)
auditor := *api.AGPL.Auditor.Load()
ea := agplaudit.NewNop()
t.Logf("%T = %T", auditor, ea)
assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type())
})
}

View File

@ -0,0 +1,133 @@
package coderdenttest
import (
"context"
"crypto/ed25519"
"crypto/rand"
"io"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd"
)
const (
testKeyID = "enterprise-test"
)
var (
testPrivateKey ed25519.PrivateKey
testPublicKey ed25519.PublicKey
)
func init() {
var err error
testPublicKey, testPrivateKey, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
}
type Options struct {
*coderdtest.Options
EntitlementsUpdateInterval time.Duration
}
// New constructs a codersdk client connected to an in-memory Enterprise API instance.
func New(t *testing.T, options *Options) *codersdk.Client {
client, _, _ := NewWithAPI(t, options)
return client
}
func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
if options == nil {
options = &Options{}
}
if options.Options == nil {
options.Options = &coderdtest.Options{}
}
srv, cancelFunc, oop := coderdtest.NewOptions(t, options.Options)
coderAPI, err := coderd.New(context.Background(), &coderd.Options{
AuditLogging: true,
Options: oop,
EntitlementsUpdateInterval: options.EntitlementsUpdateInterval,
Keys: map[string]ed25519.PublicKey{
testKeyID: testPublicKey,
},
})
assert.NoError(t, err)
srv.Config.Handler = coderAPI.AGPL.RootHandler
var provisionerCloser io.Closer = nopcloser{}
if options.IncludeProvisionerDaemon {
provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL)
}
t.Cleanup(func() {
cancelFunc()
_ = provisionerCloser.Close()
_ = coderAPI.Close()
})
return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI
}
type LicenseOptions struct {
AccountType string
AccountID string
GraceAt time.Time
ExpiresAt time.Time
UserLimit int64
AuditLog bool
}
// AddLicense generates a new license with the options provided and inserts it.
func AddLicense(t *testing.T, client *codersdk.Client, options LicenseOptions) codersdk.License {
license, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
License: GenerateLicense(t, options),
})
require.NoError(t, err)
return license
}
// GenerateLicense returns a signed JWT using the test key.
func GenerateLicense(t *testing.T, options LicenseOptions) string {
if options.ExpiresAt.IsZero() {
options.ExpiresAt = time.Now().Add(time.Hour)
}
if options.GraceAt.IsZero() {
options.GraceAt = time.Now().Add(time.Hour)
}
auditLog := int64(0)
if options.AuditLog {
auditLog = 1
}
c := &coderd.Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@testing.test",
ExpiresAt: jwt.NewNumericDate(options.ExpiresAt),
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
},
LicenseExpires: jwt.NewNumericDate(options.GraceAt),
AccountType: options.AccountType,
AccountID: options.AccountID,
Version: coderd.CurrentVersion,
Features: coderd.Features{
UserLimit: options.UserLimit,
AuditLog: auditLog,
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)
tok.Header[coderd.HeaderKeyID] = testKeyID
signedTok, err := tok.SignedString(testPrivateKey)
require.NoError(t, err)
return signedTok
}
type nopcloser struct{}
func (nopcloser) Close() error { return nil }

View File

@ -0,0 +1,51 @@
package coderdenttest_test
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
)
func TestNew(t *testing.T) {
t.Parallel()
_ = coderdenttest.New(t, nil)
}
func TestAuthorizeAllEndpoints(t *testing.T) {
t.Parallel()
client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Authorizer: &coderdtest.RecordingAuthorizer{},
IncludeProvisionerDaemon: true,
},
})
admin := coderdtest.CreateFirstUser(t, client)
license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{})
a := coderdtest.NewAuthTester(context.Background(), t, client, api.AGPL, admin)
a.URLParams["licenses/{id}"] = fmt.Sprintf("licenses/%d", license.ID)
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
assertRoute["GET:/api/v2/entitlements"] = coderdtest.RouteCheck{
NoAuthorize: true,
}
assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionCreate,
AssertObject: rbac.ResourceLicense,
}
assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{
StatusCode: http.StatusOK,
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceLicense,
}
assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionDelete,
AssertObject: rbac.ResourceLicense,
}
a.Test(context.Background(), assertRoute, skipRoutes)
}

View File

@ -1,327 +0,0 @@
package coderd
import (
"context"
"crypto/ed25519"
"fmt"
"net/http"
"reflect"
"sync"
"time"
"github.com/coder/coder/enterprise/audit/backends"
"github.com/cenkalti/backoff/v4"
"golang.org/x/xerrors"
"cdr.dev/slog"
agpl "github.com/coder/coder/coderd"
agplAudit "github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/features"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/audit"
)
type Enablements struct {
AuditLogs bool
}
type featuresService struct {
logger slog.Logger
database database.Store
pubsub database.Pubsub
keys map[string]ed25519.PublicKey
enablements Enablements
resyncInterval time.Duration
// enabledImplementations includes an "enabled" implementation of every feature. This is
// initialized at start of day and remains static. The consequence of this is that these things
// are hanging around using memory even if not licensed or in use, but it greatly simplifies the
// logic because we don't have to bother creating and destroying them as entitlements change.
// If we have a particularly memory-hungry feature in future, we might wish to reconsider this
// choice.
enabledImplementations agpl.FeatureInterfaces
mu sync.RWMutex
entitlements entitlements
}
// newFeaturesService creates a FeaturesService and starts it. It will continue running for the
// duration of the passed ctx.
func newFeaturesService(
ctx context.Context,
logger slog.Logger,
db database.Store,
pubsub database.Pubsub,
enablements Enablements,
) features.Service {
fs := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: keys,
enablements: enablements,
enabledImplementations: agpl.FeatureInterfaces{
Auditor: audit.NewAuditor(
audit.DefaultFilter,
backends.NewPostgres(db, true),
backends.NewSlog(logger),
),
},
resyncInterval: 10 * time.Minute,
entitlements: entitlements{
activeUsers: numericalEntitlement{
entitlementLimit: entitlementLimit{
unlimited: true,
},
},
},
}
go fs.syncEntitlements(ctx)
return fs
}
func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Request) {
s.mu.RLock()
e := s.entitlements
s.mu.RUnlock()
resp := codersdk.Entitlements{
Features: make(map[string]codersdk.Feature),
Warnings: make([]string, 0),
HasLicense: e.hasLicense,
}
// User limit
uf := codersdk.Feature{
Entitlement: e.activeUsers.state.toSDK(),
Enabled: true,
}
if !e.activeUsers.unlimited {
n, err := s.database.GetActiveUserCount(r.Context())
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Unable to query database",
Detail: err.Error(),
})
return
}
uf.Actual = &n
uf.Limit = &e.activeUsers.limit
if n > e.activeUsers.limit {
resp.Warnings = append(resp.Warnings,
fmt.Sprintf(
"Your deployment has %d active users but is only licensed for %d.",
n, e.activeUsers.limit))
}
}
resp.Features[codersdk.FeatureUserLimit] = uf
// Audit logs
resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{
Entitlement: e.auditLogs.state.toSDK(),
Enabled: s.enablements.AuditLogs,
}
if e.auditLogs.state == gracePeriod && s.enablements.AuditLogs {
resp.Warnings = append(resp.Warnings,
"Audit logging is enabled but your license for this feature is expired.")
}
httpapi.Write(rw, http.StatusOK, resp)
}
type entitlementState int
const (
notEntitled entitlementState = iota
gracePeriod
entitled
)
type entitlementLimit struct {
unlimited bool
limit int64
}
type entitlement struct {
state entitlementState
}
func (s entitlementState) toSDK() codersdk.Entitlement {
switch s {
case notEntitled:
return codersdk.EntitlementNotEntitled
case gracePeriod:
return codersdk.EntitlementGracePeriod
case entitled:
return codersdk.EntitlementEntitled
default:
panic("unknown entitlementState")
}
}
type numericalEntitlement struct {
entitlement
entitlementLimit
}
type entitlements struct {
hasLicense bool
activeUsers numericalEntitlement
auditLogs entitlement
}
func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) {
licenses, err := s.database.GetUnexpiredLicenses(ctx)
if err != nil {
return entitlements{}, err
}
now := time.Now()
e := entitlements{
activeUsers: numericalEntitlement{
entitlementLimit: entitlementLimit{
unlimited: true,
},
},
}
for _, l := range licenses {
claims, err := validateDBLicense(l, s.keys)
if err != nil {
s.logger.Debug(ctx, "skipping invalid license",
slog.F("id", l.ID), slog.Error(err))
continue
}
e.hasLicense = true
thisEntitlement := entitled
if now.After(claims.LicenseExpires.Time) {
// if the grace period were over, the validation fails, so if we are after
// LicenseExpires we must be in grace period.
thisEntitlement = gracePeriod
}
if claims.Features.UserLimit > 0 {
e.activeUsers.state = thisEntitlement
e.activeUsers.unlimited = false
e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit)
}
if claims.Features.AuditLog > 0 {
e.auditLogs.state = thisEntitlement
}
}
return e, nil
}
func (s *featuresService) syncEntitlements(ctx context.Context) {
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0 // retry indefinitely
b := backoff.WithContext(eb, ctx)
updates := make(chan struct{}, 1)
subscribed := false
for {
select {
case <-ctx.Done():
return
default:
// pass
}
if !subscribed {
cancel, err := s.pubsub.Subscribe(PubSubEventLicenses, func(_ context.Context, _ []byte) {
// don't block. If the channel is full, drop the event, as there is a resync
// scheduled already.
select {
case updates <- struct{}{}:
// pass
default:
// pass
}
})
if err != nil {
s.logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err))
time.Sleep(b.NextBackOff())
continue
}
// nolint: revive
defer cancel()
subscribed = true
s.logger.Debug(ctx, "successfully subscribed to pubsub")
}
s.logger.Info(ctx, "syncing licensed entitlements")
ents, err := s.getEntitlements(ctx)
if err != nil {
s.logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err))
time.Sleep(b.NextBackOff())
continue
}
b.Reset()
s.mu.Lock()
s.entitlements = ents
s.mu.Unlock()
s.logger.Debug(ctx, "synced licensed entitlements")
select {
case <-ctx.Done():
return
case <-time.After(s.resyncInterval):
continue
case <-updates:
s.logger.Debug(ctx, "got pubsub update")
continue
}
}
}
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}
func (s *featuresService) Get(ps any) error {
if reflect.TypeOf(ps).Kind() != reflect.Pointer {
return xerrors.New("input must be pointer to struct")
}
vs := reflect.ValueOf(ps).Elem()
if vs.Kind() != reflect.Struct {
return xerrors.New("input must be pointer to struct")
}
// grab a local copy of entitlements so that we have a consistent set, but aren't keeping it
// locked from updates while we process.
s.mu.RLock()
ent := s.entitlements
s.mu.RUnlock()
for i := 0; i < vs.NumField(); i++ {
vf := vs.Field(i)
tf := vf.Type()
if tf.Kind() != reflect.Interface {
return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String())
}
err := s.setImplementation(ent, vf, tf)
if err != nil {
return err
}
}
return nil
}
func (s *featuresService) setImplementation(ent entitlements, vf reflect.Value, tf reflect.Type) error {
// c.f. https://stackoverflow.com/questions/7132848/how-to-get-the-reflect-type-of-an-interface
switch tf {
case reflect.TypeOf((*agplAudit.Auditor)(nil)).Elem():
// Audit logging
if !s.enablements.AuditLogs || ent.auditLogs.state == notEntitled {
vf.Set(reflect.ValueOf(agpl.DisabledImplementations.Auditor))
return nil
}
vf.Set(reflect.ValueOf(s.enabledImplementations.Auditor))
return nil
default:
return xerrors.Errorf("unable to find implementation of interface %s", tf.String())
}
}

View File

@ -1,545 +0,0 @@
package coderd
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
agplCoderd "github.com/coder/coder/coderd"
agplAudit "github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/features"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/audit"
"github.com/coder/coder/enterprise/audit/backends"
"github.com/coder/coder/testutil"
)
func TestFeaturesService_EntitlementsAPI(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// Note that these are not actually used because we don't run the syncEntitlements
// routine in this test.
pubsub := database.NewPubsubInMemory()
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := "testing"
t.Run("NoLicense", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
result := requestEntitlements(t, uut)
assert.False(t, result.HasLicense)
assert.Empty(t, result.Warnings)
assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureUserLimit].Entitlement)
assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureAuditLog].Entitlement)
})
t.Run("FullLicense", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
db := databasefake.New()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
entitlements: entitlements{
hasLicense: true,
activeUsers: numericalEntitlement{
entitlement{entitled},
entitlementLimit{
unlimited: false,
limit: 100,
},
},
auditLogs: entitlement{entitled},
},
}
_, err := db.InsertUser(ctx, database.InsertUserParams{
ID: uuid.UUID{},
Email: "",
Username: "",
HashedPassword: nil,
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
RBACRoles: nil,
LoginType: "",
})
require.NoError(t, err)
result := requestEntitlements(t, uut)
assert.True(t, result.HasLicense)
ul := result.Features[codersdk.FeatureUserLimit]
assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement)
assert.Equal(t, int64(100), *ul.Limit)
assert.Equal(t, int64(1), *ul.Actual)
assert.True(t, ul.Enabled)
al := result.Features[codersdk.FeatureAuditLog]
assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement)
assert.True(t, al.Enabled)
assert.Nil(t, al.Limit)
assert.Nil(t, al.Actual)
assert.Empty(t, result.Warnings)
})
t.Run("Warnings", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
db := databasefake.New()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
entitlements: entitlements{
hasLicense: true,
activeUsers: numericalEntitlement{
entitlement{gracePeriod},
entitlementLimit{
unlimited: false,
limit: 4,
},
},
auditLogs: entitlement{gracePeriod},
},
}
for i := byte(0); i < 5; i++ {
_, err := db.InsertUser(ctx, database.InsertUserParams{
ID: uuid.UUID{i},
Email: "",
Username: "",
HashedPassword: nil,
CreatedAt: time.Time{},
UpdatedAt: time.Time{},
RBACRoles: nil,
LoginType: "",
})
require.NoError(t, err)
}
result := requestEntitlements(t, uut)
assert.True(t, result.HasLicense)
ul := result.Features[codersdk.FeatureUserLimit]
assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement)
assert.Equal(t, int64(4), *ul.Limit)
assert.Equal(t, int64(5), *ul.Actual)
assert.True(t, ul.Enabled)
al := result.Features[codersdk.FeatureAuditLog]
assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement)
assert.True(t, al.Enabled)
assert.Nil(t, al.Limit)
assert.Nil(t, al.Actual)
assert.Len(t, result.Warnings, 2)
assert.Contains(t, result.Warnings,
"Your deployment has 5 active users but is only licensed for 4.")
assert.Contains(t, result.Warnings,
"Audit logging is enabled but your license for this feature is expired.")
})
}
func TestFeaturesServiceSyncEntitlements(t *testing.T) {
t.Parallel()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := "testing"
// This tests that pubsub updates work by setting the resync interval very long
t.Run("PubSub", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil)
pubsub := database.NewPubsubInMemory()
db := databasefake.New()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
resyncInterval: time.Hour, // no resyncs during test
entitlements: entitlements{},
}
_, invalidKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Start of day, 3 licenses, one expired, one invalid
_ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour)
_ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour)
l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour)
go uut.syncEntitlements(ctx)
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
// New license
l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour)
err = pubsub.Publish(PubSubEventLicenses, []byte("add"))
require.NoError(t, err)
// User limit goes up, because 305 > 300
testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast)
// New license with lower limit
_ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour)
err = pubsub.Publish(PubSubEventLicenses, []byte("add"))
require.NoError(t, err)
// Need to delete the others before the limit lowers
_, err = db.DeleteLicense(ctx, l1.ID)
require.NoError(t, err)
err = pubsub.Publish(PubSubEventLicenses, []byte("delete"))
require.NoError(t, err)
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
_, err = db.DeleteLicense(ctx, l0.ID)
require.NoError(t, err)
err = pubsub.Publish(PubSubEventLicenses, []byte("delete"))
require.NoError(t, err)
testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast)
})
// This tests that periodic resyncs work by setting the resync interval very fast and
// not sending any pubsub updates.
t.Run("Resyncs", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil)
pubsub := database.NewPubsubInMemory()
db := databasefake.New()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
resyncInterval: 10 * time.Millisecond,
entitlements: entitlements{},
}
_, invalidKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Start of day, 3 licenses, one expired, one invalid
_ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour)
_ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour)
l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour)
go uut.syncEntitlements(ctx)
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
// New license
l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour)
// User limit goes up, because 305 > 300
testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast)
// New license with lower limit
_ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour)
// Need to delete the others before the limit lowers
_, err = db.DeleteLicense(ctx, l1.ID)
require.NoError(t, err)
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
_, err = db.DeleteLicense(ctx, l0.ID)
require.NoError(t, err)
testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast)
})
}
func requestEntitlements(t *testing.T, uut features.Service) codersdk.Entitlements {
t.Helper()
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
rw := httptest.NewRecorder()
uut.EntitlementsAPI(rw, r)
resp := rw.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
dec := json.NewDecoder(resp.Body)
var result codersdk.Entitlements
err := dec.Decode(&result)
require.NoError(t, err)
return result
}
func putLicense(
ctx context.Context, t *testing.T, db database.Store,
k ed25519.PrivateKey, keyID string, userLimit int64,
timeToGrace, timeToExpire time.Duration,
) database.License {
t.Helper()
c := &Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@testing.test",
ExpiresAt: jwt.NewNumericDate(time.Now().Add(timeToExpire)),
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
},
LicenseExpires: jwt.NewNumericDate(time.Now().Add(timeToGrace)),
Version: CurrentVersion,
Features: Features{
UserLimit: userLimit,
AuditLog: 1,
},
}
j, err := makeLicense(c, k, keyID)
require.NoError(t, err)
l, err := db.InsertLicense(ctx, database.InsertLicenseParams{
UploadedAt: c.IssuedAt.Time,
JWT: j,
Exp: c.ExpiresAt.Time,
})
require.NoError(t, err)
return l
}
func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool {
return func(_ context.Context) bool {
fs.mu.RLock()
defer fs.mu.RUnlock()
return fs.entitlements.activeUsers.limit == limit
}
}
func TestFeaturesServiceGet(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// Note that these are not actually used because we don't run the syncEntitlements
// routine in this test.
pubsub := database.NewPubsubInMemory()
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := "testing"
db := databasefake.New()
t.Run("AuditorOff", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
target := struct {
Auditor agplAudit.Auditor
}{}
err := uut.Get(&target)
require.NoError(t, err)
assert.NotNil(t, target.Auditor)
nop := agplAudit.NewNop()
assert.Equal(t, reflect.ValueOf(nop).Type(), reflect.ValueOf(target.Auditor).Type())
})
t.Run("AuditorOn", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{entitled},
},
}
target := struct {
Auditor agplAudit.Auditor
}{}
err := uut.Get(&target)
require.NoError(t, err)
assert.NotNil(t, target.Auditor)
ea := audit.NewAuditor(
audit.DefaultFilter,
backends.NewPostgres(db, true),
backends.NewSlog(logger),
)
assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(target.Auditor).Type())
})
t.Run("NotPointer", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
target := struct {
Auditor agplAudit.Auditor
}{}
err := uut.Get(target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
t.Run("UnknownInterface", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
target := struct {
test testInterface
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.test)
})
t.Run("PointerToNonStruct", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
var target agplAudit.Auditor
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target)
})
t.Run("StructWithNonInterfaces", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
target := struct {
N int64
Auditor agplAudit.Auditor
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
}
type testInterface interface {
Test() error
}

View File

@ -30,7 +30,8 @@ const (
HeaderKeyID = "kid"
AccountTypeSalesforce = "salesforce"
VersionClaim = "version"
PubSubEventLicenses = "licenses"
PubsubEventLicenses = "licenses"
)
var ValidMethods = []string{"EdDSA"}
@ -41,7 +42,7 @@ var ValidMethods = []string{"EdDSA"}
//go:embed keys/2022-08-12
var key20220812 []byte
var keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)}
var Keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)}
type Features struct {
UserLimit int64 `json:"user_limit"`
@ -68,6 +69,193 @@ var (
ErrMissingLicenseExpires = xerrors.New("license missing license_expires")
)
// postLicense adds a new Enterprise license to the cluster. We allow multiple different licenses
// in the cluster at one time for several reasons:
//
// 1. Upgrades --- if the license format changes from one version of Coder to the next, during a
// rolling update you will have different Coder servers that need different licenses to function.
// 2. Avoid abrupt feature breakage --- when an admin uploads a new license with different features
// we generally don't want the old features to immediately break without warning. With a grace
// period on the license, features will continue to work from the old license until its grace
// period, then the users will get a warning allowing them to gracefully stop using the feature.
func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceLicense) {
httpapi.Forbidden(rw)
return
}
var addLicense codersdk.AddLicenseRequest
if !httpapi.Read(rw, r, &addLicense) {
return
}
claims, err := parseLicense(addLicense.License, api.Keys)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Detail: err.Error(),
})
return
}
exp, ok := claims["exp"].(float64)
if !ok {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Detail: "exp claim missing or not parsable",
})
return
}
expTime := time.Unix(int64(exp), 0)
dl, err := api.Database.InsertLicense(r.Context(), database.InsertLicenseParams{
UploadedAt: database.Now(),
JWT: addLicense.License,
Exp: expTime,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Unable to add license to database",
Detail: err.Error(),
})
return
}
err = api.updateEntitlements(r.Context())
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update entitlements",
Detail: err.Error(),
})
return
}
err = api.Pubsub.Publish(PubsubEventLicenses, []byte("add"))
if err != nil {
api.Logger.Error(context.Background(), "failed to publish license add", slog.Error(err))
// don't fail the HTTP request, since we did write it successfully to the database
}
httpapi.Write(rw, http.StatusCreated, convertLicense(dl, claims))
}
func (api *API) licenses(rw http.ResponseWriter, r *http.Request) {
licenses, err := api.Database.GetLicenses(r.Context())
if xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusOK, []codersdk.License{})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching licenses.",
Detail: err.Error(),
})
return
}
licenses, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, rbac.ActionRead, licenses)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching licenses.",
Detail: err.Error(),
})
return
}
sdkLicenses, err := convertLicenses(licenses)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error parsing licenses.",
Detail: err.Error(),
})
return
}
httpapi.Write(rw, http.StatusOK, sdkLicenses)
}
func (api *API) deleteLicense(rw http.ResponseWriter, r *http.Request) {
if !api.AGPL.Authorize(r, rbac.ActionDelete, rbac.ResourceLicense) {
httpapi.Forbidden(rw)
return
}
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 32)
if err != nil {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: "License ID must be an integer",
})
return
}
_, err = api.Database.DeleteLicense(r.Context(), int32(id))
if xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: "Unknown license ID",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error deleting license",
Detail: err.Error(),
})
return
}
err = api.updateEntitlements(r.Context())
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update entitlements",
Detail: err.Error(),
})
return
}
err = api.Pubsub.Publish(PubsubEventLicenses, []byte("delete"))
if err != nil {
api.Logger.Error(context.Background(), "failed to publish license delete", slog.Error(err))
// don't fail the HTTP request, since we did write it successfully to the database
}
rw.WriteHeader(http.StatusOK)
}
func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License {
return codersdk.License{
ID: dl.ID,
UploadedAt: dl.UploadedAt,
Claims: c,
}
}
func convertLicenses(licenses []database.License) ([]codersdk.License, error) {
var out []codersdk.License
for _, l := range licenses {
c, err := decodeClaims(l)
if err != nil {
return nil, err
}
out = append(out, convertLicense(l, c))
}
return out, nil
}
// decodeClaims decodes the JWT claims from the stored JWT. Note here we do not validate the JWT
// and just return the claims verbatim. We want to include all licenses on the GET response, even
// if they are expired, or signed by a key this version of Coder no longer considers valid.
//
// Also, we do not return the whole JWT itself because a signed JWT is a bearer token and we
// want to limit the chance of it being accidentally leaked.
func decodeClaims(l database.License) (jwt.MapClaims, error) {
parts := strings.Split(l.JWT, ".")
if len(parts) != 3 {
return nil, xerrors.Errorf("Unable to parse license %d as JWT", l.ID)
}
cb, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, xerrors.Errorf("Unable to decode license %d claims: %w", l.ID, err)
}
c := make(jwt.MapClaims)
d := json.NewDecoder(bytes.NewBuffer(cb))
d.UseNumber()
err = d.Decode(&c)
return c, err
}
// parseLicense parses the license and returns the claims. If the license's signature is invalid or
// is not parsable, an error is returned.
func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) {
@ -129,203 +317,3 @@ func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, e
return k, nil
}
}
// licenseAPI handles enterprise licenses, and attaches to the main coderd.API via the
// LicenseHandler option, so that it serves all routes under /api/v2/licenses
type licenseAPI struct {
router chi.Router
logger slog.Logger
database database.Store
pubsub database.Pubsub
auth *coderd.HTTPAuthorizer
}
func newLicenseAPI(
l slog.Logger,
db database.Store,
ps database.Pubsub,
auth *coderd.HTTPAuthorizer,
) *licenseAPI {
r := chi.NewRouter()
a := &licenseAPI{router: r, logger: l, database: db, pubsub: ps, auth: auth}
r.Post("/", a.postLicense)
r.Get("/", a.licenses)
r.Delete("/{id}", a.delete)
return a
}
func (a *licenseAPI) handler() http.Handler {
return a.router
}
// postLicense adds a new Enterprise license to the cluster. We allow multiple different licenses
// in the cluster at one time for several reasons:
//
// 1. Upgrades --- if the license format changes from one version of Coder to the next, during a
// rolling update you will have different Coder servers that need different licenses to function.
// 2. Avoid abrupt feature breakage --- when an admin uploads a new license with different features
// we generally don't want the old features to immediately break without warning. With a grace
// period on the license, features will continue to work from the old license until its grace
// period, then the users will get a warning allowing them to gracefully stop using the feature.
func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) {
if !a.auth.Authorize(r, rbac.ActionCreate, rbac.ResourceLicense) {
httpapi.Forbidden(rw)
return
}
var addLicense codersdk.AddLicenseRequest
if !httpapi.Read(rw, r, &addLicense) {
return
}
claims, err := parseLicense(addLicense.License, keys)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Detail: err.Error(),
})
return
}
exp, ok := claims["exp"].(float64)
if !ok {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Detail: "exp claim missing or not parsable",
})
return
}
expTime := time.Unix(int64(exp), 0)
dl, err := a.database.InsertLicense(r.Context(), database.InsertLicenseParams{
UploadedAt: database.Now(),
JWT: addLicense.License,
Exp: expTime,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Unable to add license to database",
Detail: err.Error(),
})
return
}
err = a.pubsub.Publish(PubSubEventLicenses, []byte("add"))
if err != nil {
a.logger.Error(context.Background(), "failed to publish license add", slog.Error(err))
// don't fail the HTTP request, since we did write it successfully to the database
}
httpapi.Write(rw, http.StatusCreated, convertLicense(dl, claims))
}
func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License {
return codersdk.License{
ID: dl.ID,
UploadedAt: dl.UploadedAt,
Claims: c,
}
}
func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) {
licenses, err := a.database.GetLicenses(r.Context())
if xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusOK, []codersdk.License{})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching licenses.",
Detail: err.Error(),
})
return
}
licenses, err = coderd.AuthorizeFilter(a.auth, r, rbac.ActionRead, licenses)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching licenses.",
Detail: err.Error(),
})
return
}
sdkLicenses, err := convertLicenses(licenses)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error parsing licenses.",
Detail: err.Error(),
})
return
}
httpapi.Write(rw, http.StatusOK, sdkLicenses)
}
func convertLicenses(licenses []database.License) ([]codersdk.License, error) {
var out []codersdk.License
for _, l := range licenses {
c, err := decodeClaims(l)
if err != nil {
return nil, err
}
out = append(out, convertLicense(l, c))
}
return out, nil
}
// decodeClaims decodes the JWT claims from the stored JWT. Note here we do not validate the JWT
// and just return the claims verbatim. We want to include all licenses on the GET response, even
// if they are expired, or signed by a key this version of Coder no longer considers valid.
//
// Also, we do not return the whole JWT itself because a signed JWT is a bearer token and we
// want to limit the chance of it being accidentally leaked.
func decodeClaims(l database.License) (jwt.MapClaims, error) {
parts := strings.Split(l.JWT, ".")
if len(parts) != 3 {
return nil, xerrors.Errorf("Unable to parse license %d as JWT", l.ID)
}
cb, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, xerrors.Errorf("Unable to decode license %d claims: %w", l.ID, err)
}
c := make(jwt.MapClaims)
d := json.NewDecoder(bytes.NewBuffer(cb))
d.UseNumber()
err = d.Decode(&c)
return c, err
}
func (a *licenseAPI) delete(rw http.ResponseWriter, r *http.Request) {
if !a.auth.Authorize(r, rbac.ActionDelete, rbac.ResourceLicense) {
httpapi.Forbidden(rw)
return
}
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 32)
if err != nil {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: "License ID must be an integer",
})
return
}
_, err = a.database.DeleteLicense(r.Context(), int32(id))
if xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: "Unknown license ID",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error deleting license",
Detail: err.Error(),
})
return
}
err = a.pubsub.Publish(PubSubEventLicenses, []byte("delete"))
if err != nil {
a.logger.Error(context.Background(), "failed to publish license delete", slog.Error(err))
// don't fail the HTTP request, since we did write it successfully to the database
}
rw.WriteHeader(http.StatusOK)
}

View File

@ -1,316 +0,0 @@
package coderd
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"net/http"
"testing"
"time"
"golang.org/x/xerrors"
"github.com/stretchr/testify/assert"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
// these tests patch the map of license keys, so cannot be run in parallel
// nolint:paralleltest
func TestPostLicense(t *testing.T) {
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := "testing"
oldKeys := keys
defer func() {
t.Log("restoring keys")
keys = oldKeys
}()
keys = map[string]ed25519.PublicKey{keyID: pubKey}
t.Run("POST", func(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
claims := &Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@coder.test",
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
},
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
AccountType: AccountTypeSalesforce,
AccountID: "testing",
Version: CurrentVersion,
Features: Features{
UserLimit: 0,
AuditLog: 1,
},
}
lic, err := makeLicense(claims, privKey, keyID)
require.NoError(t, err)
respLic, err := client.AddLicense(ctx, codersdk.AddLicenseRequest{
License: lic,
})
require.NoError(t, err)
assert.GreaterOrEqual(t, respLic.ID, int32(0))
// just a couple spot checks for sanity
assert.Equal(t, claims.AccountID, respLic.Claims["account_id"])
features, ok := respLic.Claims["features"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog])
})
t.Run("POST_unathorized", func(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
claims := &Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@coder.test",
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
},
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
AccountType: AccountTypeSalesforce,
AccountID: "testing",
Version: CurrentVersion,
Features: Features{
UserLimit: 0,
AuditLog: 1,
},
}
lic, err := makeLicense(claims, privKey, keyID)
require.NoError(t, err)
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
License: lic,
})
errResp := &codersdk.Error{}
if xerrors.As(err, &errResp) {
assert.Equal(t, 401, errResp.StatusCode())
} else {
t.Error("expected to get error status 401")
}
})
t.Run("POST_corrupted", func(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
claims := &Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@coder.test",
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
},
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
AccountType: AccountTypeSalesforce,
AccountID: "testing",
Version: CurrentVersion,
Features: Features{
UserLimit: 0,
AuditLog: 1,
},
}
lic, err := makeLicense(claims, privKey, keyID)
require.NoError(t, err)
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
License: "h" + lic,
})
errResp := &codersdk.Error{}
if xerrors.As(err, &errResp) {
assert.Equal(t, 400, errResp.StatusCode())
} else {
t.Error("expected to get error status 400")
}
})
}
// these tests patch the map of license keys, so cannot be run in parallel
// nolint:paralleltest
func TestGetLicense(t *testing.T) {
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := "testing"
oldKeys := keys
defer func() {
t.Log("restoring keys")
keys = oldKeys
}()
keys = map[string]ed25519.PublicKey{keyID: pubKey}
t.Run("GET", func(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
claims := &Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@coder.test",
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
},
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
AccountType: AccountTypeSalesforce,
AccountID: "testing",
Version: CurrentVersion,
Features: Features{
UserLimit: 0,
AuditLog: 1,
},
}
lic, err := makeLicense(claims, privKey, keyID)
require.NoError(t, err)
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
License: lic,
})
require.NoError(t, err)
// 2nd license
claims.AccountID = "testing2"
claims.Features.UserLimit = 200
lic2, err := makeLicense(claims, privKey, keyID)
require.NoError(t, err)
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
License: lic2,
})
require.NoError(t, err)
licenses, err := client.Licenses(ctx)
require.NoError(t, err)
require.Len(t, licenses, 2)
assert.Equal(t, int32(1), licenses[0].ID)
assert.Equal(t, "testing", licenses[0].Claims["account_id"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("0"),
codersdk.FeatureAuditLog: json.Number("1"),
}, licenses[0].Claims["features"])
assert.Equal(t, int32(2), licenses[1].ID)
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("200"),
codersdk.FeatureAuditLog: json.Number("1"),
}, licenses[1].Claims["features"])
})
}
// these tests patch the map of license keys, so cannot be run in parallel
// nolint:paralleltest
func TestDeleteLicense(t *testing.T) {
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := "testing"
oldKeys := keys
defer func() {
t.Log("restoring keys")
keys = oldKeys
}()
keys = map[string]ed25519.PublicKey{keyID: pubKey}
t.Run("DELETE_empty", func(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
err := client.DeleteLicense(ctx, 1)
errResp := &codersdk.Error{}
if xerrors.As(err, &errResp) {
assert.Equal(t, 404, errResp.StatusCode())
} else {
t.Error("expected to get error status 404")
}
})
t.Run("DELETE_bad_id", func(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
resp, err := client.Request(ctx, http.MethodDelete, "/api/v2/licenses/drivers", nil)
require.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
require.NoError(t, resp.Body.Close())
})
t.Run("DELETE", func(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
claims := &Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@coder.test",
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
},
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
AccountType: AccountTypeSalesforce,
AccountID: "testing",
Version: CurrentVersion,
Features: Features{
UserLimit: 0,
AuditLog: 1,
},
}
lic, err := makeLicense(claims, privKey, keyID)
require.NoError(t, err)
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
License: lic,
})
require.NoError(t, err)
// 2nd license
claims.AccountID = "testing2"
claims.Features.UserLimit = 200
lic2, err := makeLicense(claims, privKey, keyID)
require.NoError(t, err)
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
License: lic2,
})
require.NoError(t, err)
licenses, err := client.Licenses(ctx)
require.NoError(t, err)
assert.Len(t, licenses, 2)
for _, l := range licenses {
err = client.DeleteLicense(ctx, l.ID)
require.NoError(t, err)
}
licenses, err = client.Licenses(ctx)
require.NoError(t, err)
assert.Len(t, licenses, 0)
})
}
func makeLicense(c *Claims, privateKey ed25519.PrivateKey, keyID string) (string, error) {
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)
tok.Header[HeaderKeyID] = keyID
signedTok, err := tok.SignedString(privateKey)
if err != nil {
return "", xerrors.Errorf("sign license: %w", err)
}
return signedTok, nil
}

View File

@ -0,0 +1,168 @@
package coderd_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/testutil"
)
func TestPostLicense(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountType: coderd.AccountTypeSalesforce,
AccountID: "testing",
AuditLog: true,
})
assert.GreaterOrEqual(t, respLic.ID, int32(0))
// just a couple spot checks for sanity
assert.Equal(t, "testing", respLic.Claims["account_id"])
features, ok := respLic.Claims["features"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog])
})
t.Run("Unauthorized", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
License: "content",
})
errResp := &codersdk.Error{}
if xerrors.As(err, &errResp) {
assert.Equal(t, 401, errResp.StatusCode())
} else {
t.Error("expected to get error status 401")
}
})
t.Run("Corrupted", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{})
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
License: "invalid",
})
errResp := &codersdk.Error{}
if xerrors.As(err, &errResp) {
assert.Equal(t, 400, errResp.StatusCode())
} else {
t.Error("expected to get error status 400")
}
})
}
func TestGetLicense(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountID: "testing",
AuditLog: true,
})
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountID: "testing2",
AuditLog: true,
UserLimit: 200,
})
licenses, err := client.Licenses(ctx)
require.NoError(t, err)
require.Len(t, licenses, 2)
assert.Equal(t, int32(1), licenses[0].ID)
assert.Equal(t, "testing", licenses[0].Claims["account_id"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("0"),
codersdk.FeatureAuditLog: json.Number("1"),
}, licenses[0].Claims["features"])
assert.Equal(t, int32(2), licenses[1].ID)
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("200"),
codersdk.FeatureAuditLog: json.Number("1"),
}, licenses[1].Claims["features"])
})
}
func TestDeleteLicense(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
err := client.DeleteLicense(ctx, 1)
errResp := &codersdk.Error{}
if xerrors.As(err, &errResp) {
assert.Equal(t, 404, errResp.StatusCode())
} else {
t.Error("expected to get error status 404")
}
})
t.Run("BadID", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
resp, err := client.Request(ctx, http.MethodDelete, "/api/v2/licenses/drivers", nil)
require.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
require.NoError(t, resp.Body.Close())
})
t.Run("Success", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountID: "testing",
AuditLog: true,
})
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountID: "testing2",
AuditLog: true,
UserLimit: 200,
})
licenses, err := client.Licenses(ctx)
require.NoError(t, err)
assert.Len(t, licenses, 2)
for _, l := range licenses {
err = client.DeleteLicense(ctx, l.ID)
require.NoError(t, err)
}
licenses, err = client.Licenses(ctx)
require.NoError(t, err)
assert.Len(t, licenses, 0)
})
}

View File

@ -16,6 +16,7 @@
formatter = pkgs.nixpkgs-fmt;
devShells.default = pkgs.mkShell {
buildInputs = with pkgs; [
bash
bat
drpc.defaultPackage.${system}
exa

View File

@ -16,6 +16,22 @@ export const hardCodedCSRFCookie = (): string => {
return csrfToken
}
// defaultEntitlements has a default set of disabled functionality.
export const defaultEntitlements = (): TypesGen.Entitlements => {
const features: TypesGen.Entitlements["features"] = {}
for (const feature in Types.FeatureNames) {
features[feature] = {
enabled: false,
entitlement: "not_entitled",
}
}
return {
features: features,
has_license: false,
warnings: [],
}
}
// Always attach CSRF token to all requests.
// In puppeteer the document is undefined. In those cases, just
// do nothing.
@ -424,8 +440,15 @@ export const putWorkspaceExtension = async (
}
export const getEntitlements = async (): Promise<TypesGen.Entitlements> => {
const response = await axios.get("/api/v2/entitlements")
return response.data
try {
const response = await axios.get("/api/v2/entitlements")
return response.data
} catch (error) {
if (axios.isAxiosError(error) && error.response?.status === 404) {
return defaultEntitlements()
}
throw error
}
}
export const getAuditLogs = async (

View File

@ -84,7 +84,7 @@ export const entitlementsMachine = createMachine(
}),
},
services: {
getEntitlements: () => API.getEntitlements(),
getEntitlements: API.getEntitlements,
},
},
)