mirror of https://github.com/coder/coder.git
chore: Refactor Enterprise code to layer on top of AGPL (#4034)
* chore: Refactor Enterprise code to layer on top of AGPL This is an experiment to invert the import order of the Enterprise code to layer on top of AGPL. * Fix Garrett's comments * Add pointer.Handle to atomically obtain references This uses a context to ensure the same value persists through multiple executions to `Load()`. * Remove entitlements API from AGPL coderd * Remove AGPL Coder entitlements endpoint test * Fix warnings output * Add command-line flag to toggle audit logging * Fix hasLicense being set * Remove features interface * Fix audit logging default * Add bash as a dependency * Add comment * Add tests for resync and pubsub, and add back previous exp backoff retry * Separate authz code again * Add pointer loading example from comment * Fix duplicate test, remove pointer.Handle * Fix expired license * Add entitlements struct * Fix context passing
This commit is contained in:
parent
714c366d16
commit
db0ba8588e
|
@ -20,8 +20,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
@ -83,18 +81,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str
|
|||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
t.Cleanup(func() { require.NoError(t, <-errC) })
|
||||
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
require.Eventually(t, func() bool {
|
||||
_, err = dialer.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
return agentClient, agentToken, pubkey
|
||||
}
|
||||
|
||||
|
|
15
cli/root.go
15
cli/root.go
|
@ -91,12 +91,13 @@ func Core() []*cobra.Command {
|
|||
users(),
|
||||
versionCmd(),
|
||||
workspaceAgent(),
|
||||
features(),
|
||||
}
|
||||
}
|
||||
|
||||
func AGPL() []*cobra.Command {
|
||||
all := append(Core(), Server(coderd.New))
|
||||
all := append(Core(), Server(func(_ context.Context, o *coderd.Options) (*coderd.API, error) {
|
||||
return coderd.New(o), nil
|
||||
}))
|
||||
return all
|
||||
}
|
||||
|
||||
|
@ -548,13 +549,11 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {
|
|||
defer cancel()
|
||||
|
||||
entitlements, err := client.Entitlements(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get entitlements to show warnings: %w", err)
|
||||
if err == nil {
|
||||
for _, w := range entitlements.Warnings {
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w))
|
||||
}
|
||||
}
|
||||
for _, w := range entitlements.Warnings {
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ import (
|
|||
)
|
||||
|
||||
// nolint:gocyclo
|
||||
func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command {
|
||||
var (
|
||||
accessURL string
|
||||
address string
|
||||
|
@ -489,7 +489,10 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
|||
), promAddress, "prometheus")()
|
||||
}
|
||||
|
||||
coderAPI := newAPI(options)
|
||||
coderAPI, err := newAPI(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer coderAPI.Close()
|
||||
|
||||
client := codersdk.New(localURL)
|
||||
|
@ -536,7 +539,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
|||
// These errors are typically noise like "TLS: EOF". Vault does similar:
|
||||
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
|
||||
ErrorLog: log.New(io.Discard, "", 0),
|
||||
Handler: coderAPI.Handler,
|
||||
Handler: coderAPI.RootHandler,
|
||||
BaseContext: func(_ net.Listener) context.Context {
|
||||
return shutdownConnsCtx
|
||||
},
|
||||
|
|
|
@ -12,14 +12,13 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/features"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
type RequestParams struct {
|
||||
Features features.Service
|
||||
Log slog.Logger
|
||||
Audit Auditor
|
||||
Log slog.Logger
|
||||
|
||||
Request *http.Request
|
||||
Action database.AuditAction
|
||||
|
@ -102,15 +101,6 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
|
|||
params: p,
|
||||
}
|
||||
|
||||
feats := struct {
|
||||
Audit Auditor
|
||||
}{}
|
||||
err := p.Features.Get(&feats)
|
||||
if err != nil {
|
||||
p.Log.Error(p.Request.Context(), "unable to get auditor interface", slog.Error(err))
|
||||
return req, func() {}
|
||||
}
|
||||
|
||||
return req, func() {
|
||||
ctx := context.Background()
|
||||
logCtx := p.Request.Context()
|
||||
|
@ -120,7 +110,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
|
|||
return
|
||||
}
|
||||
|
||||
diff := Diff(feats.Audit, req.Old, req.New)
|
||||
diff := Diff(p.Audit, req.Old, req.New)
|
||||
diffRaw, _ := json.Marshal(diff)
|
||||
|
||||
ip, err := parseIP(p.Request.RemoteAddr)
|
||||
|
@ -128,7 +118,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
|
|||
p.Log.Warn(logCtx, "parse ip", slog.Error(err))
|
||||
}
|
||||
|
||||
err = feats.Audit.Export(ctx, database.AuditLog{
|
||||
err = p.Audit.Export(ctx, database.AuditLog{
|
||||
ID: uuid.New(),
|
||||
Time: database.Now(),
|
||||
UserID: httpmw.APIKey(p.Request).UserID,
|
||||
|
|
|
@ -43,7 +43,7 @@ type HTTPAuthorizer struct {
|
|||
// return
|
||||
// }
|
||||
func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
|
||||
return api.httpAuth.Authorize(r, action, object)
|
||||
return api.HTTPAuth.Authorize(r, action, object)
|
||||
}
|
||||
|
||||
// Authorize will return false if the user is not authorized to do the action.
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/url"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
|
@ -24,9 +25,9 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/awsidentity"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/features"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
|
@ -50,6 +51,7 @@ type Options struct {
|
|||
// CacheDir is used for caching files served by the API.
|
||||
CacheDir string
|
||||
|
||||
Auditor audit.Auditor
|
||||
AgentConnectionUpdateFrequency time.Duration
|
||||
AgentInactiveDisconnectTimeout time.Duration
|
||||
// APIRateLimit is the minutely throughput rate limit per user or ip.
|
||||
|
@ -68,8 +70,6 @@ type Options struct {
|
|||
Telemetry telemetry.Reporter
|
||||
TracerProvider trace.TracerProvider
|
||||
AutoImportTemplates []AutoImportTemplate
|
||||
LicenseHandler http.Handler
|
||||
FeaturesService features.Service
|
||||
|
||||
TailnetCoordinator *tailnet.Coordinator
|
||||
DERPMap *tailcfg.DERPMap
|
||||
|
@ -80,6 +80,9 @@ type Options struct {
|
|||
|
||||
// New constructs a Coder API handler.
|
||||
func New(options *Options) *API {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
if options.AgentConnectionUpdateFrequency == 0 {
|
||||
options.AgentConnectionUpdateFrequency = 3 * time.Second
|
||||
}
|
||||
|
@ -117,11 +120,8 @@ func New(options *Options) *API {
|
|||
if options.TailnetCoordinator == nil {
|
||||
options.TailnetCoordinator = tailnet.NewCoordinator()
|
||||
}
|
||||
if options.LicenseHandler == nil {
|
||||
options.LicenseHandler = licenses()
|
||||
}
|
||||
if options.FeaturesService == nil {
|
||||
options.FeaturesService = &featuresService{}
|
||||
if options.Auditor == nil {
|
||||
options.Auditor = audit.NewNop()
|
||||
}
|
||||
|
||||
siteCacheDir := options.CacheDir
|
||||
|
@ -142,14 +142,16 @@ func New(options *Options) *API {
|
|||
r := chi.NewRouter()
|
||||
api := &API{
|
||||
Options: options,
|
||||
Handler: r,
|
||||
RootHandler: r,
|
||||
siteHandler: site.Handler(site.FS(), binFS),
|
||||
httpAuth: &HTTPAuthorizer{
|
||||
HTTPAuth: &HTTPAuthorizer{
|
||||
Authorizer: options.Authorizer,
|
||||
Logger: options.Logger,
|
||||
},
|
||||
metricsCache: metricsCache,
|
||||
Auditor: atomic.Pointer[audit.Auditor]{},
|
||||
}
|
||||
api.Auditor.Store(&options.Auditor)
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
|
||||
oauthConfigs := &httpmw.OAuth2Configs{
|
||||
|
@ -218,6 +220,8 @@ func New(options *Options) *API {
|
|||
})
|
||||
|
||||
r.Route("/api/v2", func(r chi.Router) {
|
||||
api.APIHandler = r
|
||||
|
||||
r.NotFound(func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Route not found.",
|
||||
|
@ -473,14 +477,6 @@ func New(options *Options) *API {
|
|||
r.Get("/resources", api.workspaceBuildResources)
|
||||
r.Get("/state", api.workspaceBuildState)
|
||||
})
|
||||
r.Route("/entitlements", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/", api.FeaturesService.EntitlementsAPI)
|
||||
})
|
||||
r.Route("/licenses", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Mount("/", options.LicenseHandler)
|
||||
})
|
||||
})
|
||||
|
||||
r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP)
|
||||
|
@ -489,17 +485,20 @@ func New(options *Options) *API {
|
|||
|
||||
type API struct {
|
||||
*Options
|
||||
Auditor atomic.Pointer[audit.Auditor]
|
||||
HTTPAuth *HTTPAuthorizer
|
||||
|
||||
derpServer *derp.Server
|
||||
// APIHandler serves "/api/v2"
|
||||
APIHandler chi.Router
|
||||
// RootHandler serves "/"
|
||||
RootHandler chi.Router
|
||||
|
||||
Handler chi.Router
|
||||
derpServer *derp.Server
|
||||
metricsCache *metricscache.Cache
|
||||
siteHandler http.Handler
|
||||
websocketWaitMutex sync.Mutex
|
||||
websocketWaitGroup sync.WaitGroup
|
||||
workspaceAgentCache *wsconncache.Cache
|
||||
httpAuth *HTTPAuthorizer
|
||||
|
||||
metricsCache *metricscache.Cache
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
|
|
|
@ -38,16 +38,6 @@ func TestBuildInfo(t *testing.T) {
|
|||
require.Equal(t, buildinfo.Version(), buildInfo.Version, "version")
|
||||
}
|
||||
|
||||
// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered.
|
||||
func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
a := coderdtest.NewAuthTester(ctx, t, nil)
|
||||
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
|
||||
a.Test(ctx, assertRoute, skipRoutes)
|
||||
}
|
||||
|
||||
func TestDERP(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
|
|
|
@ -22,143 +22,6 @@ import (
|
|||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
type RouteCheck struct {
|
||||
NoAuthorize bool
|
||||
AssertAction rbac.Action
|
||||
AssertObject rbac.Object
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
type AuthTester struct {
|
||||
t *testing.T
|
||||
api *coderd.API
|
||||
authorizer *recordingAuthorizer
|
||||
|
||||
Client *codersdk.Client
|
||||
Workspace codersdk.Workspace
|
||||
Organization codersdk.Organization
|
||||
Admin codersdk.CreateFirstUserResponse
|
||||
Template codersdk.Template
|
||||
Version codersdk.TemplateVersion
|
||||
WorkspaceResource codersdk.WorkspaceResource
|
||||
File codersdk.UploadResponse
|
||||
TemplateVersionDryRun codersdk.ProvisionerJob
|
||||
TemplateParam codersdk.Parameter
|
||||
URLParams map[string]string
|
||||
}
|
||||
|
||||
func NewAuthTester(ctx context.Context, t *testing.T, options *Options) *AuthTester {
|
||||
authorizer := &recordingAuthorizer{}
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
if options.Authorizer != nil {
|
||||
t.Error("NewAuthTester cannot be called with custom Authorizer")
|
||||
}
|
||||
options.Authorizer = authorizer
|
||||
options.IncludeProvisionerDaemon = true
|
||||
|
||||
client, _, api := newWithAPI(t, options)
|
||||
admin := CreateFirstUser(t, client)
|
||||
// The provisioner will call to coderd and register itself. This is async,
|
||||
// so we wait for it to occur.
|
||||
require.Eventually(t, func() bool {
|
||||
provisionerds, err := client.ProvisionerDaemons(ctx)
|
||||
return assert.NoError(t, err) && len(provisionerds) > 0
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
|
||||
provisionerds, err := client.ProvisionerDaemons(ctx)
|
||||
require.NoError(t, err, "fetch provisioners")
|
||||
require.Len(t, provisionerds, 1)
|
||||
|
||||
organization, err := client.Organization(ctx, admin.OrganizationID)
|
||||
require.NoError(t, err, "fetch org")
|
||||
|
||||
// Setup some data in the database.
|
||||
version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
// Return a workspace resource
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agents: []*proto.Agent{{
|
||||
Name: "agent",
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Apps: []*proto.App{{
|
||||
Name: "testapp",
|
||||
Url: "http://localhost:3000",
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
AwaitTemplateVersionJob(t, client, version.ID)
|
||||
template := CreateTemplate(t, client, admin.OrganizationID, version.ID)
|
||||
workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID)
|
||||
AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024))
|
||||
require.NoError(t, err, "upload file")
|
||||
workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err, "workspace resources")
|
||||
templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{
|
||||
ParameterValues: []codersdk.CreateParameterRequest{},
|
||||
})
|
||||
require.NoError(t, err, "template version dry-run")
|
||||
|
||||
templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{
|
||||
Name: "test-param",
|
||||
SourceValue: "hello world",
|
||||
SourceScheme: codersdk.ParameterSourceSchemeData,
|
||||
DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable,
|
||||
})
|
||||
require.NoError(t, err, "create template param")
|
||||
|
||||
urlParameters := map[string]string{
|
||||
"{organization}": admin.OrganizationID.String(),
|
||||
"{user}": admin.UserID.String(),
|
||||
"{organizationname}": organization.Name,
|
||||
"{workspace}": workspace.ID.String(),
|
||||
"{workspacebuild}": workspace.LatestBuild.ID.String(),
|
||||
"{workspacename}": workspace.Name,
|
||||
"{workspaceagent}": workspaceResources[0].Agents[0].ID.String(),
|
||||
"{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10),
|
||||
"{template}": template.ID.String(),
|
||||
"{hash}": file.Hash,
|
||||
"{workspaceresource}": workspaceResources[0].ID.String(),
|
||||
"{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name,
|
||||
"{templateversion}": version.ID.String(),
|
||||
"{jobID}": templateVersionDryRun.ID.String(),
|
||||
"{templatename}": template.Name,
|
||||
"{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name,
|
||||
// Only checking template scoped params here
|
||||
"parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s",
|
||||
string(templateParam.Scope), templateParam.ScopeID.String()),
|
||||
}
|
||||
|
||||
return &AuthTester{
|
||||
t: t,
|
||||
api: api,
|
||||
authorizer: authorizer,
|
||||
Client: client,
|
||||
Workspace: workspace,
|
||||
Organization: organization,
|
||||
Admin: admin,
|
||||
Template: template,
|
||||
Version: version,
|
||||
WorkspaceResource: workspaceResources[0],
|
||||
File: file,
|
||||
TemplateVersionDryRun: templateVersionDryRun,
|
||||
TemplateParam: templateParam,
|
||||
URLParams: urlParameters,
|
||||
}
|
||||
}
|
||||
|
||||
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
// Some quick reused objects
|
||||
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||
|
@ -181,7 +44,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
"POST:/api/v2/users/login": {NoAuthorize: true},
|
||||
"GET:/api/v2/users/authmethods": {NoAuthorize: true},
|
||||
"POST:/api/v2/csp/reports": {NoAuthorize: true},
|
||||
"GET:/api/v2/entitlements": {NoAuthorize: true},
|
||||
|
||||
// Has it's own auth
|
||||
"GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true},
|
||||
|
@ -408,6 +270,134 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
return skipRoutes, assertRoute
|
||||
}
|
||||
|
||||
type RouteCheck struct {
|
||||
NoAuthorize bool
|
||||
AssertAction rbac.Action
|
||||
AssertObject rbac.Object
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
type AuthTester struct {
|
||||
t *testing.T
|
||||
api *coderd.API
|
||||
authorizer *RecordingAuthorizer
|
||||
|
||||
Client *codersdk.Client
|
||||
Workspace codersdk.Workspace
|
||||
Organization codersdk.Organization
|
||||
Admin codersdk.CreateFirstUserResponse
|
||||
Template codersdk.Template
|
||||
Version codersdk.TemplateVersion
|
||||
WorkspaceResource codersdk.WorkspaceResource
|
||||
File codersdk.UploadResponse
|
||||
TemplateVersionDryRun codersdk.ProvisionerJob
|
||||
TemplateParam codersdk.Parameter
|
||||
URLParams map[string]string
|
||||
}
|
||||
|
||||
func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, api *coderd.API, admin codersdk.CreateFirstUserResponse) *AuthTester {
|
||||
authorizer, ok := api.Authorizer.(*RecordingAuthorizer)
|
||||
if !ok {
|
||||
t.Fail()
|
||||
}
|
||||
// The provisioner will call to coderd and register itself. This is async,
|
||||
// so we wait for it to occur.
|
||||
require.Eventually(t, func() bool {
|
||||
provisionerds, err := client.ProvisionerDaemons(ctx)
|
||||
return assert.NoError(t, err) && len(provisionerds) > 0
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
|
||||
provisionerds, err := client.ProvisionerDaemons(ctx)
|
||||
require.NoError(t, err, "fetch provisioners")
|
||||
require.Len(t, provisionerds, 1)
|
||||
|
||||
organization, err := client.Organization(ctx, admin.OrganizationID)
|
||||
require.NoError(t, err, "fetch org")
|
||||
|
||||
// Setup some data in the database.
|
||||
version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
// Return a workspace resource
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agents: []*proto.Agent{{
|
||||
Name: "agent",
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Apps: []*proto.App{{
|
||||
Name: "testapp",
|
||||
Url: "http://localhost:3000",
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
AwaitTemplateVersionJob(t, client, version.ID)
|
||||
template := CreateTemplate(t, client, admin.OrganizationID, version.ID)
|
||||
workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID)
|
||||
AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024))
|
||||
require.NoError(t, err, "upload file")
|
||||
workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err, "workspace resources")
|
||||
templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{
|
||||
ParameterValues: []codersdk.CreateParameterRequest{},
|
||||
})
|
||||
require.NoError(t, err, "template version dry-run")
|
||||
|
||||
templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{
|
||||
Name: "test-param",
|
||||
SourceValue: "hello world",
|
||||
SourceScheme: codersdk.ParameterSourceSchemeData,
|
||||
DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable,
|
||||
})
|
||||
require.NoError(t, err, "create template param")
|
||||
urlParameters := map[string]string{
|
||||
"{organization}": admin.OrganizationID.String(),
|
||||
"{user}": admin.UserID.String(),
|
||||
"{organizationname}": organization.Name,
|
||||
"{workspace}": workspace.ID.String(),
|
||||
"{workspacebuild}": workspace.LatestBuild.ID.String(),
|
||||
"{workspacename}": workspace.Name,
|
||||
"{workspaceagent}": workspaceResources[0].Agents[0].ID.String(),
|
||||
"{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10),
|
||||
"{template}": template.ID.String(),
|
||||
"{hash}": file.Hash,
|
||||
"{workspaceresource}": workspaceResources[0].ID.String(),
|
||||
"{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name,
|
||||
"{templateversion}": version.ID.String(),
|
||||
"{jobID}": templateVersionDryRun.ID.String(),
|
||||
"{templatename}": template.Name,
|
||||
"{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name,
|
||||
// Only checking template scoped params here
|
||||
"parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s",
|
||||
string(templateParam.Scope), templateParam.ScopeID.String()),
|
||||
}
|
||||
|
||||
return &AuthTester{
|
||||
t: t,
|
||||
api: api,
|
||||
authorizer: authorizer,
|
||||
Client: client,
|
||||
Workspace: workspace,
|
||||
Organization: organization,
|
||||
Admin: admin,
|
||||
Template: template,
|
||||
Version: version,
|
||||
WorkspaceResource: workspaceResources[0],
|
||||
File: file,
|
||||
TemplateVersionDryRun: templateVersionDryRun,
|
||||
TemplateParam: templateParam,
|
||||
URLParams: urlParameters,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) {
|
||||
// Always fail auth from this point forward
|
||||
a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil)
|
||||
|
@ -433,7 +423,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
|
|||
}
|
||||
|
||||
err := chi.Walk(
|
||||
a.api.Handler,
|
||||
a.api.RootHandler,
|
||||
func(
|
||||
method string,
|
||||
route string,
|
||||
|
@ -513,14 +503,14 @@ type authCall struct {
|
|||
Object rbac.Object
|
||||
}
|
||||
|
||||
type recordingAuthorizer struct {
|
||||
type RecordingAuthorizer struct {
|
||||
Called *authCall
|
||||
AlwaysReturn error
|
||||
}
|
||||
|
||||
var _ rbac.Authorizer = (*recordingAuthorizer)(nil)
|
||||
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
|
||||
|
||||
func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
|
||||
func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
|
||||
r.Called = &authCall{
|
||||
SubjectID: subjectID,
|
||||
Roles: roleNames,
|
||||
|
@ -531,7 +521,7 @@ func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro
|
|||
return r.AlwaysReturn
|
||||
}
|
||||
|
||||
func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
|
||||
func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
|
||||
return &fakePreparedAuthorizer{
|
||||
Original: r,
|
||||
SubjectID: subjectID,
|
||||
|
@ -541,12 +531,12 @@ func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID str
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (r *recordingAuthorizer) reset() {
|
||||
func (r *RecordingAuthorizer) reset() {
|
||||
r.Called = nil
|
||||
}
|
||||
|
||||
type fakePreparedAuthorizer struct {
|
||||
Original *recordingAuthorizer
|
||||
Original *RecordingAuthorizer
|
||||
SubjectID string
|
||||
Roles []string
|
||||
Scope rbac.Scope
|
|
@ -0,0 +1,20 @@
|
|||
package coderdtest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
)
|
||||
|
||||
func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
Authorizer: &coderdtest.RecordingAuthorizer{},
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
admin := coderdtest.CreateFirstUser(t, client)
|
||||
a := coderdtest.NewAuthTester(context.Background(), t, client, api, admin)
|
||||
skipRoute, assertRoute := coderdtest.AGPLRoutes(a)
|
||||
a.Test(context.Background(), assertRoute, skipRoute)
|
||||
}
|
|
@ -80,7 +80,6 @@ type Options struct {
|
|||
|
||||
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
|
||||
IncludeProvisionerDaemon bool
|
||||
APIBuilder func(*coderd.Options) *coderd.API
|
||||
MetricsCacheRefreshInterval time.Duration
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
}
|
||||
|
@ -112,14 +111,11 @@ func NewWithProvisionerCloser(t *testing.T, options *Options) (*codersdk.Client,
|
|||
// and is a temporary measure while the API to register provisioners is ironed
|
||||
// out.
|
||||
func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) {
|
||||
client, closer, _ := newWithAPI(t, options)
|
||||
client, closer, _ := NewWithAPI(t, options)
|
||||
return client, closer
|
||||
}
|
||||
|
||||
// newWithAPI constructs an in-memory API instance and returns a client to talk to it.
|
||||
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
|
||||
// Do not expose the API or wrath shall descend upon thee.
|
||||
func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
|
||||
func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
|
@ -140,9 +136,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
close(options.AutobuildStats)
|
||||
})
|
||||
}
|
||||
if options.APIBuilder == nil {
|
||||
options.APIBuilder = coderd.New
|
||||
}
|
||||
|
||||
// This can be hotswapped for a live database instance.
|
||||
db := databasefake.New()
|
||||
|
@ -166,8 +159,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer t.Cleanup(cancelFunc) // Defer to ensure cancelFunc is executed first.
|
||||
|
||||
lifecycleExecutor := executor.New(
|
||||
ctx,
|
||||
db,
|
||||
|
@ -201,13 +192,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
|
||||
}
|
||||
|
||||
features := coderd.DisabledImplementations
|
||||
if options.Auditor != nil {
|
||||
features.Auditor = options.Auditor
|
||||
}
|
||||
|
||||
// We set the handler after server creation for the access URL.
|
||||
coderAPI := options.APIBuilder(&coderd.Options{
|
||||
return srv, cancelFunc, &coderd.Options{
|
||||
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
|
||||
// Force a long disconnection timeout to ensure
|
||||
// agents are not marked as disconnected during slow tests.
|
||||
|
@ -218,6 +203,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
|
||||
Auditor: options.Auditor,
|
||||
AWSCertificates: options.AWSCertificates,
|
||||
AzureCertificates: options.AzureCertificates,
|
||||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||||
|
@ -248,22 +234,30 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
AutoImportTemplates: options.AutoImportTemplates,
|
||||
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
|
||||
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
|
||||
FeaturesService: coderd.NewMockFeaturesService(features),
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = coderAPI.Close()
|
||||
})
|
||||
srv.Config.Handler = coderAPI.Handler
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithAPI constructs an in-memory API instance and returns a client to talk to it.
|
||||
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
|
||||
// Do not expose the API or wrath shall descend upon thee.
|
||||
func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
srv, cancelFunc, newOptions := NewOptions(t, options)
|
||||
// We set the handler after server creation for the access URL.
|
||||
coderAPI := coderd.New(newOptions)
|
||||
srv.Config.Handler = coderAPI.RootHandler
|
||||
var provisionerCloser io.Closer = nopcloser{}
|
||||
if options.IncludeProvisionerDaemon {
|
||||
provisionerCloser = NewProvisionerDaemon(t, coderAPI)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cancelFunc()
|
||||
_ = provisionerCloser.Close()
|
||||
_ = coderAPI.Close()
|
||||
})
|
||||
|
||||
return codersdk.New(serverURL), provisionerCloser, coderAPI
|
||||
return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI
|
||||
}
|
||||
|
||||
// NewProvisionerDaemon launches a provisionerd instance configured to work
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/features"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func NewMockFeaturesService(feats FeatureInterfaces) features.Service {
|
||||
return &featuresService{
|
||||
feats: &feats,
|
||||
}
|
||||
}
|
||||
|
||||
type featuresService struct {
|
||||
feats *FeatureInterfaces
|
||||
}
|
||||
|
||||
func (*featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) {
|
||||
feats := make(map[string]codersdk.Feature)
|
||||
for _, f := range codersdk.FeatureNames {
|
||||
feats[f] = codersdk.Feature{
|
||||
Entitlement: codersdk.EntitlementNotEntitled,
|
||||
Enabled: false,
|
||||
}
|
||||
}
|
||||
httpapi.Write(rw, http.StatusOK, codersdk.Entitlements{
|
||||
Features: feats,
|
||||
Warnings: []string{},
|
||||
HasLicense: false,
|
||||
})
|
||||
}
|
||||
|
||||
// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a
|
||||
// struct type containing feature interfaces as fields. The AGPL featureService always returns the
|
||||
// "disabled" version of the feature interface because it doesn't include any enterprise features
|
||||
// by definition.
|
||||
func (f *featuresService) Get(ps any) error {
|
||||
if reflect.TypeOf(ps).Kind() != reflect.Pointer {
|
||||
return xerrors.New("input must be pointer to struct")
|
||||
}
|
||||
vs := reflect.ValueOf(ps).Elem()
|
||||
if vs.Kind() != reflect.Struct {
|
||||
return xerrors.New("input must be pointer to struct")
|
||||
}
|
||||
for i := 0; i < vs.NumField(); i++ {
|
||||
vf := vs.Field(i)
|
||||
tf := vf.Type()
|
||||
if tf.Kind() != reflect.Interface {
|
||||
return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String())
|
||||
}
|
||||
err := f.setImplementation(vf, tf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setImplementation finds the correct implementation for the field's type, and sets it on the
|
||||
// struct. It returns an error if unsuccessful
|
||||
func (f *featuresService) setImplementation(vf reflect.Value, tf reflect.Type) error {
|
||||
feats := f.feats
|
||||
if feats == nil {
|
||||
feats = &DisabledImplementations
|
||||
}
|
||||
|
||||
// when we get more than a few features it might make sense to have a data structure for finding
|
||||
// the correct implementation that's faster than just a linear search, but for now just spin
|
||||
// through the implementations we have.
|
||||
vd := reflect.ValueOf(*feats)
|
||||
for j := 0; j < vd.NumField(); j++ {
|
||||
vdf := vd.Field(j)
|
||||
if vdf.Type() == tf {
|
||||
vf.Set(vdf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return xerrors.Errorf("unable to find implementation of interface %s", tf.String())
|
||||
}
|
||||
|
||||
// FeatureInterfaces contains a field for each interface controlled by an enterprise feature.
|
||||
type FeatureInterfaces struct {
|
||||
Auditor audit.Auditor
|
||||
}
|
||||
|
||||
// DisabledImplementations includes all the implementations of turned-off features. There are no
|
||||
// turned-on implementations in AGPL code.
|
||||
var DisabledImplementations = FeatureInterfaces{
|
||||
Auditor: audit.NewNop(),
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package features
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Service is the interface for interacting with enterprise features.
|
||||
type Service interface {
|
||||
EntitlementsAPI(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a
|
||||
// struct type containing feature interfaces as fields. The FeatureService sets all fields to
|
||||
// the correct implementations depending on whether the features are turned on.
|
||||
Get(s any) error
|
||||
}
|
|
@ -1,100 +0,0 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func TestEntitlements(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("GET", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
(&featuresService{}).EntitlementsAPI(rw, r)
|
||||
resp := rw.Result()
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
dec := json.NewDecoder(resp.Body)
|
||||
var result codersdk.Entitlements
|
||||
err := dec.Decode(&result)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result.HasLicense)
|
||||
assert.Empty(t, result.Warnings)
|
||||
for _, f := range codersdk.FeatureNames {
|
||||
require.Contains(t, result.Features, f)
|
||||
fe := result.Features[f]
|
||||
assert.False(t, fe.Enabled)
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled, fe.Entitlement)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFeaturesServiceGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Auditor", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := featuresService{}
|
||||
target := struct {
|
||||
Auditor audit.Auditor
|
||||
}{}
|
||||
err := uut.Get(&target)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, target.Auditor)
|
||||
})
|
||||
|
||||
t.Run("NotPointer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := featuresService{}
|
||||
target := struct {
|
||||
Auditor audit.Auditor
|
||||
}{}
|
||||
err := uut.Get(target)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, target.Auditor)
|
||||
})
|
||||
|
||||
t.Run("UnknownInterface", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := featuresService{}
|
||||
target := struct {
|
||||
test testInterface
|
||||
}{}
|
||||
err := uut.Get(&target)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, target.test)
|
||||
})
|
||||
|
||||
t.Run("PointerToNonStruct", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := featuresService{}
|
||||
var target audit.Auditor
|
||||
err := uut.Get(&target)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, target)
|
||||
})
|
||||
|
||||
t.Run("StructWithNonInterfaces", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := featuresService{}
|
||||
target := struct {
|
||||
N int64
|
||||
Auditor audit.Auditor
|
||||
}{}
|
||||
err := uut.Get(&target)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, target.Auditor)
|
||||
})
|
||||
}
|
||||
|
||||
type testInterface interface {
|
||||
Test() error
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func licenses() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.NotFound(unsupported)
|
||||
return r
|
||||
}
|
||||
|
||||
func unsupported(rw http.ResponseWriter, _ *http.Request) {
|
||||
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Unsupported",
|
||||
Detail: "These endpoints are not supported in AGPL-licensed Coder",
|
||||
Validations: nil,
|
||||
})
|
||||
}
|
|
@ -48,7 +48,7 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
|
|||
if daemons == nil {
|
||||
daemons = []database.ProvisionerDaemon{}
|
||||
}
|
||||
daemons, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, daemons)
|
||||
daemons, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, daemons)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner daemons.",
|
||||
|
|
|
@ -41,7 +41,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) {
|
|||
api := New(&opts)
|
||||
defer api.Close()
|
||||
|
||||
server := httptest.NewServer(api.Handler)
|
||||
server := httptest.NewServer(api.RootHandler)
|
||||
defer server.Close()
|
||||
userID := uuid.New()
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
|
|
|
@ -85,11 +85,12 @@ func (api *API) template(rw http.ResponseWriter, r *http.Request) {
|
|||
func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
template = httpmw.TemplateParam(r)
|
||||
auditor = *api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionDelete,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionDelete,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -139,17 +140,18 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
|
|||
createTemplate codersdk.CreateTemplateRequest
|
||||
organization = httpmw.OrganizationParam(r)
|
||||
apiKey = httpmw.APIKey(r)
|
||||
auditor = *api.Auditor.Load()
|
||||
templateAudit, commitTemplateAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
})
|
||||
templateVersionAudit, commitTemplateVersionAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitTemplateAudit()
|
||||
|
@ -340,7 +342,7 @@ func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
// Filter templates based on rbac permissions
|
||||
templates, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, templates)
|
||||
templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching templates.",
|
||||
|
@ -435,11 +437,12 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re
|
|||
func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
template = httpmw.TemplateParam(r)
|
||||
auditor = *api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
|
|
@ -559,11 +559,12 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) {
|
|||
func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
template = httpmw.TemplateParam(r)
|
||||
auditor = *api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -631,11 +632,12 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
|
|||
var (
|
||||
apiKey = httpmw.APIKey(r)
|
||||
organization = httpmw.OrganizationParam(r)
|
||||
auditor = *api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
})
|
||||
|
||||
req codersdk.CreateTemplateVersionRequest
|
||||
|
|
|
@ -220,7 +220,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
users, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, users)
|
||||
users, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, users)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching users.",
|
||||
|
@ -255,11 +255,12 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Creates a new user.
|
||||
func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
||||
auditor := *api.Auditor.Load()
|
||||
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
})
|
||||
defer commitAudit()
|
||||
|
||||
|
@ -339,12 +340,13 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (api *API) deleteUser(rw http.ResponseWriter, r *http.Request) {
|
||||
auditor := *api.Auditor.Load()
|
||||
user := httpmw.UserParam(r)
|
||||
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionDelete,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionDelete,
|
||||
})
|
||||
aReq.Old = user
|
||||
defer commitAudit()
|
||||
|
@ -414,11 +416,12 @@ func (api *API) userByName(rw http.ResponseWriter, r *http.Request) {
|
|||
func (api *API) putUserProfile(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
user = httpmw.UserParam(r)
|
||||
auditor = *api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -494,11 +497,12 @@ func (api *API) putUserStatus(status database.UserStatus) func(rw http.ResponseW
|
|||
var (
|
||||
user = httpmw.UserParam(r)
|
||||
apiKey = httpmw.APIKey(r)
|
||||
auditor = *api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -560,11 +564,12 @@ func (api *API) putUserPassword(rw http.ResponseWriter, r *http.Request) {
|
|||
var (
|
||||
user = httpmw.UserParam(r)
|
||||
params codersdk.UpdateUserPasswordRequest
|
||||
auditor = *api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -673,7 +678,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Only include ones we can read from RBAC.
|
||||
memberships, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, memberships)
|
||||
memberships, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, memberships)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching memberships.",
|
||||
|
@ -698,11 +703,12 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) {
|
|||
user = httpmw.UserParam(r)
|
||||
actorRoles = httpmw.UserAuthorization(r)
|
||||
apiKey = httpmw.APIKey(r)
|
||||
auditor = *api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -812,7 +818,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Only return orgs the user can read.
|
||||
organizations, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, organizations)
|
||||
organizations, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, organizations)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching organizations.",
|
||||
|
@ -1176,9 +1182,9 @@ func (api *API) createUser(ctx context.Context, store database.Store, req create
|
|||
func (api *API) setAuthCookie(rw http.ResponseWriter, cookie *http.Cookie) {
|
||||
http.SetCookie(rw, cookie)
|
||||
|
||||
devurlCookie := api.applicationCookie(cookie)
|
||||
if devurlCookie != nil {
|
||||
http.SetCookie(rw, devurlCookie)
|
||||
appCookie := api.applicationCookie(cookie)
|
||||
if appCookie != nil {
|
||||
http.SetCookie(rw, appCookie)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,11 @@ func TestFirstUser(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{})
|
||||
has, err := client.HasFirstUser(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.False(t, has)
|
||||
|
||||
_, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Only return workspaces the user can read
|
||||
workspaces, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, workspaces)
|
||||
workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspaces.",
|
||||
|
@ -217,11 +217,12 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
|
|||
var (
|
||||
organization = httpmw.OrganizationParam(r)
|
||||
apiKey = httpmw.APIKey(r)
|
||||
auditor = api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
Audit: *auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -480,11 +481,12 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
|
|||
func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
workspace = httpmw.WorkspaceParam(r)
|
||||
auditor = api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: *auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -556,11 +558,12 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
|
|||
func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
workspace = httpmw.WorkspaceParam(r)
|
||||
auditor = api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: *auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
@ -616,11 +619,12 @@ func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) {
|
|||
func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
workspace = httpmw.WorkspaceParam(r)
|
||||
auditor = api.Auditor.Load()
|
||||
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
|
||||
Features: api.FeaturesService,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
Audit: *auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
})
|
||||
)
|
||||
defer commitAudit()
|
||||
|
|
|
@ -3,12 +3,15 @@ package cli
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
agpl "github.com/coder/coder/cli"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
@ -36,11 +39,15 @@ func featuresList() *cobra.Command {
|
|||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
client, err := agpl.CreateClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entitlements, err := client.Entitlements(cmd.Context())
|
||||
var apiError *codersdk.Error
|
||||
if errors.As(err, &apiError) && apiError.StatusCode() == http.StatusNotFound {
|
||||
return xerrors.New("You are on the AGPL licensed version of Coder that does not have Enterprise functionality!")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
|
@ -11,6 +11,8 @@ import (
|
|||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/cli"
|
||||
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
)
|
||||
|
||||
|
@ -18,9 +20,9 @@ func TestFeaturesList(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Table", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
client := coderdenttest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
cmd, root := clitest.New(t, "features", "list")
|
||||
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "features", "list")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t)
|
||||
cmd.SetIn(pty.Input())
|
||||
|
@ -36,9 +38,9 @@ func TestFeaturesList(t *testing.T) {
|
|||
t.Run("JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
client := coderdenttest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
cmd, root := clitest.New(t, "features", "list", "-o", "json")
|
||||
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(), "features", "list", "-o", "json")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
doneChan := make(chan struct{})
|
||||
|
|
@ -23,7 +23,7 @@ import (
|
|||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/cli"
|
||||
"github.com/coder/coder/enterprise/coderd"
|
||||
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
@ -124,7 +124,7 @@ func TestLicensesAddReal(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise})
|
||||
client := coderdenttest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(),
|
||||
"licenses", "add", "-l", fakeLicenseJWT)
|
||||
|
@ -175,7 +175,7 @@ func TestLicensesListReal(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise})
|
||||
client := coderdenttest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(),
|
||||
"licenses", "list")
|
||||
|
@ -219,7 +219,7 @@ func TestLicensesDeleteReal(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: coderd.NewEnterprise})
|
||||
client := coderdenttest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
cmd, root := clitest.NewWithSubcommands(t, cli.EnterpriseSubcommands(),
|
||||
"licenses", "delete", "1")
|
||||
|
|
|
@ -4,12 +4,12 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
|
||||
agpl "github.com/coder/coder/cli"
|
||||
"github.com/coder/coder/enterprise/coderd"
|
||||
)
|
||||
|
||||
func enterpriseOnly() []*cobra.Command {
|
||||
return []*cobra.Command{
|
||||
agpl.Server(coderd.NewEnterprise),
|
||||
server(),
|
||||
features(),
|
||||
licenses(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/enterprise/coderd"
|
||||
|
||||
agpl "github.com/coder/coder/cli"
|
||||
agplcoderd "github.com/coder/coder/coderd"
|
||||
)
|
||||
|
||||
func server() *cobra.Command {
|
||||
var (
|
||||
auditLogging bool
|
||||
)
|
||||
cmd := agpl.Server(func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) {
|
||||
api, err := coderd.New(ctx, &coderd.Options{
|
||||
AuditLogging: auditLogging,
|
||||
Options: options,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return api.AGPL, nil
|
||||
})
|
||||
cliflag.BoolVarP(cmd.Flags(), &auditLogging, "audit-logging", "", "CODER_AUDIT_LOGGING", true,
|
||||
"Specifies whether audit logging is enabled.")
|
||||
|
||||
return cmd
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered.
|
||||
// these tests patch the map of license keys, so cannot be run in parallel
|
||||
// nolint:paralleltest
|
||||
func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
oldKeys := keys
|
||||
defer func() {
|
||||
t.Log("restoring keys")
|
||||
keys = oldKeys
|
||||
}()
|
||||
keys = map[string]ed25519.PublicKey{keyID: pubKey}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
a := coderdtest.NewAuthTester(ctx, t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
|
||||
// We need a license in the DB, so that when we call GET api/v2/licenses there is one in the
|
||||
// list to check authz on.
|
||||
claims := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@coder.test",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
AccountType: AccountTypeSalesforce,
|
||||
AccountID: "testing",
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: 0,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
lic, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
license, err := a.Client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
a.URLParams["licenses/{id}"] = fmt.Sprintf("licenses/%d", license.ID)
|
||||
|
||||
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
|
||||
assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: rbac.ResourceLicense,
|
||||
}
|
||||
assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{
|
||||
StatusCode: http.StatusOK,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceLicense,
|
||||
}
|
||||
assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{
|
||||
AssertAction: rbac.ActionDelete,
|
||||
AssertObject: rbac.ResourceLicense,
|
||||
}
|
||||
a.Test(ctx, assertRoute, skipRoutes)
|
||||
}
|
|
@ -2,48 +2,282 @@ package coderd
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
agplaudit "github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/audit"
|
||||
"github.com/coder/coder/enterprise/audit/backends"
|
||||
)
|
||||
|
||||
const EnvAuditLogEnable = "CODER_AUDIT_LOG_ENABLE"
|
||||
// New constructs an Enterprise coderd API instance.
|
||||
// This handler is designed to wrap the AGPL Coder code and
|
||||
// layer Enterprise functionality on top as much as possible.
|
||||
func New(ctx context.Context, options *Options) (*API, error) {
|
||||
if options.EntitlementsUpdateInterval == 0 {
|
||||
options.EntitlementsUpdateInterval = 10 * time.Minute
|
||||
}
|
||||
if options.Keys == nil {
|
||||
options.Keys = Keys
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
api := &API{
|
||||
AGPL: coderd.New(options.Options),
|
||||
Options: options,
|
||||
|
||||
func NewEnterprise(options *coderd.Options) *coderd.API {
|
||||
var eOpts = *options
|
||||
if eOpts.Authorizer == nil {
|
||||
var err error
|
||||
eOpts.Authorizer, err = rbac.NewAuthorizer()
|
||||
entitlements: entitlements{
|
||||
activeUsers: codersdk.Feature{
|
||||
Entitlement: codersdk.EntitlementNotEntitled,
|
||||
Enabled: false,
|
||||
},
|
||||
auditLogs: codersdk.EntitlementNotEntitled,
|
||||
},
|
||||
cancelEntitlementsLoop: cancelFunc,
|
||||
}
|
||||
oauthConfigs := &httpmw.OAuth2Configs{
|
||||
Github: options.GithubOAuth2Config,
|
||||
OIDC: options.OIDCConfig,
|
||||
}
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
|
||||
|
||||
api.AGPL.APIHandler.Group(func(r chi.Router) {
|
||||
r.Get("/entitlements", api.serveEntitlements)
|
||||
r.Route("/licenses", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Post("/", api.postLicense)
|
||||
r.Get("/", api.licenses)
|
||||
r.Delete("/{id}", api.deleteLicense)
|
||||
})
|
||||
})
|
||||
|
||||
err := api.updateEntitlements(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update entitlements: %w", err)
|
||||
}
|
||||
go api.runEntitlementsLoop(ctx)
|
||||
|
||||
return api, nil
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
*coderd.Options
|
||||
|
||||
AuditLogging bool
|
||||
EntitlementsUpdateInterval time.Duration
|
||||
Keys map[string]ed25519.PublicKey
|
||||
}
|
||||
|
||||
type API struct {
|
||||
AGPL *coderd.API
|
||||
*Options
|
||||
|
||||
cancelEntitlementsLoop func()
|
||||
entitlementsMu sync.RWMutex
|
||||
entitlements entitlements
|
||||
}
|
||||
|
||||
type entitlements struct {
|
||||
hasLicense bool
|
||||
activeUsers codersdk.Feature
|
||||
auditLogs codersdk.Entitlement
|
||||
}
|
||||
|
||||
func (api *API) Close() error {
|
||||
api.cancelEntitlementsLoop()
|
||||
return api.AGPL.Close()
|
||||
}
|
||||
|
||||
func (api *API) updateEntitlements(ctx context.Context) error {
|
||||
licenses, err := api.Database.GetUnexpiredLicenses(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
api.entitlementsMu.Lock()
|
||||
defer api.entitlementsMu.Unlock()
|
||||
now := time.Now()
|
||||
|
||||
// Default all entitlements to be disabled.
|
||||
entitlements := entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: codersdk.Feature{
|
||||
Enabled: false,
|
||||
Entitlement: codersdk.EntitlementNotEntitled,
|
||||
},
|
||||
auditLogs: codersdk.EntitlementNotEntitled,
|
||||
}
|
||||
|
||||
// Here we loop through licenses to detect enabled features.
|
||||
for _, l := range licenses {
|
||||
claims, err := validateDBLicense(l, api.Keys)
|
||||
if err != nil {
|
||||
// This should never happen, as the unit tests would fail if the
|
||||
// default built in authorizer failed.
|
||||
panic(xerrors.Errorf("rego authorize panic: %w", err))
|
||||
api.Logger.Debug(ctx, "skipping invalid license",
|
||||
slog.F("id", l.ID), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
entitlements.hasLicense = true
|
||||
entitlement := codersdk.EntitlementEntitled
|
||||
if now.After(claims.LicenseExpires.Time) {
|
||||
// if the grace period were over, the validation fails, so if we are after
|
||||
// LicenseExpires we must be in grace period.
|
||||
entitlement = codersdk.EntitlementGracePeriod
|
||||
}
|
||||
if claims.Features.UserLimit > 0 {
|
||||
entitlements.activeUsers = codersdk.Feature{
|
||||
Enabled: true,
|
||||
Entitlement: entitlement,
|
||||
}
|
||||
currentLimit := int64(0)
|
||||
if entitlements.activeUsers.Limit != nil {
|
||||
currentLimit = *entitlements.activeUsers.Limit
|
||||
}
|
||||
limit := max(currentLimit, claims.Features.UserLimit)
|
||||
entitlements.activeUsers.Limit = &limit
|
||||
}
|
||||
if claims.Features.AuditLog > 0 {
|
||||
entitlements.auditLogs = entitlement
|
||||
}
|
||||
}
|
||||
eOpts.LicenseHandler = newLicenseAPI(
|
||||
eOpts.Logger,
|
||||
eOpts.Database,
|
||||
eOpts.Pubsub,
|
||||
&coderd.HTTPAuthorizer{
|
||||
Authorizer: eOpts.Authorizer,
|
||||
Logger: eOpts.Logger,
|
||||
}).handler()
|
||||
en := Enablements{AuditLogs: true}
|
||||
auditLog := os.Getenv(EnvAuditLogEnable)
|
||||
auditLog = strings.ToLower(auditLog)
|
||||
if auditLog == "disable" || auditLog == "false" || auditLog == "0" || auditLog == "no" {
|
||||
en.AuditLogs = false
|
||||
|
||||
if entitlements.auditLogs != api.entitlements.auditLogs {
|
||||
auditor := agplaudit.NewNop()
|
||||
// A flag could be added to the options that would allow disabling
|
||||
// enhanced audit logging here!
|
||||
if entitlements.auditLogs == codersdk.EntitlementEntitled && api.AuditLogging {
|
||||
auditor = audit.NewAuditor(
|
||||
audit.DefaultFilter,
|
||||
backends.NewPostgres(api.Database, true),
|
||||
backends.NewSlog(api.Logger),
|
||||
)
|
||||
}
|
||||
api.AGPL.Auditor.Store(&auditor)
|
||||
}
|
||||
eOpts.FeaturesService = newFeaturesService(
|
||||
context.Background(),
|
||||
eOpts.Logger,
|
||||
eOpts.Database,
|
||||
eOpts.Pubsub,
|
||||
en,
|
||||
)
|
||||
return coderd.New(&eOpts)
|
||||
|
||||
api.entitlements = entitlements
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
|
||||
api.entitlementsMu.RLock()
|
||||
entitlements := api.entitlements
|
||||
api.entitlementsMu.RUnlock()
|
||||
|
||||
resp := codersdk.Entitlements{
|
||||
Features: make(map[string]codersdk.Feature),
|
||||
Warnings: make([]string, 0),
|
||||
HasLicense: entitlements.hasLicense,
|
||||
}
|
||||
|
||||
if entitlements.activeUsers.Limit != nil {
|
||||
activeUserCount, err := api.Database.GetActiveUserCount(r.Context())
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Unable to query database",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
entitlements.activeUsers.Actual = &activeUserCount
|
||||
if activeUserCount > *entitlements.activeUsers.Limit {
|
||||
resp.Warnings = append(resp.Warnings,
|
||||
fmt.Sprintf(
|
||||
"Your deployment has %d active users but is only licensed for %d.",
|
||||
activeUserCount, *entitlements.activeUsers.Limit))
|
||||
}
|
||||
}
|
||||
resp.Features[codersdk.FeatureUserLimit] = entitlements.activeUsers
|
||||
|
||||
// Audit logs
|
||||
resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{
|
||||
Entitlement: entitlements.auditLogs,
|
||||
Enabled: api.AuditLogging,
|
||||
}
|
||||
if entitlements.auditLogs == codersdk.EntitlementGracePeriod && api.AuditLogging {
|
||||
resp.Warnings = append(resp.Warnings,
|
||||
"Audit logging is enabled but your license for this feature is expired.")
|
||||
}
|
||||
|
||||
httpapi.Write(rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (api *API) runEntitlementsLoop(ctx context.Context) {
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0 // retry indefinitely
|
||||
b := backoff.WithContext(eb, ctx)
|
||||
updates := make(chan struct{}, 1)
|
||||
subscribed := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// pass
|
||||
}
|
||||
if !subscribed {
|
||||
cancel, err := api.Pubsub.Subscribe(PubsubEventLicenses, func(_ context.Context, _ []byte) {
|
||||
// don't block. If the channel is full, drop the event, as there is a resync
|
||||
// scheduled already.
|
||||
select {
|
||||
case updates <- struct{}{}:
|
||||
// pass
|
||||
default:
|
||||
// pass
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(b.NextBackOff()):
|
||||
}
|
||||
continue
|
||||
}
|
||||
// nolint: revive
|
||||
defer cancel()
|
||||
subscribed = true
|
||||
api.Logger.Debug(ctx, "successfully subscribed to pubsub")
|
||||
}
|
||||
|
||||
api.Logger.Info(ctx, "syncing licensed entitlements")
|
||||
err := api.updateEntitlements(ctx)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err))
|
||||
time.Sleep(b.NextBackOff())
|
||||
continue
|
||||
}
|
||||
b.Reset()
|
||||
api.Logger.Debug(ctx, "synced licensed entitlements")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(api.EntitlementsUpdateInterval):
|
||||
continue
|
||||
case <-updates:
|
||||
api.Logger.Debug(ctx, "got pubsub update")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func max(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -0,0 +1,204 @@
|
|||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
agplaudit "github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/audit"
|
||||
"github.com/coder/coder/enterprise/coderd"
|
||||
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestEntitlements(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NoLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
res, err := client.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.False(t, res.HasLicense)
|
||||
require.Empty(t, res.Warnings)
|
||||
})
|
||||
t.Run("FullLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
UserLimit: 100,
|
||||
AuditLog: true,
|
||||
})
|
||||
res, err := client.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.HasLicense)
|
||||
ul := res.Features[codersdk.FeatureUserLimit]
|
||||
assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement)
|
||||
assert.Equal(t, int64(100), *ul.Limit)
|
||||
assert.Equal(t, int64(1), *ul.Actual)
|
||||
assert.True(t, ul.Enabled)
|
||||
al := res.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
assert.Nil(t, al.Limit)
|
||||
assert.Nil(t, al.Actual)
|
||||
assert.Empty(t, res.Warnings)
|
||||
})
|
||||
t.Run("FullLicenseToNone", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
UserLimit: 100,
|
||||
AuditLog: true,
|
||||
})
|
||||
res, err := client.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.HasLicense)
|
||||
al := res.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
|
||||
err = client.DeleteLicense(context.Background(), license.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err = client.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res.HasLicense)
|
||||
al = res.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
})
|
||||
t.Run("Warnings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
for i := 0; i < 4; i++ {
|
||||
coderdtest.CreateAnotherUser(t, client, first.OrganizationID)
|
||||
}
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
UserLimit: 4,
|
||||
AuditLog: true,
|
||||
GraceAt: time.Now().Add(-time.Second),
|
||||
})
|
||||
res, err := client.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.HasLicense)
|
||||
ul := res.Features[codersdk.FeatureUserLimit]
|
||||
assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement)
|
||||
assert.Equal(t, int64(4), *ul.Limit)
|
||||
assert.Equal(t, int64(5), *ul.Actual)
|
||||
assert.True(t, ul.Enabled)
|
||||
al := res.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
assert.Nil(t, al.Limit)
|
||||
assert.Nil(t, al.Actual)
|
||||
assert.Len(t, res.Warnings, 2)
|
||||
assert.Contains(t, res.Warnings,
|
||||
"Your deployment has 5 active users but is only licensed for 4.")
|
||||
assert.Contains(t, res.Warnings,
|
||||
"Audit logging is enabled but your license for this feature is expired.")
|
||||
})
|
||||
t.Run("Pubsub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _, api := coderdenttest.NewWithAPI(t, nil)
|
||||
entitlements, err := client.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.False(t, entitlements.HasLicense)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
Exp: database.Now().AddDate(1, 0, 0),
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
AuditLog: true,
|
||||
}),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = api.Pubsub.Publish(coderd.PubsubEventLicenses, []byte{})
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
entitlements, err := client.Entitlements(context.Background())
|
||||
assert.NoError(t, err)
|
||||
return entitlements.HasLicense
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
})
|
||||
t.Run("Resync", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
EntitlementsUpdateInterval: 25 * time.Millisecond,
|
||||
})
|
||||
entitlements, err := client.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.False(t, entitlements.HasLicense)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
// Valid
|
||||
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
Exp: database.Now().AddDate(1, 0, 0),
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
AuditLog: true,
|
||||
}),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Expired
|
||||
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
Exp: database.Now().AddDate(-1, 0, 0),
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
ExpiresAt: database.Now().AddDate(-1, 0, 0),
|
||||
}),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Invalid
|
||||
_, err = api.Database.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
Exp: database.Now().AddDate(1, 0, 0),
|
||||
JWT: "invalid",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
entitlements, err := client.Entitlements(context.Background())
|
||||
assert.NoError(t, err)
|
||||
return entitlements.HasLicense
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuditLogging(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Enabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _, api := coderdenttest.NewWithAPI(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
AuditLog: true,
|
||||
})
|
||||
auditor := *api.AGPL.Auditor.Load()
|
||||
ea := audit.NewAuditor(audit.DefaultFilter)
|
||||
t.Logf("%T = %T", auditor, ea)
|
||||
assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type())
|
||||
})
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _, api := coderdenttest.NewWithAPI(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
auditor := *api.AGPL.Auditor.Load()
|
||||
ea := agplaudit.NewNop()
|
||||
t.Logf("%T = %T", auditor, ea)
|
||||
assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(auditor).Type())
|
||||
})
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
package coderdenttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/coderd"
|
||||
)
|
||||
|
||||
const (
|
||||
testKeyID = "enterprise-test"
|
||||
)
|
||||
|
||||
var (
|
||||
testPrivateKey ed25519.PrivateKey
|
||||
testPublicKey ed25519.PublicKey
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
testPublicKey, testPrivateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
*coderdtest.Options
|
||||
EntitlementsUpdateInterval time.Duration
|
||||
}
|
||||
|
||||
// New constructs a codersdk client connected to an in-memory Enterprise API instance.
|
||||
func New(t *testing.T, options *Options) *codersdk.Client {
|
||||
client, _, _ := NewWithAPI(t, options)
|
||||
return client
|
||||
}
|
||||
|
||||
func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
if options.Options == nil {
|
||||
options.Options = &coderdtest.Options{}
|
||||
}
|
||||
srv, cancelFunc, oop := coderdtest.NewOptions(t, options.Options)
|
||||
coderAPI, err := coderd.New(context.Background(), &coderd.Options{
|
||||
AuditLogging: true,
|
||||
Options: oop,
|
||||
EntitlementsUpdateInterval: options.EntitlementsUpdateInterval,
|
||||
Keys: map[string]ed25519.PublicKey{
|
||||
testKeyID: testPublicKey,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
srv.Config.Handler = coderAPI.AGPL.RootHandler
|
||||
var provisionerCloser io.Closer = nopcloser{}
|
||||
if options.IncludeProvisionerDaemon {
|
||||
provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cancelFunc()
|
||||
_ = provisionerCloser.Close()
|
||||
_ = coderAPI.Close()
|
||||
})
|
||||
return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI
|
||||
}
|
||||
|
||||
type LicenseOptions struct {
|
||||
AccountType string
|
||||
AccountID string
|
||||
GraceAt time.Time
|
||||
ExpiresAt time.Time
|
||||
UserLimit int64
|
||||
AuditLog bool
|
||||
}
|
||||
|
||||
// AddLicense generates a new license with the options provided and inserts it.
|
||||
func AddLicense(t *testing.T, client *codersdk.Client, options LicenseOptions) codersdk.License {
|
||||
license, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
|
||||
License: GenerateLicense(t, options),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return license
|
||||
}
|
||||
|
||||
// GenerateLicense returns a signed JWT using the test key.
|
||||
func GenerateLicense(t *testing.T, options LicenseOptions) string {
|
||||
if options.ExpiresAt.IsZero() {
|
||||
options.ExpiresAt = time.Now().Add(time.Hour)
|
||||
}
|
||||
if options.GraceAt.IsZero() {
|
||||
options.GraceAt = time.Now().Add(time.Hour)
|
||||
}
|
||||
auditLog := int64(0)
|
||||
if options.AuditLog {
|
||||
auditLog = 1
|
||||
}
|
||||
c := &coderd.Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@testing.test",
|
||||
ExpiresAt: jwt.NewNumericDate(options.ExpiresAt),
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(options.GraceAt),
|
||||
AccountType: options.AccountType,
|
||||
AccountID: options.AccountID,
|
||||
Version: coderd.CurrentVersion,
|
||||
Features: coderd.Features{
|
||||
UserLimit: options.UserLimit,
|
||||
AuditLog: auditLog,
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)
|
||||
tok.Header[coderd.HeaderKeyID] = testKeyID
|
||||
signedTok, err := tok.SignedString(testPrivateKey)
|
||||
require.NoError(t, err)
|
||||
return signedTok
|
||||
}
|
||||
|
||||
type nopcloser struct{}
|
||||
|
||||
func (nopcloser) Close() error { return nil }
|
|
@ -0,0 +1,51 @@
|
|||
package coderdenttest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
_ = coderdenttest.New(t, nil)
|
||||
}
|
||||
|
||||
func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _, api := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Authorizer: &coderdtest.RecordingAuthorizer{},
|
||||
IncludeProvisionerDaemon: true,
|
||||
},
|
||||
})
|
||||
admin := coderdtest.CreateFirstUser(t, client)
|
||||
license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{})
|
||||
a := coderdtest.NewAuthTester(context.Background(), t, client, api.AGPL, admin)
|
||||
a.URLParams["licenses/{id}"] = fmt.Sprintf("licenses/%d", license.ID)
|
||||
|
||||
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
|
||||
assertRoute["GET:/api/v2/entitlements"] = coderdtest.RouteCheck{
|
||||
NoAuthorize: true,
|
||||
}
|
||||
assertRoute["POST:/api/v2/licenses"] = coderdtest.RouteCheck{
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: rbac.ResourceLicense,
|
||||
}
|
||||
assertRoute["GET:/api/v2/licenses"] = coderdtest.RouteCheck{
|
||||
StatusCode: http.StatusOK,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceLicense,
|
||||
}
|
||||
assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{
|
||||
AssertAction: rbac.ActionDelete,
|
||||
AssertObject: rbac.ResourceLicense,
|
||||
}
|
||||
|
||||
a.Test(context.Background(), assertRoute, skipRoutes)
|
||||
}
|
|
@ -1,327 +0,0 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/enterprise/audit/backends"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
agpl "github.com/coder/coder/coderd"
|
||||
agplAudit "github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/features"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/audit"
|
||||
)
|
||||
|
||||
type Enablements struct {
|
||||
AuditLogs bool
|
||||
}
|
||||
|
||||
type featuresService struct {
|
||||
logger slog.Logger
|
||||
database database.Store
|
||||
pubsub database.Pubsub
|
||||
keys map[string]ed25519.PublicKey
|
||||
enablements Enablements
|
||||
resyncInterval time.Duration
|
||||
// enabledImplementations includes an "enabled" implementation of every feature. This is
|
||||
// initialized at start of day and remains static. The consequence of this is that these things
|
||||
// are hanging around using memory even if not licensed or in use, but it greatly simplifies the
|
||||
// logic because we don't have to bother creating and destroying them as entitlements change.
|
||||
// If we have a particularly memory-hungry feature in future, we might wish to reconsider this
|
||||
// choice.
|
||||
enabledImplementations agpl.FeatureInterfaces
|
||||
|
||||
mu sync.RWMutex
|
||||
entitlements entitlements
|
||||
}
|
||||
|
||||
// newFeaturesService creates a FeaturesService and starts it. It will continue running for the
|
||||
// duration of the passed ctx.
|
||||
func newFeaturesService(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
pubsub database.Pubsub,
|
||||
enablements Enablements,
|
||||
) features.Service {
|
||||
fs := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: keys,
|
||||
enablements: enablements,
|
||||
enabledImplementations: agpl.FeatureInterfaces{
|
||||
Auditor: audit.NewAuditor(
|
||||
audit.DefaultFilter,
|
||||
backends.NewPostgres(db, true),
|
||||
backends.NewSlog(logger),
|
||||
),
|
||||
},
|
||||
resyncInterval: 10 * time.Minute,
|
||||
entitlements: entitlements{
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlementLimit: entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
go fs.syncEntitlements(ctx)
|
||||
return fs
|
||||
}
|
||||
|
||||
func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
e := s.entitlements
|
||||
s.mu.RUnlock()
|
||||
|
||||
resp := codersdk.Entitlements{
|
||||
Features: make(map[string]codersdk.Feature),
|
||||
Warnings: make([]string, 0),
|
||||
HasLicense: e.hasLicense,
|
||||
}
|
||||
|
||||
// User limit
|
||||
uf := codersdk.Feature{
|
||||
Entitlement: e.activeUsers.state.toSDK(),
|
||||
Enabled: true,
|
||||
}
|
||||
if !e.activeUsers.unlimited {
|
||||
n, err := s.database.GetActiveUserCount(r.Context())
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Unable to query database",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
uf.Actual = &n
|
||||
uf.Limit = &e.activeUsers.limit
|
||||
if n > e.activeUsers.limit {
|
||||
resp.Warnings = append(resp.Warnings,
|
||||
fmt.Sprintf(
|
||||
"Your deployment has %d active users but is only licensed for %d.",
|
||||
n, e.activeUsers.limit))
|
||||
}
|
||||
}
|
||||
resp.Features[codersdk.FeatureUserLimit] = uf
|
||||
|
||||
// Audit logs
|
||||
resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{
|
||||
Entitlement: e.auditLogs.state.toSDK(),
|
||||
Enabled: s.enablements.AuditLogs,
|
||||
}
|
||||
if e.auditLogs.state == gracePeriod && s.enablements.AuditLogs {
|
||||
resp.Warnings = append(resp.Warnings,
|
||||
"Audit logging is enabled but your license for this feature is expired.")
|
||||
}
|
||||
|
||||
httpapi.Write(rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
type entitlementState int
|
||||
|
||||
const (
|
||||
notEntitled entitlementState = iota
|
||||
gracePeriod
|
||||
entitled
|
||||
)
|
||||
|
||||
type entitlementLimit struct {
|
||||
unlimited bool
|
||||
limit int64
|
||||
}
|
||||
|
||||
type entitlement struct {
|
||||
state entitlementState
|
||||
}
|
||||
|
||||
func (s entitlementState) toSDK() codersdk.Entitlement {
|
||||
switch s {
|
||||
case notEntitled:
|
||||
return codersdk.EntitlementNotEntitled
|
||||
case gracePeriod:
|
||||
return codersdk.EntitlementGracePeriod
|
||||
case entitled:
|
||||
return codersdk.EntitlementEntitled
|
||||
default:
|
||||
panic("unknown entitlementState")
|
||||
}
|
||||
}
|
||||
|
||||
type numericalEntitlement struct {
|
||||
entitlement
|
||||
entitlementLimit
|
||||
}
|
||||
|
||||
type entitlements struct {
|
||||
hasLicense bool
|
||||
activeUsers numericalEntitlement
|
||||
auditLogs entitlement
|
||||
}
|
||||
|
||||
func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) {
|
||||
licenses, err := s.database.GetUnexpiredLicenses(ctx)
|
||||
if err != nil {
|
||||
return entitlements{}, err
|
||||
}
|
||||
now := time.Now()
|
||||
e := entitlements{
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlementLimit: entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, l := range licenses {
|
||||
claims, err := validateDBLicense(l, s.keys)
|
||||
if err != nil {
|
||||
s.logger.Debug(ctx, "skipping invalid license",
|
||||
slog.F("id", l.ID), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
e.hasLicense = true
|
||||
thisEntitlement := entitled
|
||||
if now.After(claims.LicenseExpires.Time) {
|
||||
// if the grace period were over, the validation fails, so if we are after
|
||||
// LicenseExpires we must be in grace period.
|
||||
thisEntitlement = gracePeriod
|
||||
}
|
||||
if claims.Features.UserLimit > 0 {
|
||||
e.activeUsers.state = thisEntitlement
|
||||
e.activeUsers.unlimited = false
|
||||
e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit)
|
||||
}
|
||||
if claims.Features.AuditLog > 0 {
|
||||
e.auditLogs.state = thisEntitlement
|
||||
}
|
||||
}
|
||||
return e, nil
|
||||
}
|
||||
|
||||
func (s *featuresService) syncEntitlements(ctx context.Context) {
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0 // retry indefinitely
|
||||
b := backoff.WithContext(eb, ctx)
|
||||
updates := make(chan struct{}, 1)
|
||||
subscribed := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// pass
|
||||
}
|
||||
if !subscribed {
|
||||
cancel, err := s.pubsub.Subscribe(PubSubEventLicenses, func(_ context.Context, _ []byte) {
|
||||
// don't block. If the channel is full, drop the event, as there is a resync
|
||||
// scheduled already.
|
||||
select {
|
||||
case updates <- struct{}{}:
|
||||
// pass
|
||||
default:
|
||||
// pass
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err))
|
||||
time.Sleep(b.NextBackOff())
|
||||
continue
|
||||
}
|
||||
// nolint: revive
|
||||
defer cancel()
|
||||
subscribed = true
|
||||
s.logger.Debug(ctx, "successfully subscribed to pubsub")
|
||||
}
|
||||
|
||||
s.logger.Info(ctx, "syncing licensed entitlements")
|
||||
ents, err := s.getEntitlements(ctx)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err))
|
||||
time.Sleep(b.NextBackOff())
|
||||
continue
|
||||
}
|
||||
b.Reset()
|
||||
|
||||
s.mu.Lock()
|
||||
s.entitlements = ents
|
||||
s.mu.Unlock()
|
||||
s.logger.Debug(ctx, "synced licensed entitlements")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(s.resyncInterval):
|
||||
continue
|
||||
case <-updates:
|
||||
s.logger.Debug(ctx, "got pubsub update")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func max(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (s *featuresService) Get(ps any) error {
|
||||
if reflect.TypeOf(ps).Kind() != reflect.Pointer {
|
||||
return xerrors.New("input must be pointer to struct")
|
||||
}
|
||||
vs := reflect.ValueOf(ps).Elem()
|
||||
if vs.Kind() != reflect.Struct {
|
||||
return xerrors.New("input must be pointer to struct")
|
||||
}
|
||||
// grab a local copy of entitlements so that we have a consistent set, but aren't keeping it
|
||||
// locked from updates while we process.
|
||||
s.mu.RLock()
|
||||
ent := s.entitlements
|
||||
s.mu.RUnlock()
|
||||
|
||||
for i := 0; i < vs.NumField(); i++ {
|
||||
vf := vs.Field(i)
|
||||
tf := vf.Type()
|
||||
if tf.Kind() != reflect.Interface {
|
||||
return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String())
|
||||
}
|
||||
|
||||
err := s.setImplementation(ent, vf, tf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *featuresService) setImplementation(ent entitlements, vf reflect.Value, tf reflect.Type) error {
|
||||
// c.f. https://stackoverflow.com/questions/7132848/how-to-get-the-reflect-type-of-an-interface
|
||||
switch tf {
|
||||
case reflect.TypeOf((*agplAudit.Auditor)(nil)).Elem():
|
||||
// Audit logging
|
||||
if !s.enablements.AuditLogs || ent.auditLogs.state == notEntitled {
|
||||
vf.Set(reflect.ValueOf(agpl.DisabledImplementations.Auditor))
|
||||
return nil
|
||||
}
|
||||
vf.Set(reflect.ValueOf(s.enabledImplementations.Auditor))
|
||||
return nil
|
||||
default:
|
||||
return xerrors.Errorf("unable to find implementation of interface %s", tf.String())
|
||||
}
|
||||
}
|
|
@ -1,545 +0,0 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
agplCoderd "github.com/coder/coder/coderd"
|
||||
agplAudit "github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/features"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/audit"
|
||||
"github.com/coder/coder/enterprise/audit/backends"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestFeaturesService_EntitlementsAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
// Note that these are not actually used because we don't run the syncEntitlements
|
||||
// routine in this test.
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
|
||||
t.Run("NoLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
entitlements: entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{notEntitled},
|
||||
entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{notEntitled},
|
||||
},
|
||||
}
|
||||
result := requestEntitlements(t, uut)
|
||||
assert.False(t, result.HasLicense)
|
||||
assert.Empty(t, result.Warnings)
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureUserLimit].Entitlement)
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureAuditLog].Entitlement)
|
||||
})
|
||||
|
||||
t.Run("FullLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
entitlements: entitlements{
|
||||
hasLicense: true,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{entitled},
|
||||
entitlementLimit{
|
||||
unlimited: false,
|
||||
limit: 100,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{entitled},
|
||||
},
|
||||
}
|
||||
_, err := db.InsertUser(ctx, database.InsertUserParams{
|
||||
ID: uuid.UUID{},
|
||||
Email: "",
|
||||
Username: "",
|
||||
HashedPassword: nil,
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
RBACRoles: nil,
|
||||
LoginType: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
result := requestEntitlements(t, uut)
|
||||
assert.True(t, result.HasLicense)
|
||||
ul := result.Features[codersdk.FeatureUserLimit]
|
||||
assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement)
|
||||
assert.Equal(t, int64(100), *ul.Limit)
|
||||
assert.Equal(t, int64(1), *ul.Actual)
|
||||
assert.True(t, ul.Enabled)
|
||||
al := result.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
assert.Nil(t, al.Limit)
|
||||
assert.Nil(t, al.Actual)
|
||||
assert.Empty(t, result.Warnings)
|
||||
})
|
||||
|
||||
t.Run("Warnings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
entitlements: entitlements{
|
||||
hasLicense: true,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{gracePeriod},
|
||||
entitlementLimit{
|
||||
unlimited: false,
|
||||
limit: 4,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{gracePeriod},
|
||||
},
|
||||
}
|
||||
for i := byte(0); i < 5; i++ {
|
||||
_, err := db.InsertUser(ctx, database.InsertUserParams{
|
||||
ID: uuid.UUID{i},
|
||||
Email: "",
|
||||
Username: "",
|
||||
HashedPassword: nil,
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
RBACRoles: nil,
|
||||
LoginType: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
result := requestEntitlements(t, uut)
|
||||
assert.True(t, result.HasLicense)
|
||||
ul := result.Features[codersdk.FeatureUserLimit]
|
||||
assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement)
|
||||
assert.Equal(t, int64(4), *ul.Limit)
|
||||
assert.Equal(t, int64(5), *ul.Actual)
|
||||
assert.True(t, ul.Enabled)
|
||||
al := result.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
assert.Nil(t, al.Limit)
|
||||
assert.Nil(t, al.Actual)
|
||||
assert.Len(t, result.Warnings, 2)
|
||||
assert.Contains(t, result.Warnings,
|
||||
"Your deployment has 5 active users but is only licensed for 4.")
|
||||
assert.Contains(t, result.Warnings,
|
||||
"Audit logging is enabled but your license for this feature is expired.")
|
||||
})
|
||||
}
|
||||
|
||||
func TestFeaturesServiceSyncEntitlements(t *testing.T) {
|
||||
t.Parallel()
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
|
||||
// This tests that pubsub updates work by setting the resync interval very long
|
||||
t.Run("PubSub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
resyncInterval: time.Hour, // no resyncs during test
|
||||
entitlements: entitlements{},
|
||||
}
|
||||
|
||||
_, invalidKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start of day, 3 licenses, one expired, one invalid
|
||||
_ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour)
|
||||
_ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour)
|
||||
l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour)
|
||||
|
||||
go uut.syncEntitlements(ctx)
|
||||
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
|
||||
|
||||
// New license
|
||||
l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour)
|
||||
err = pubsub.Publish(PubSubEventLicenses, []byte("add"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// User limit goes up, because 305 > 300
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast)
|
||||
|
||||
// New license with lower limit
|
||||
_ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour)
|
||||
err = pubsub.Publish(PubSubEventLicenses, []byte("add"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Need to delete the others before the limit lowers
|
||||
_, err = db.DeleteLicense(ctx, l1.ID)
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(PubSubEventLicenses, []byte("delete"))
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
|
||||
|
||||
_, err = db.DeleteLicense(ctx, l0.ID)
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(PubSubEventLicenses, []byte("delete"))
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast)
|
||||
})
|
||||
|
||||
// This tests that periodic resyncs work by setting the resync interval very fast and
|
||||
// not sending any pubsub updates.
|
||||
t.Run("Resyncs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
resyncInterval: 10 * time.Millisecond,
|
||||
entitlements: entitlements{},
|
||||
}
|
||||
|
||||
_, invalidKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start of day, 3 licenses, one expired, one invalid
|
||||
_ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour)
|
||||
_ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour)
|
||||
l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour)
|
||||
|
||||
go uut.syncEntitlements(ctx)
|
||||
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
|
||||
|
||||
// New license
|
||||
l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour)
|
||||
|
||||
// User limit goes up, because 305 > 300
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast)
|
||||
|
||||
// New license with lower limit
|
||||
_ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour)
|
||||
|
||||
// Need to delete the others before the limit lowers
|
||||
_, err = db.DeleteLicense(ctx, l1.ID)
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
|
||||
|
||||
_, err = db.DeleteLicense(ctx, l0.ID)
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast)
|
||||
})
|
||||
}
|
||||
|
||||
func requestEntitlements(t *testing.T, uut features.Service) codersdk.Entitlements {
|
||||
t.Helper()
|
||||
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
uut.EntitlementsAPI(rw, r)
|
||||
resp := rw.Result()
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
dec := json.NewDecoder(resp.Body)
|
||||
var result codersdk.Entitlements
|
||||
err := dec.Decode(&result)
|
||||
require.NoError(t, err)
|
||||
return result
|
||||
}
|
||||
|
||||
func putLicense(
|
||||
ctx context.Context, t *testing.T, db database.Store,
|
||||
k ed25519.PrivateKey, keyID string, userLimit int64,
|
||||
timeToGrace, timeToExpire time.Duration,
|
||||
) database.License {
|
||||
t.Helper()
|
||||
c := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@testing.test",
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(timeToExpire)),
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(timeToGrace)),
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: userLimit,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
j, err := makeLicense(c, k, keyID)
|
||||
require.NoError(t, err)
|
||||
l, err := db.InsertLicense(ctx, database.InsertLicenseParams{
|
||||
UploadedAt: c.IssuedAt.Time,
|
||||
JWT: j,
|
||||
Exp: c.ExpiresAt.Time,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return l
|
||||
}
|
||||
|
||||
func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool {
|
||||
return func(_ context.Context) bool {
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
return fs.entitlements.activeUsers.limit == limit
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeaturesServiceGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
// Note that these are not actually used because we don't run the syncEntitlements
|
||||
// routine in this test.
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
db := databasefake.New()
|
||||
|
||||
t.Run("AuditorOff", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
enabledImplementations: agplCoderd.FeatureInterfaces{
|
||||
Auditor: audit.NewAuditor(audit.DefaultFilter),
|
||||
},
|
||||
entitlements: entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{notEntitled},
|
||||
entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{notEntitled},
|
||||
},
|
||||
}
|
||||
target := struct {
|
||||
Auditor agplAudit.Auditor
|
||||
}{}
|
||||
err := uut.Get(&target)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, target.Auditor)
|
||||
nop := agplAudit.NewNop()
|
||||
assert.Equal(t, reflect.ValueOf(nop).Type(), reflect.ValueOf(target.Auditor).Type())
|
||||
})
|
||||
|
||||
t.Run("AuditorOn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
enabledImplementations: agplCoderd.FeatureInterfaces{
|
||||
Auditor: audit.NewAuditor(audit.DefaultFilter),
|
||||
},
|
||||
entitlements: entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{notEntitled},
|
||||
entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{entitled},
|
||||
},
|
||||
}
|
||||
target := struct {
|
||||
Auditor agplAudit.Auditor
|
||||
}{}
|
||||
err := uut.Get(&target)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, target.Auditor)
|
||||
ea := audit.NewAuditor(
|
||||
audit.DefaultFilter,
|
||||
backends.NewPostgres(db, true),
|
||||
backends.NewSlog(logger),
|
||||
)
|
||||
assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(target.Auditor).Type())
|
||||
})
|
||||
|
||||
t.Run("NotPointer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
enabledImplementations: agplCoderd.FeatureInterfaces{
|
||||
Auditor: audit.NewAuditor(audit.DefaultFilter),
|
||||
},
|
||||
entitlements: entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{notEntitled},
|
||||
entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{notEntitled},
|
||||
},
|
||||
}
|
||||
target := struct {
|
||||
Auditor agplAudit.Auditor
|
||||
}{}
|
||||
err := uut.Get(target)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, target.Auditor)
|
||||
})
|
||||
|
||||
t.Run("UnknownInterface", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
enabledImplementations: agplCoderd.FeatureInterfaces{
|
||||
Auditor: audit.NewAuditor(audit.DefaultFilter),
|
||||
},
|
||||
entitlements: entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{notEntitled},
|
||||
entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{notEntitled},
|
||||
},
|
||||
}
|
||||
target := struct {
|
||||
test testInterface
|
||||
}{}
|
||||
err := uut.Get(&target)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, target.test)
|
||||
})
|
||||
|
||||
t.Run("PointerToNonStruct", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
enabledImplementations: agplCoderd.FeatureInterfaces{
|
||||
Auditor: audit.NewAuditor(audit.DefaultFilter),
|
||||
},
|
||||
entitlements: entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{notEntitled},
|
||||
entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{notEntitled},
|
||||
},
|
||||
}
|
||||
var target agplAudit.Auditor
|
||||
err := uut.Get(&target)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, target)
|
||||
})
|
||||
|
||||
t.Run("StructWithNonInterfaces", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
enabledImplementations: agplCoderd.FeatureInterfaces{
|
||||
Auditor: audit.NewAuditor(audit.DefaultFilter),
|
||||
},
|
||||
entitlements: entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{notEntitled},
|
||||
entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{notEntitled},
|
||||
},
|
||||
}
|
||||
target := struct {
|
||||
N int64
|
||||
Auditor agplAudit.Auditor
|
||||
}{}
|
||||
err := uut.Get(&target)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, target.Auditor)
|
||||
})
|
||||
}
|
||||
|
||||
type testInterface interface {
|
||||
Test() error
|
||||
}
|
|
@ -30,7 +30,8 @@ const (
|
|||
HeaderKeyID = "kid"
|
||||
AccountTypeSalesforce = "salesforce"
|
||||
VersionClaim = "version"
|
||||
PubSubEventLicenses = "licenses"
|
||||
|
||||
PubsubEventLicenses = "licenses"
|
||||
)
|
||||
|
||||
var ValidMethods = []string{"EdDSA"}
|
||||
|
@ -41,7 +42,7 @@ var ValidMethods = []string{"EdDSA"}
|
|||
//go:embed keys/2022-08-12
|
||||
var key20220812 []byte
|
||||
|
||||
var keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)}
|
||||
var Keys = map[string]ed25519.PublicKey{"2022-08-12": ed25519.PublicKey(key20220812)}
|
||||
|
||||
type Features struct {
|
||||
UserLimit int64 `json:"user_limit"`
|
||||
|
@ -68,6 +69,193 @@ var (
|
|||
ErrMissingLicenseExpires = xerrors.New("license missing license_expires")
|
||||
)
|
||||
|
||||
// postLicense adds a new Enterprise license to the cluster. We allow multiple different licenses
|
||||
// in the cluster at one time for several reasons:
|
||||
//
|
||||
// 1. Upgrades --- if the license format changes from one version of Coder to the next, during a
|
||||
// rolling update you will have different Coder servers that need different licenses to function.
|
||||
// 2. Avoid abrupt feature breakage --- when an admin uploads a new license with different features
|
||||
// we generally don't want the old features to immediately break without warning. With a grace
|
||||
// period on the license, features will continue to work from the old license until its grace
|
||||
// period, then the users will get a warning allowing them to gracefully stop using the feature.
|
||||
func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
|
||||
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceLicense) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var addLicense codersdk.AddLicenseRequest
|
||||
if !httpapi.Read(rw, r, &addLicense) {
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := parseLicense(addLicense.License, api.Keys)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid license",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
exp, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid license",
|
||||
Detail: "exp claim missing or not parsable",
|
||||
})
|
||||
return
|
||||
}
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
|
||||
dl, err := api.Database.InsertLicense(r.Context(), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
JWT: addLicense.License,
|
||||
Exp: expTime,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Unable to add license to database",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = api.updateEntitlements(r.Context())
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update entitlements",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = api.Pubsub.Publish(PubsubEventLicenses, []byte("add"))
|
||||
if err != nil {
|
||||
api.Logger.Error(context.Background(), "failed to publish license add", slog.Error(err))
|
||||
// don't fail the HTTP request, since we did write it successfully to the database
|
||||
}
|
||||
|
||||
httpapi.Write(rw, http.StatusCreated, convertLicense(dl, claims))
|
||||
}
|
||||
|
||||
func (api *API) licenses(rw http.ResponseWriter, r *http.Request) {
|
||||
licenses, err := api.Database.GetLicenses(r.Context())
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusOK, []codersdk.License{})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
licenses, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, rbac.ActionRead, licenses)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
sdkLicenses, err := convertLicenses(licenses)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error parsing licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(rw, http.StatusOK, sdkLicenses)
|
||||
}
|
||||
|
||||
func (api *API) deleteLicense(rw http.ResponseWriter, r *http.Request) {
|
||||
if !api.AGPL.Authorize(r, rbac.ActionDelete, rbac.ResourceLicense) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 32)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "License ID must be an integer",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = api.Database.DeleteLicense(r.Context(), int32(id))
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Unknown license ID",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error deleting license",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = api.updateEntitlements(r.Context())
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update entitlements",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = api.Pubsub.Publish(PubsubEventLicenses, []byte("delete"))
|
||||
if err != nil {
|
||||
api.Logger.Error(context.Background(), "failed to publish license delete", slog.Error(err))
|
||||
// don't fail the HTTP request, since we did write it successfully to the database
|
||||
}
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License {
|
||||
return codersdk.License{
|
||||
ID: dl.ID,
|
||||
UploadedAt: dl.UploadedAt,
|
||||
Claims: c,
|
||||
}
|
||||
}
|
||||
|
||||
func convertLicenses(licenses []database.License) ([]codersdk.License, error) {
|
||||
var out []codersdk.License
|
||||
for _, l := range licenses {
|
||||
c, err := decodeClaims(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, convertLicense(l, c))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// decodeClaims decodes the JWT claims from the stored JWT. Note here we do not validate the JWT
|
||||
// and just return the claims verbatim. We want to include all licenses on the GET response, even
|
||||
// if they are expired, or signed by a key this version of Coder no longer considers valid.
|
||||
//
|
||||
// Also, we do not return the whole JWT itself because a signed JWT is a bearer token and we
|
||||
// want to limit the chance of it being accidentally leaked.
|
||||
func decodeClaims(l database.License) (jwt.MapClaims, error) {
|
||||
parts := strings.Split(l.JWT, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, xerrors.Errorf("Unable to parse license %d as JWT", l.ID)
|
||||
}
|
||||
cb, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("Unable to decode license %d claims: %w", l.ID, err)
|
||||
}
|
||||
c := make(jwt.MapClaims)
|
||||
d := json.NewDecoder(bytes.NewBuffer(cb))
|
||||
d.UseNumber()
|
||||
err = d.Decode(&c)
|
||||
return c, err
|
||||
}
|
||||
|
||||
// parseLicense parses the license and returns the claims. If the license's signature is invalid or
|
||||
// is not parsable, an error is returned.
|
||||
func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error) {
|
||||
|
@ -129,203 +317,3 @@ func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, e
|
|||
return k, nil
|
||||
}
|
||||
}
|
||||
|
||||
// licenseAPI handles enterprise licenses, and attaches to the main coderd.API via the
|
||||
// LicenseHandler option, so that it serves all routes under /api/v2/licenses
|
||||
type licenseAPI struct {
|
||||
router chi.Router
|
||||
logger slog.Logger
|
||||
database database.Store
|
||||
pubsub database.Pubsub
|
||||
auth *coderd.HTTPAuthorizer
|
||||
}
|
||||
|
||||
func newLicenseAPI(
|
||||
l slog.Logger,
|
||||
db database.Store,
|
||||
ps database.Pubsub,
|
||||
auth *coderd.HTTPAuthorizer,
|
||||
) *licenseAPI {
|
||||
r := chi.NewRouter()
|
||||
a := &licenseAPI{router: r, logger: l, database: db, pubsub: ps, auth: auth}
|
||||
r.Post("/", a.postLicense)
|
||||
r.Get("/", a.licenses)
|
||||
r.Delete("/{id}", a.delete)
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *licenseAPI) handler() http.Handler {
|
||||
return a.router
|
||||
}
|
||||
|
||||
// postLicense adds a new Enterprise license to the cluster. We allow multiple different licenses
|
||||
// in the cluster at one time for several reasons:
|
||||
//
|
||||
// 1. Upgrades --- if the license format changes from one version of Coder to the next, during a
|
||||
// rolling update you will have different Coder servers that need different licenses to function.
|
||||
// 2. Avoid abrupt feature breakage --- when an admin uploads a new license with different features
|
||||
// we generally don't want the old features to immediately break without warning. With a grace
|
||||
// period on the license, features will continue to work from the old license until its grace
|
||||
// period, then the users will get a warning allowing them to gracefully stop using the feature.
|
||||
func (a *licenseAPI) postLicense(rw http.ResponseWriter, r *http.Request) {
|
||||
if !a.auth.Authorize(r, rbac.ActionCreate, rbac.ResourceLicense) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var addLicense codersdk.AddLicenseRequest
|
||||
if !httpapi.Read(rw, r, &addLicense) {
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := parseLicense(addLicense.License, keys)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid license",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
exp, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid license",
|
||||
Detail: "exp claim missing or not parsable",
|
||||
})
|
||||
return
|
||||
}
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
|
||||
dl, err := a.database.InsertLicense(r.Context(), database.InsertLicenseParams{
|
||||
UploadedAt: database.Now(),
|
||||
JWT: addLicense.License,
|
||||
Exp: expTime,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Unable to add license to database",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = a.pubsub.Publish(PubSubEventLicenses, []byte("add"))
|
||||
if err != nil {
|
||||
a.logger.Error(context.Background(), "failed to publish license add", slog.Error(err))
|
||||
// don't fail the HTTP request, since we did write it successfully to the database
|
||||
}
|
||||
|
||||
httpapi.Write(rw, http.StatusCreated, convertLicense(dl, claims))
|
||||
}
|
||||
|
||||
func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License {
|
||||
return codersdk.License{
|
||||
ID: dl.ID,
|
||||
UploadedAt: dl.UploadedAt,
|
||||
Claims: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) {
|
||||
licenses, err := a.database.GetLicenses(r.Context())
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusOK, []codersdk.License{})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
licenses, err = coderd.AuthorizeFilter(a.auth, r, rbac.ActionRead, licenses)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
sdkLicenses, err := convertLicenses(licenses)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error parsing licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(rw, http.StatusOK, sdkLicenses)
|
||||
}
|
||||
|
||||
func convertLicenses(licenses []database.License) ([]codersdk.License, error) {
|
||||
var out []codersdk.License
|
||||
for _, l := range licenses {
|
||||
c, err := decodeClaims(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, convertLicense(l, c))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// decodeClaims decodes the JWT claims from the stored JWT. Note here we do not validate the JWT
|
||||
// and just return the claims verbatim. We want to include all licenses on the GET response, even
|
||||
// if they are expired, or signed by a key this version of Coder no longer considers valid.
|
||||
//
|
||||
// Also, we do not return the whole JWT itself because a signed JWT is a bearer token and we
|
||||
// want to limit the chance of it being accidentally leaked.
|
||||
func decodeClaims(l database.License) (jwt.MapClaims, error) {
|
||||
parts := strings.Split(l.JWT, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, xerrors.Errorf("Unable to parse license %d as JWT", l.ID)
|
||||
}
|
||||
cb, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("Unable to decode license %d claims: %w", l.ID, err)
|
||||
}
|
||||
c := make(jwt.MapClaims)
|
||||
d := json.NewDecoder(bytes.NewBuffer(cb))
|
||||
d.UseNumber()
|
||||
err = d.Decode(&c)
|
||||
return c, err
|
||||
}
|
||||
|
||||
func (a *licenseAPI) delete(rw http.ResponseWriter, r *http.Request) {
|
||||
if !a.auth.Authorize(r, rbac.ActionDelete, rbac.ResourceLicense) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 32)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "License ID must be an integer",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = a.database.DeleteLicense(r.Context(), int32(id))
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Unknown license ID",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error deleting license",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = a.pubsub.Publish(PubSubEventLicenses, []byte("delete"))
|
||||
if err != nil {
|
||||
a.logger.Error(context.Background(), "failed to publish license delete", slog.Error(err))
|
||||
// don't fail the HTTP request, since we did write it successfully to the database
|
||||
}
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
|
|
@ -1,316 +0,0 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
// these tests patch the map of license keys, so cannot be run in parallel
|
||||
// nolint:paralleltest
|
||||
func TestPostLicense(t *testing.T) {
|
||||
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
oldKeys := keys
|
||||
defer func() {
|
||||
t.Log("restoring keys")
|
||||
keys = oldKeys
|
||||
}()
|
||||
keys = map[string]ed25519.PublicKey{keyID: pubKey}
|
||||
|
||||
t.Run("POST", func(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
claims := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@coder.test",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
AccountType: AccountTypeSalesforce,
|
||||
AccountID: "testing",
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: 0,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
lic, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
respLic, err := client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, respLic.ID, int32(0))
|
||||
// just a couple spot checks for sanity
|
||||
assert.Equal(t, claims.AccountID, respLic.Claims["account_id"])
|
||||
features, ok := respLic.Claims["features"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog])
|
||||
})
|
||||
|
||||
t.Run("POST_unathorized", func(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
claims := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@coder.test",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
AccountType: AccountTypeSalesforce,
|
||||
AccountID: "testing",
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: 0,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
lic, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic,
|
||||
})
|
||||
errResp := &codersdk.Error{}
|
||||
if xerrors.As(err, &errResp) {
|
||||
assert.Equal(t, 401, errResp.StatusCode())
|
||||
} else {
|
||||
t.Error("expected to get error status 401")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("POST_corrupted", func(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
claims := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@coder.test",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
AccountType: AccountTypeSalesforce,
|
||||
AccountID: "testing",
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: 0,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
lic, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: "h" + lic,
|
||||
})
|
||||
errResp := &codersdk.Error{}
|
||||
if xerrors.As(err, &errResp) {
|
||||
assert.Equal(t, 400, errResp.StatusCode())
|
||||
} else {
|
||||
t.Error("expected to get error status 400")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// these tests patch the map of license keys, so cannot be run in parallel
|
||||
// nolint:paralleltest
|
||||
func TestGetLicense(t *testing.T) {
|
||||
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
oldKeys := keys
|
||||
defer func() {
|
||||
t.Log("restoring keys")
|
||||
keys = oldKeys
|
||||
}()
|
||||
keys = map[string]ed25519.PublicKey{keyID: pubKey}
|
||||
|
||||
t.Run("GET", func(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
claims := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@coder.test",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
AccountType: AccountTypeSalesforce,
|
||||
AccountID: "testing",
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: 0,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
lic, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2nd license
|
||||
claims.AccountID = "testing2"
|
||||
claims.Features.UserLimit = 200
|
||||
lic2, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
licenses, err := client.Licenses(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, licenses, 2)
|
||||
assert.Equal(t, int32(1), licenses[0].ID)
|
||||
assert.Equal(t, "testing", licenses[0].Claims["account_id"])
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
codersdk.FeatureUserLimit: json.Number("0"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
}, licenses[0].Claims["features"])
|
||||
assert.Equal(t, int32(2), licenses[1].ID)
|
||||
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
codersdk.FeatureUserLimit: json.Number("200"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
}, licenses[1].Claims["features"])
|
||||
})
|
||||
}
|
||||
|
||||
// these tests patch the map of license keys, so cannot be run in parallel
|
||||
// nolint:paralleltest
|
||||
func TestDeleteLicense(t *testing.T) {
|
||||
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
oldKeys := keys
|
||||
defer func() {
|
||||
t.Log("restoring keys")
|
||||
keys = oldKeys
|
||||
}()
|
||||
keys = map[string]ed25519.PublicKey{keyID: pubKey}
|
||||
|
||||
t.Run("DELETE_empty", func(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
err := client.DeleteLicense(ctx, 1)
|
||||
errResp := &codersdk.Error{}
|
||||
if xerrors.As(err, &errResp) {
|
||||
assert.Equal(t, 404, errResp.StatusCode())
|
||||
} else {
|
||||
t.Error("expected to get error status 404")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DELETE_bad_id", func(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
resp, err := client.Request(ctx, http.MethodDelete, "/api/v2/licenses/drivers", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
})
|
||||
|
||||
t.Run("DELETE", func(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
claims := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@coder.test",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
AccountType: AccountTypeSalesforce,
|
||||
AccountID: "testing",
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: 0,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
lic, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2nd license
|
||||
claims.AccountID = "testing2"
|
||||
claims.Features.UserLimit = 200
|
||||
lic2, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
licenses, err := client.Licenses(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, licenses, 2)
|
||||
for _, l := range licenses {
|
||||
err = client.DeleteLicense(ctx, l.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
licenses, err = client.Licenses(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, licenses, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func makeLicense(c *Claims, privateKey ed25519.PrivateKey, keyID string) (string, error) {
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)
|
||||
tok.Header[HeaderKeyID] = keyID
|
||||
signedTok, err := tok.SignedString(privateKey)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("sign license: %w", err)
|
||||
}
|
||||
return signedTok, nil
|
||||
}
|
|
@ -0,0 +1,168 @@
|
|||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/coderd"
|
||||
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestPostLicense(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
AccountType: coderd.AccountTypeSalesforce,
|
||||
AccountID: "testing",
|
||||
AuditLog: true,
|
||||
})
|
||||
assert.GreaterOrEqual(t, respLic.ID, int32(0))
|
||||
// just a couple spot checks for sanity
|
||||
assert.Equal(t, "testing", respLic.Claims["account_id"])
|
||||
features, ok := respLic.Claims["features"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, json.Number("1"), features[codersdk.FeatureAuditLog])
|
||||
})
|
||||
|
||||
t.Run("Unauthorized", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
|
||||
License: "content",
|
||||
})
|
||||
errResp := &codersdk.Error{}
|
||||
if xerrors.As(err, &errResp) {
|
||||
assert.Equal(t, 401, errResp.StatusCode())
|
||||
} else {
|
||||
t.Error("expected to get error status 401")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Corrupted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{})
|
||||
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
|
||||
License: "invalid",
|
||||
})
|
||||
errResp := &codersdk.Error{}
|
||||
if xerrors.As(err, &errResp) {
|
||||
assert.Equal(t, 400, errResp.StatusCode())
|
||||
} else {
|
||||
t.Error("expected to get error status 400")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetLicense(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
AccountID: "testing",
|
||||
AuditLog: true,
|
||||
})
|
||||
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
AccountID: "testing2",
|
||||
AuditLog: true,
|
||||
UserLimit: 200,
|
||||
})
|
||||
|
||||
licenses, err := client.Licenses(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, licenses, 2)
|
||||
assert.Equal(t, int32(1), licenses[0].ID)
|
||||
assert.Equal(t, "testing", licenses[0].Claims["account_id"])
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
codersdk.FeatureUserLimit: json.Number("0"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
}, licenses[0].Claims["features"])
|
||||
assert.Equal(t, int32(2), licenses[1].ID)
|
||||
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
codersdk.FeatureUserLimit: json.Number("200"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
}, licenses[1].Claims["features"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteLicense(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
err := client.DeleteLicense(ctx, 1)
|
||||
errResp := &codersdk.Error{}
|
||||
if xerrors.As(err, &errResp) {
|
||||
assert.Equal(t, 404, errResp.StatusCode())
|
||||
} else {
|
||||
t.Error("expected to get error status 404")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BadID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
resp, err := client.Request(ctx, http.MethodDelete, "/api/v2/licenses/drivers", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
})
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdenttest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
AccountID: "testing",
|
||||
AuditLog: true,
|
||||
})
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
AccountID: "testing2",
|
||||
AuditLog: true,
|
||||
UserLimit: 200,
|
||||
})
|
||||
|
||||
licenses, err := client.Licenses(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, licenses, 2)
|
||||
for _, l := range licenses {
|
||||
err = client.DeleteLicense(ctx, l.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
licenses, err = client.Licenses(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, licenses, 0)
|
||||
})
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
formatter = pkgs.nixpkgs-fmt;
|
||||
devShells.default = pkgs.mkShell {
|
||||
buildInputs = with pkgs; [
|
||||
bash
|
||||
bat
|
||||
drpc.defaultPackage.${system}
|
||||
exa
|
||||
|
|
|
@ -16,6 +16,22 @@ export const hardCodedCSRFCookie = (): string => {
|
|||
return csrfToken
|
||||
}
|
||||
|
||||
// defaultEntitlements has a default set of disabled functionality.
|
||||
export const defaultEntitlements = (): TypesGen.Entitlements => {
|
||||
const features: TypesGen.Entitlements["features"] = {}
|
||||
for (const feature in Types.FeatureNames) {
|
||||
features[feature] = {
|
||||
enabled: false,
|
||||
entitlement: "not_entitled",
|
||||
}
|
||||
}
|
||||
return {
|
||||
features: features,
|
||||
has_license: false,
|
||||
warnings: [],
|
||||
}
|
||||
}
|
||||
|
||||
// Always attach CSRF token to all requests.
|
||||
// In puppeteer the document is undefined. In those cases, just
|
||||
// do nothing.
|
||||
|
@ -424,8 +440,15 @@ export const putWorkspaceExtension = async (
|
|||
}
|
||||
|
||||
export const getEntitlements = async (): Promise<TypesGen.Entitlements> => {
|
||||
const response = await axios.get("/api/v2/entitlements")
|
||||
return response.data
|
||||
try {
|
||||
const response = await axios.get("/api/v2/entitlements")
|
||||
return response.data
|
||||
} catch (error) {
|
||||
if (axios.isAxiosError(error) && error.response?.status === 404) {
|
||||
return defaultEntitlements()
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
export const getAuditLogs = async (
|
||||
|
|
|
@ -84,7 +84,7 @@ export const entitlementsMachine = createMachine(
|
|||
}),
|
||||
},
|
||||
services: {
|
||||
getEntitlements: () => API.getEntitlements(),
|
||||
getEntitlements: API.getEntitlements,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue