feat: Implement RBAC checks on /templates endpoints (#1678)

* feat: Generic Filter method for rbac objects
This commit is contained in:
Steven Masley 2022-05-24 08:43:34 -05:00 committed by GitHub
parent fcd610ee7b
commit c7ca86d374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 221 additions and 73 deletions

View File

@ -12,9 +12,14 @@ import (
"github.com/coder/coder/coderd/rbac"
)
func (api *api) Authorize(rw http.ResponseWriter, r *http.Request, action rbac.Action, object rbac.Object) bool {
func AuthorizeFilter[O rbac.Objecter](api *api, r *http.Request, action rbac.Action, objects []O) []O {
roles := httpmw.UserRoles(r)
err := api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, action, object)
return rbac.Filter(r.Context(), api.Authorizer, roles.ID.String(), roles.Roles, action, objects)
}
func (api *api) Authorize(rw http.ResponseWriter, r *http.Request, action rbac.Action, object rbac.Objecter) bool {
roles := httpmw.UserRoles(r)
err := api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, action, object.RBACObject())
if err != nil {
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
Message: err.Error(),

View File

@ -186,6 +186,7 @@ func newRouter(options *Options, a *api) chi.Router {
r.Route("/templates/{template}", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
authRolesMiddleware,
httpmw.ExtractTemplateParam(options.Database),
)

View File

@ -100,8 +100,6 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
"PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true},
"GET:/api/v2/organizations/{organization}/provisionerdaemons": {NoAuthorize: true},
"POST:/api/v2/organizations/{organization}/templates": {NoAuthorize: true},
"GET:/api/v2/organizations/{organization}/templates": {NoAuthorize: true},
"GET:/api/v2/organizations/{organization}/templates/{templatename}": {NoAuthorize: true},
"POST:/api/v2/organizations/{organization}/templateversions": {NoAuthorize: true},
"POST:/api/v2/organizations/{organization}/workspaces": {NoAuthorize: true},
@ -110,8 +108,6 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
"GET:/api/v2/parameters/{scope}/{id}": {NoAuthorize: true},
"DELETE:/api/v2/parameters/{scope}/{id}/{name}": {NoAuthorize: true},
"DELETE:/api/v2/templates/{template}": {NoAuthorize: true},
"GET:/api/v2/templates/{template}": {NoAuthorize: true},
"GET:/api/v2/templates/{template}/versions": {NoAuthorize: true},
"PATCH:/api/v2/templates/{template}/versions": {NoAuthorize: true},
"GET:/api/v2/templates/{template}/versions/{templateversionname}": {NoAuthorize: true},
@ -185,7 +181,23 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
AssertAction: rbac.ActionRead,
AssertObject: workspaceRBACObj,
},
"GET:/api/v2/organizations/{organization}/templates": {
StatusCode: http.StatusOK,
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceTemplate.InOrg(template.OrganizationID).WithID(template.ID.String()),
},
"POST:/api/v2/organizations/{organization}/templates": {
AssertAction: rbac.ActionCreate,
AssertObject: rbac.ResourceTemplate.InOrg(organization.ID),
},
"DELETE:/api/v2/templates/{template}": {
AssertAction: rbac.ActionDelete,
AssertObject: rbac.ResourceTemplate.InOrg(template.OrganizationID).WithID(template.ID.String()),
},
"GET:/api/v2/templates/{template}": {
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceTemplate.InOrg(template.OrganizationID).WithID(template.ID.String()),
},
"POST:/api/v2/files": {AssertAction: rbac.ActionCreate, AssertObject: rbac.ResourceFile},
"GET:/api/v2/files/{fileHash}": {AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceFile.WithOwner(admin.UserID.String()).WithID(file.Hash)},
@ -226,6 +238,7 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
route = strings.ReplaceAll(route, "{workspacebuild}", workspace.LatestBuild.ID.String())
route = strings.ReplaceAll(route, "{workspacename}", workspace.Name)
route = strings.ReplaceAll(route, "{workspacebuildname}", workspace.LatestBuild.Name)
route = strings.ReplaceAll(route, "{template}", template.ID.String())
route = strings.ReplaceAll(route, "{hash}", file.Hash)
resp, err := client.Request(context.Background(), method, route, nil)

View File

@ -0,0 +1,19 @@
package database
import "github.com/coder/coder/coderd/rbac"
func (t Template) RBACObject() rbac.Object {
return rbac.ResourceTemplate.InOrg(t.OrganizationID).WithID(t.ID.String())
}
func (w Workspace) RBACObject() rbac.Object {
return rbac.ResourceWorkspace.InOrg(w.OrganizationID).WithID(w.ID.String()).WithOwner(w.OwnerID.String())
}
func (m OrganizationMember) RBACObject() rbac.Object {
return rbac.ResourceOrganizationMember.InOrg(m.OrganizationID).WithID(m.UserID.String())
}
func (o Organization) RBACObject() rbac.Object {
return rbac.ResourceOrganization.InOrg(o.ID).WithID(o.ID.String())
}

View File

@ -3,7 +3,6 @@ package rbac
import (
"context"
_ "embed"
"golang.org/x/xerrors"
"github.com/open-policy-agent/opa/rego"
@ -13,6 +12,24 @@ type Authorizer interface {
ByRoleName(ctx context.Context, subjectID string, roleNames []string, action Action, object Object) error
}
// Filter takes in a list of objects, and will filter the list removing all
// the elements the subject does not have permission for.
// Filter does not allocate a new slice, and will use the existing one
// passed in. This can cause memory leaks if the slice is held for a prolonged
// period of time.
func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, subjRoles []string, action Action, objects []O) []O {
filtered := make([]O, 0)
for i := range objects {
object := objects[i]
err := auth.ByRoleName(ctx, subjID, subjRoles, action, object.RBACObject())
if err == nil {
filtered = append(filtered, object)
}
}
return filtered
}
// RegoAuthorizer will use a prepared rego query for performing authorize()
type RegoAuthorizer struct {
query rego.PreparedEvalQuery

View File

@ -4,13 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"strconv"
"testing"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac"
)
@ -24,6 +23,94 @@ type subject struct {
Roles []rbac.Role `json:"roles"`
}
func TestFilter(t *testing.T) {
t.Parallel()
objectList := make([]rbac.Object, 0)
workspaceList := make([]rbac.Object, 0)
fileList := make([]rbac.Object, 0)
for i := 0; i < 10; i++ {
idxStr := strconv.Itoa(i)
workspace := rbac.ResourceWorkspace.WithID(idxStr).WithOwner("me")
file := rbac.ResourceFile.WithID(idxStr).WithOwner("me")
workspaceList = append(workspaceList, workspace)
fileList = append(fileList, file)
objectList = append(objectList, workspace)
objectList = append(objectList, file)
}
// copyList is to prevent tests from sharing the same slice
copyList := func(list []rbac.Object) []rbac.Object {
tmp := make([]rbac.Object, len(list))
copy(tmp, list)
return tmp
}
testCases := []struct {
Name string
List []rbac.Object
Expected []rbac.Object
Auth func(o rbac.Object) error
}{
{
Name: "FilterWorkspaceType",
List: copyList(objectList),
Expected: copyList(workspaceList),
Auth: func(o rbac.Object) error {
if o.Type != rbac.ResourceWorkspace.Type {
return xerrors.New("only workspace")
}
return nil
},
},
{
Name: "FilterFileType",
List: copyList(objectList),
Expected: copyList(fileList),
Auth: func(o rbac.Object) error {
if o.Type != rbac.ResourceFile.Type {
return xerrors.New("only file")
}
return nil
},
},
{
Name: "FilterAll",
List: copyList(objectList),
Expected: []rbac.Object{},
Auth: func(o rbac.Object) error {
return xerrors.New("always fail")
},
},
{
Name: "FilterNone",
List: copyList(objectList),
Expected: copyList(objectList),
Auth: func(o rbac.Object) error {
return nil
},
},
}
for _, c := range testCases {
c := c
t.Run(c.Name, func(t *testing.T) {
t.Parallel()
authorizer := fakeAuthorizer{
AuthFunc: func(_ context.Context, _ string, _ []string, _ rbac.Action, object rbac.Object) error {
return c.Auth(object)
},
}
filtered := rbac.Filter(context.Background(), authorizer, "me", []string{}, rbac.ActionRead, c.List)
require.ElementsMatch(t, c.Expected, filtered, "expect same list")
require.Equal(t, len(c.Expected), len(filtered), "same length list")
})
}
}
// TestAuthorizeDomain test the very basic roles that are commonly used.
func TestAuthorizeDomain(t *testing.T) {
t.Parallel()

15
coderd/rbac/fake_test.go Normal file
View File

@ -0,0 +1,15 @@
package rbac_test
import (
"context"
"github.com/coder/coder/coderd/rbac"
)
type fakeAuthorizer struct {
AuthFunc func(ctx context.Context, subjectID string, roleNames []string, action rbac.Action, object rbac.Object) error
}
func (f fakeAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, action rbac.Action, object rbac.Object) error {
return f.AuthFunc(ctx, subjectID, roleNames, action, object)
}

View File

@ -6,6 +6,11 @@ import (
const WildcardSymbol = "*"
// Objecter returns the RBAC object for itself.
type Objecter interface {
RBACObject() Object
}
// Resources are just typed objects. Making resources this way allows directly
// passing them into an Authorize function and use the chaining api.
var (
@ -99,6 +104,10 @@ type Object struct {
// TODO: SharedUsers?
}
func (z Object) RBACObject() Object {
return z
}
// All returns an object matching all resources of the same type.
func (z Object) All() Object {
return Object{

View File

@ -13,6 +13,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
)
@ -30,6 +31,11 @@ func (api *api) template(rw http.ResponseWriter, r *http.Request) {
})
return
}
if !api.Authorize(rw, r, rbac.ActionRead, template) {
return
}
count := uint32(0)
if len(workspaceCounts) > 0 {
count = uint32(workspaceCounts[0].Count)
@ -40,6 +46,9 @@ func (api *api) template(rw http.ResponseWriter, r *http.Request) {
func (api *api) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
template := httpmw.TemplateParam(r)
if !api.Authorize(rw, r, rbac.ActionDelete, template) {
return
}
workspaces, err := api.Database.GetWorkspacesByTemplateID(r.Context(), database.GetWorkspacesByTemplateIDParams{
TemplateID: template.ID,
@ -77,10 +86,14 @@ func (api *api) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
// Create a new template in an organization.
func (api *api) postTemplateByOrganization(rw http.ResponseWriter, r *http.Request) {
var createTemplate codersdk.CreateTemplateRequest
organization := httpmw.OrganizationParam(r)
if !api.Authorize(rw, r, rbac.ActionCreate, rbac.ResourceTemplate.InOrg(organization.ID)) {
return
}
if !httpapi.Read(rw, r, &createTemplate) {
return
}
organization := httpmw.OrganizationParam(r)
_, err := api.Database.GetTemplateByOrganizationAndName(r.Context(), database.GetTemplateByOrganizationAndNameParams{
OrganizationID: organization.ID,
Name: createTemplate.Name,
@ -194,7 +207,12 @@ func (api *api) templatesByOrganization(rw http.ResponseWriter, r *http.Request)
})
return
}
// Filter templates based on rbac permissions
templates = AuthorizeFilter(api, r, rbac.ActionRead, templates)
templateIDs := make([]uuid.UUID, 0, len(templates))
for _, template := range templates {
templateIDs = append(templateIDs, template.ID)
}
@ -233,6 +251,10 @@ func (api *api) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re
return
}
if !api.Authorize(rw, r, rbac.ActionRead, template) {
return
}
workspaceCounts, err := api.Database.GetWorkspaceOwnerCountsByTemplateIDs(r.Context(), []uuid.UUID{template.ID})
if errors.Is(err, sql.ErrNoRows) {
err = nil

View File

@ -245,7 +245,7 @@ func (api *api) userByName(rw http.ResponseWriter, r *http.Request) {
func (api *api) putUserProfile(rw http.ResponseWriter, r *http.Request) {
user := httpmw.UserParam(r)
if !api.Authorize(rw, r, rbac.ActionUpdate, rbac.ResourceUser.WithOwner(user.ID.String())) {
if !api.Authorize(rw, r, rbac.ActionUpdate, rbac.ResourceUser.WithID(user.ID.String())) {
return
}
@ -389,7 +389,6 @@ func (api *api) putUserPassword(rw http.ResponseWriter, r *http.Request) {
func (api *api) userRoles(rw http.ResponseWriter, r *http.Request) {
user := httpmw.UserParam(r)
roles := httpmw.UserRoles(r)
if !api.Authorize(rw, r, rbac.ActionRead, rbac.ResourceUserData.
WithOwner(user.ID.String())) {
@ -409,13 +408,10 @@ func (api *api) userRoles(rw http.ResponseWriter, r *http.Request) {
return
}
for _, mem := range memberships {
err := api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, rbac.ActionRead,
rbac.ResourceOrganizationMember.
WithID(user.ID.String()).
InOrg(mem.OrganizationID),
)
// Only include ones we can read from RBAC
memberships = AuthorizeFilter(api, r, rbac.ActionRead, memberships)
for _, mem := range memberships {
// If we can read the org member, include the roles
if err == nil {
resp.OrganizationRoles[mem.OrganizationID] = mem.Roles
@ -508,7 +504,6 @@ func (api *api) updateSiteUserRoles(ctx context.Context, args database.UpdateUse
// Returns organizations the parameterized user has access to.
func (api *api) organizationsByUser(rw http.ResponseWriter, r *http.Request) {
user := httpmw.UserParam(r)
roles := httpmw.UserRoles(r)
organizations, err := api.Database.GetOrganizationsByUserID(r.Context(), user.ID)
if errors.Is(err, sql.ErrNoRows) {
@ -522,17 +517,12 @@ func (api *api) organizationsByUser(rw http.ResponseWriter, r *http.Request) {
return
}
// Only return orgs the user can read
organizations = AuthorizeFilter(api, r, rbac.ActionRead, organizations)
publicOrganizations := make([]codersdk.Organization, 0, len(organizations))
for _, organization := range organizations {
err := api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, rbac.ActionRead,
rbac.ResourceOrganization.
WithID(organization.ID.String()).
InOrg(organization.ID),
)
if err == nil {
// Only return orgs the user can read
publicOrganizations = append(publicOrganizations, convertOrganization(organization))
}
publicOrganizations = append(publicOrganizations, convertOrganization(organization))
}
httpapi.Write(rw, http.StatusOK, publicOrganizations)

View File

@ -30,13 +30,7 @@ import (
func (api *api) workspace(rw http.ResponseWriter, r *http.Request) {
workspace := httpmw.WorkspaceParam(r)
if !api.Authorize(rw, r, rbac.ActionRead,
rbac.ResourceWorkspace.InOrg(workspace.OrganizationID).WithOwner(workspace.OwnerID.String()).WithID(workspace.ID.String())) {
return
}
if !api.Authorize(rw, r, rbac.ActionRead,
rbac.ResourceWorkspace.InOrg(workspace.OrganizationID).WithOwner(workspace.OwnerID.String()).WithID(workspace.ID.String())) {
if !api.Authorize(rw, r, rbac.ActionRead, workspace) {
return
}
@ -108,7 +102,6 @@ func (api *api) workspace(rw http.ResponseWriter, r *http.Request) {
func (api *api) workspacesByOrganization(rw http.ResponseWriter, r *http.Request) {
organization := httpmw.OrganizationParam(r)
roles := httpmw.UserRoles(r)
workspaces, err := api.Database.GetWorkspacesWithFilter(r.Context(), database.GetWorkspacesWithFilterParams{
OrganizationID: organization.ID,
Deleted: false,
@ -123,17 +116,10 @@ func (api *api) workspacesByOrganization(rw http.ResponseWriter, r *http.Request
return
}
allowedWorkspaces := make([]database.Workspace, 0)
for _, ws := range workspaces {
ws := ws
err = api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, rbac.ActionRead,
rbac.ResourceWorkspace.InOrg(ws.OrganizationID).WithOwner(ws.OwnerID.String()).WithID(ws.ID.String()))
if err == nil {
allowedWorkspaces = append(allowedWorkspaces, ws)
}
}
// Rbac filter
workspaces = AuthorizeFilter(api, r, rbac.ActionRead, workspaces)
apiWorkspaces, err := convertWorkspaces(r.Context(), api.Database, allowedWorkspaces)
apiWorkspaces, err := convertWorkspaces(r.Context(), api.Database, workspaces)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspaces: %s", err),
@ -146,7 +132,6 @@ func (api *api) workspacesByOrganization(rw http.ResponseWriter, r *http.Request
// workspaces returns all workspaces a user can read.
// Optional filters with query params
func (api *api) workspaces(rw http.ResponseWriter, r *http.Request) {
roles := httpmw.UserRoles(r)
apiKey := httpmw.APIKey(r)
// Empty strings mean no filter
@ -186,24 +171,18 @@ func (api *api) workspaces(rw http.ResponseWriter, r *http.Request) {
filter.OwnerID = userID
}
allowedWorkspaces := make([]database.Workspace, 0)
allWorkspaces, err := api.Database.GetWorkspacesWithFilter(r.Context(), filter)
workspaces, err := api.Database.GetWorkspacesWithFilter(r.Context(), filter)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspaces for user: %s", err),
})
return
}
for _, ws := range allWorkspaces {
ws := ws
err = api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, rbac.ActionRead,
rbac.ResourceWorkspace.InOrg(ws.OrganizationID).WithOwner(ws.OwnerID.String()).WithID(ws.ID.String()))
if err == nil {
allowedWorkspaces = append(allowedWorkspaces, ws)
}
}
apiWorkspaces, err := convertWorkspaces(r.Context(), api.Database, allowedWorkspaces)
// Only return workspaces the user can read
workspaces = AuthorizeFilter(api, r, rbac.ActionRead, workspaces)
apiWorkspaces, err := convertWorkspaces(r.Context(), api.Database, workspaces)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspaces: %s", err),
@ -215,7 +194,6 @@ func (api *api) workspaces(rw http.ResponseWriter, r *http.Request) {
func (api *api) workspacesByOwner(rw http.ResponseWriter, r *http.Request) {
owner := httpmw.UserParam(r)
roles := httpmw.UserRoles(r)
workspaces, err := api.Database.GetWorkspacesWithFilter(r.Context(), database.GetWorkspacesWithFilterParams{
OwnerID: owner.ID,
Deleted: false,
@ -230,17 +208,10 @@ func (api *api) workspacesByOwner(rw http.ResponseWriter, r *http.Request) {
return
}
allowedWorkspaces := make([]database.Workspace, 0)
for _, ws := range workspaces {
ws := ws
err = api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, rbac.ActionRead,
rbac.ResourceWorkspace.InOrg(ws.OrganizationID).WithOwner(ws.OwnerID.String()).WithID(ws.ID.String()))
if err == nil {
allowedWorkspaces = append(allowedWorkspaces, ws)
}
}
// Only return workspaces the user can read
workspaces = AuthorizeFilter(api, r, rbac.ActionRead, workspaces)
apiWorkspaces, err := convertWorkspaces(r.Context(), api.Database, allowedWorkspaces)
apiWorkspaces, err := convertWorkspaces(r.Context(), api.Database, workspaces)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspaces: %s", err),
@ -278,8 +249,7 @@ func (api *api) workspaceByOwnerAndName(rw http.ResponseWriter, r *http.Request)
return
}
if !api.Authorize(rw, r, rbac.ActionRead,
rbac.ResourceWorkspace.InOrg(workspace.OrganizationID).WithOwner(workspace.OwnerID.String()).WithID(workspace.ID.String())) {
if !api.Authorize(rw, r, rbac.ActionRead, workspace) {
return
}