mirror of https://github.com/coder/coder.git
GET license endpoint (#3651)
* GET license endpoint Signed-off-by: Spike Curtis <spike@coder.com> * SDK GetLicenses -> Licenses Signed-off-by: Spike Curtis <spike@coder.com> Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
parent
da54874958
commit
c9bce19d88
|
@ -10,12 +10,12 @@ import (
|
|||
"github.com/coder/coder/coderd/rbac"
|
||||
)
|
||||
|
||||
func AuthorizeFilter[O rbac.Objecter](api *API, r *http.Request, action rbac.Action, objects []O) ([]O, error) {
|
||||
func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) {
|
||||
roles := httpmw.AuthorizationUserRoles(r)
|
||||
objects, err := rbac.Filter(r.Context(), api.Authorizer, roles.ID.String(), roles.Roles, action, objects)
|
||||
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, action, objects)
|
||||
if err != nil {
|
||||
// Log the error as Filter should not be erroring.
|
||||
api.Logger.Error(r.Context(), "filter failed",
|
||||
h.Logger.Error(r.Context(), "filter failed",
|
||||
slog.Error(err),
|
||||
slog.F("user_id", roles.ID),
|
||||
slog.F("username", roles.Username),
|
||||
|
|
|
@ -2278,8 +2278,8 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) {
|
|||
|
||||
func (q *fakeQuerier) InsertLicense(
|
||||
_ context.Context, arg database.InsertLicenseParams) (database.License, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
l := database.License{
|
||||
ID: q.lastLicenseID + 1,
|
||||
|
@ -2292,6 +2292,15 @@ func (q *fakeQuerier) InsertLicense(
|
|||
return l, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
results := append([]database.License{}, q.licenses...)
|
||||
sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID })
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
|
|
@ -43,3 +43,7 @@ func (f File) RBACObject() rbac.Object {
|
|||
func (User) RBACObject() rbac.Object {
|
||||
return rbac.ResourceUser
|
||||
}
|
||||
|
||||
func (License) RBACObject() rbac.Object {
|
||||
return rbac.ResourceLicense
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ type querier interface {
|
|||
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
|
||||
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
|
||||
GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error)
|
||||
GetLicenses(ctx context.Context) ([]License, error)
|
||||
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
|
||||
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
|
||||
GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error)
|
||||
|
|
|
@ -475,6 +475,40 @@ func (q *sqlQuerier) UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyPar
|
|||
return err
|
||||
}
|
||||
|
||||
const getLicenses = `-- name: GetLicenses :many
|
||||
SELECT id, uploaded_at, jwt, exp
|
||||
FROM licenses
|
||||
ORDER BY (id)
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getLicenses)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []License
|
||||
for rows.Next() {
|
||||
var i License
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UploadedAt,
|
||||
&i.JWT,
|
||||
&i.Exp,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertLicense = `-- name: InsertLicense :one
|
||||
INSERT INTO
|
||||
licenses (
|
||||
|
|
|
@ -7,3 +7,9 @@ INSERT INTO
|
|||
)
|
||||
VALUES
|
||||
($1, $2, $3) RETURNING *;
|
||||
|
||||
|
||||
-- name: GetLicenses :many
|
||||
SELECT *
|
||||
FROM licenses
|
||||
ORDER BY (id);
|
||||
|
|
|
@ -50,7 +50,7 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
|
|||
if daemons == nil {
|
||||
daemons = []database.ProvisionerDaemon{}
|
||||
}
|
||||
daemons, err = AuthorizeFilter(api, r, rbac.ActionRead, daemons)
|
||||
daemons, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, daemons)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner daemons.",
|
||||
|
|
|
@ -292,7 +292,7 @@ func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
// Filter templates based on rbac permissions
|
||||
templates, err = AuthorizeFilter(api, r, rbac.ActionRead, templates)
|
||||
templates, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, templates)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching templates.",
|
||||
|
|
|
@ -158,7 +158,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
users, err = AuthorizeFilter(api, r, rbac.ActionRead, users)
|
||||
users, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, users)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching users.",
|
||||
|
@ -503,7 +503,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Only include ones we can read from RBAC.
|
||||
memberships, err = AuthorizeFilter(api, r, rbac.ActionRead, memberships)
|
||||
memberships, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, memberships)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching memberships.",
|
||||
|
@ -631,7 +631,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Only return orgs the user can read.
|
||||
organizations, err = AuthorizeFilter(api, r, rbac.ActionRead, organizations)
|
||||
organizations, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, organizations)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching organizations.",
|
||||
|
|
|
@ -143,7 +143,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Only return workspaces the user can read
|
||||
workspaces, err = AuthorizeFilter(api, r, rbac.ActionRead, workspaces)
|
||||
workspaces, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, workspaces)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspaces.",
|
||||
|
|
|
@ -35,3 +35,18 @@ func (c *Client) AddLicense(ctx context.Context, r AddLicenseRequest) (License,
|
|||
d.UseNumber()
|
||||
return l, d.Decode(&l)
|
||||
}
|
||||
|
||||
func (c *Client) Licenses(ctx context.Context) ([]License, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/v2/licenses", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
var licenses []License
|
||||
d := json.NewDecoder(res.Body)
|
||||
d.UseNumber()
|
||||
return licenses, d.Decode(&licenses)
|
||||
}
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -119,6 +124,7 @@ func newLicenseAPI(
|
|||
r := chi.NewRouter()
|
||||
a := &licenseAPI{router: r, logger: l, database: db, pubsub: ps, auth: auth}
|
||||
r.Post("/", a.postLicense)
|
||||
r.Get("/", a.licenses)
|
||||
return a
|
||||
}
|
||||
|
||||
|
@ -192,3 +198,70 @@ func convertLicense(dl database.License, c jwt.MapClaims) codersdk.License {
|
|||
Claims: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *licenseAPI) licenses(rw http.ResponseWriter, r *http.Request) {
|
||||
licenses, err := a.database.GetLicenses(r.Context())
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusOK, []codersdk.License{})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
licenses, err = coderd.AuthorizeFilter(a.auth, r, rbac.ActionRead, licenses)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
sdkLicenses, err := convertLicenses(licenses)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error parsing licenses.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(rw, http.StatusOK, sdkLicenses)
|
||||
}
|
||||
|
||||
func convertLicenses(licenses []database.License) ([]codersdk.License, error) {
|
||||
var out []codersdk.License
|
||||
for _, l := range licenses {
|
||||
c, err := decodeClaims(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, convertLicense(l, c))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// decodeClaims decodes the JWT claims from the stored JWT. Note here we do not validate the JWT
|
||||
// and just return the claims verbatim. We want to include all licenses on the GET response, even
|
||||
// if they are expired, or signed by a key this version of Coder no longer considers valid.
|
||||
//
|
||||
// Also, we do not return the whole JWT itself because a signed JWT is a bearer token and we
|
||||
// want to limit the chance of it being accidentally leaked.
|
||||
func decodeClaims(l database.License) (jwt.MapClaims, error) {
|
||||
parts := strings.Split(l.JWT, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, xerrors.Errorf("Unable to parse license %d as JWT", l.ID)
|
||||
}
|
||||
cb, err := base64.URLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("Unable to decode license %d claims: %w", l.ID, err)
|
||||
}
|
||||
c := make(jwt.MapClaims)
|
||||
d := json.NewDecoder(bytes.NewBuffer(cb))
|
||||
d.UseNumber()
|
||||
err = d.Decode(&c)
|
||||
return c, err
|
||||
}
|
||||
|
|
|
@ -142,6 +142,76 @@ func TestPostLicense(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// these tests patch the map of license keys, so cannot be run in parallel
|
||||
// nolint:paralleltest
|
||||
func TestGetLicense(t *testing.T) {
|
||||
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
oldKeys := keys
|
||||
defer func() {
|
||||
t.Log("restoring keys")
|
||||
keys = oldKeys
|
||||
}()
|
||||
keys = map[string]ed25519.PublicKey{keyID: pubKey}
|
||||
|
||||
t.Run("GET", func(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIBuilder: NewEnterprise})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
claims := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@coder.test",
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(2 * time.Hour)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
AccountType: AccountTypeSalesforce,
|
||||
AccountID: "testing",
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: 0,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
lic, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2nd license
|
||||
claims.AccountID = "testing2"
|
||||
claims.Features.UserLimit = 200
|
||||
lic2, err := makeLicense(claims, privKey, keyID)
|
||||
require.NoError(t, err)
|
||||
_, err = client.AddLicense(ctx, codersdk.AddLicenseRequest{
|
||||
License: lic2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
licenses, err := client.Licenses(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, licenses, 2)
|
||||
assert.Equal(t, int32(1), licenses[0].ID)
|
||||
assert.Equal(t, "testing", licenses[0].Claims["account_id"])
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
codersdk.FeatureUserLimit: json.Number("0"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
}, licenses[0].Claims["features"])
|
||||
assert.Equal(t, int32(2), licenses[1].ID)
|
||||
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
codersdk.FeatureUserLimit: json.Number("200"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
}, licenses[1].Claims["features"])
|
||||
})
|
||||
}
|
||||
|
||||
func makeLicense(c *Claims, privateKey ed25519.PrivateKey, keyID string) (string, error) {
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)
|
||||
tok.Header[HeaderKeyID] = keyID
|
||||
|
|
Loading…
Reference in New Issue