chore: push GetUsers authorization filter to SQL (#8497)

* feat: push GetUsers filter to SQL
* Remove GetAuthorizedUserFilter
* Remove GetFilteredUserCount
* remove GetUsersWithCount
This commit is contained in:
Steven Masley 2023-07-17 09:44:58 -04:00 committed by GitHub
parent dfac0745f3
commit 67494a3012
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 140 additions and 255 deletions

View File

@ -586,32 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro
return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id)
}
func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) {
// TODO Implement this with a SQL filter. The count is incorrect without it.
rowUsers, err := q.db.GetUsers(ctx, arg)
if err != nil {
return nil, -1, err
}
if len(rowUsers) == 0 {
return []database.User{}, 0, nil
}
act, ok := ActorFromContext(ctx)
if !ok {
return nil, -1, NoActorError
}
// TODO: Is this correct? Should we return a restricted user?
users := database.ConvertUserRows(rowUsers)
users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users)
if err != nil {
return nil, -1, err
}
return users, rowUsers[0].Count, nil
}
func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error {
deleteF := func(ctx context.Context, id uuid.UUID) error {
return q.db.UpdateUserDeletedByID(ctx, database.UpdateUserDeletedByIDParams{
@ -904,15 +878,6 @@ func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]dat
return q.db.GetFileTemplates(ctx, fileID)
}
func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
if err != nil {
return -1, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
// TODO: This should be the only implementation.
return q.GetAuthorizedUserCount(ctx, arg, prep)
}
func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg)
}
@ -1389,8 +1354,12 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database
}
func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
// TODO: We should use GetUsersWithCount with a better method signature.
return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg)
// This does the filtering in SQL.
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.GetAuthorizedUsers(ctx, arg, prep)
}
// GetUsersByIDs is only used for usernames on workspace return data.
@ -2639,6 +2608,9 @@ func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetW
return q.GetWorkspaces(ctx, arg)
}
func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
return q.db.GetAuthorizedUserCount(ctx, arg, prepared)
// GetAuthorizedUsers is not required for dbauthz since GetUsers is already
// authenticated.
func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
// GetUsers is authenticated.
return q.GetUsers(ctx, arg)
}

View File

@ -869,24 +869,12 @@ func (s *MethodTestSuite) TestUser() {
Asserts(a, rbac.ActionRead, b, rbac.ActionRead).
Returns(slice.New(a, b))
}))
s.Run("GetAuthorizedUserCount", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.User(s.T(), db, database.User{})
check.Args(database.GetFilteredUserCountParams{}, emptyPreparedAuthorized{}).Asserts().Returns(int64(1))
}))
s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.User(s.T(), db, database.User{})
check.Args(database.GetFilteredUserCountParams{}).Asserts().Returns(int64(1))
}))
s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) {
a := dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"})
b := dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"})
dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"})
dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"})
check.Args(database.GetUsersParams{}).
Asserts(a, rbac.ActionRead, b, rbac.ActionRead)
}))
s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) {
a := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"})
b := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"})
check.Args(database.GetUsersParams{}).Asserts(a, rbac.ActionRead, b, rbac.ActionRead)
// Asserts are done in a SQL filter
Asserts()
}))
s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertUserParams{

View File

@ -23,6 +23,7 @@ import (
"github.com/coder/coder/coderd/database/db2sdk"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/coderd/util/slice"
"github.com/coder/coder/codersdk"
)
@ -1207,14 +1208,6 @@ func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]datab
return rows, nil
}
func (q *FakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
if err := validateDatabaseType(arg); err != nil {
return 0, err
}
count, err := q.GetAuthorizedUserCount(ctx, arg, nil)
return count, err
}
func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
if err := validateDatabaseType(arg); err != nil {
return database.GitAuthLink{}, err
@ -5365,76 +5358,37 @@ func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
}
func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
if err := validateDatabaseType(params); err != nil {
return 0, err
func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
if err := validateDatabaseType(arg); err != nil {
return nil, err
}
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
VariableConverter: regosql.UserConverter(),
})
if err != nil {
return nil, err
}
}
users, err := q.GetUsers(ctx, arg)
if err != nil {
return nil, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return -1, err
}
}
users := make([]database.User, 0, len(q.users))
for _, user := range q.users {
filteredUsers := make([]database.GetUsersRow, 0, len(users))
for _, user := range users {
// If the filter exists, ensure the object is authorized.
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
continue
}
users = append(users, user)
filteredUsers = append(filteredUsers, user)
}
// Filter out deleted since they should never be returned..
tmp := make([]database.User, 0, len(users))
for _, user := range users {
if !user.Deleted {
tmp = append(tmp, user)
}
}
users = tmp
if params.Search != "" {
tmp := make([]database.User, 0, len(users))
for i, user := range users {
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
tmp = append(tmp, users[i])
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
tmp = append(tmp, users[i])
}
}
users = tmp
}
if len(params.Status) > 0 {
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
return strings.EqualFold(string(a), string(b))
}) {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
}
}
users = usersFilteredByStatus
}
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
usersFilteredByRole := make([]database.User, 0, len(users))
for i, user := range users {
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
usersFilteredByRole = append(usersFilteredByRole, users[i])
}
}
users = usersFilteredByRole
}
return int64(len(users)), nil
return filteredUsers, nil
}

View File

@ -321,13 +321,6 @@ func (m metricsStore) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([
return rows, err
}
func (m metricsStore) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
start := time.Now()
count, err := m.s.GetFilteredUserCount(ctx, arg)
m.queryLatencies.WithLabelValues("GetFilteredUserCount").Observe(time.Since(start).Seconds())
return count, err
}
func (m metricsStore) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
start := time.Now()
link, err := m.s.GetGitAuthLink(ctx, arg)
@ -1639,9 +1632,9 @@ func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.
return workspaces, err
}
func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
func (m metricsStore) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
start := time.Now()
count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds())
return count, err
r0, r1 := m.s.GetAuthorizedUsers(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedUsers").Observe(time.Since(start).Seconds())
return r0, r1
}

View File

@ -431,19 +431,19 @@ func (mr *MockStoreMockRecorder) GetAuthorizedTemplates(arg0, arg1, arg2 interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedTemplates", reflect.TypeOf((*MockStore)(nil).GetAuthorizedTemplates), arg0, arg1, arg2)
}
// GetAuthorizedUserCount mocks base method.
func (m *MockStore) GetAuthorizedUserCount(arg0 context.Context, arg1 database.GetFilteredUserCountParams, arg2 rbac.PreparedAuthorized) (int64, error) {
// GetAuthorizedUsers mocks base method.
func (m *MockStore) GetAuthorizedUsers(arg0 context.Context, arg1 database.GetUsersParams, arg2 rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthorizedUserCount", arg0, arg1, arg2)
ret0, _ := ret[0].(int64)
ret := m.ctrl.Call(m, "GetAuthorizedUsers", arg0, arg1, arg2)
ret0, _ := ret[0].([]database.GetUsersRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthorizedUserCount indicates an expected call of GetAuthorizedUserCount.
func (mr *MockStoreMockRecorder) GetAuthorizedUserCount(arg0, arg1, arg2 interface{}) *gomock.Call {
// GetAuthorizedUsers indicates an expected call of GetAuthorizedUsers.
func (mr *MockStoreMockRecorder) GetAuthorizedUsers(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUserCount", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUserCount), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUsers", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUsers), arg0, arg1, arg2)
}
// GetAuthorizedWorkspaces mocks base method.
@ -596,21 +596,6 @@ func (mr *MockStoreMockRecorder) GetFileTemplates(arg0, arg1 interface{}) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFileTemplates", reflect.TypeOf((*MockStore)(nil).GetFileTemplates), arg0, arg1)
}
// GetFilteredUserCount mocks base method.
func (m *MockStore) GetFilteredUserCount(arg0 context.Context, arg1 database.GetFilteredUserCountParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetFilteredUserCount", arg0, arg1)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetFilteredUserCount indicates an expected call of GetFilteredUserCount.
func (mr *MockStoreMockRecorder) GetFilteredUserCount(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredUserCount", reflect.TypeOf((*MockStore)(nil).GetFilteredUserCount), arg0, arg1)
}
// GetGitAuthLink mocks base method.
func (m *MockStore) GetGitAuthLink(arg0 context.Context, arg1 database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
m.ctrl.T.Helper()

View File

@ -255,29 +255,66 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
}
type userQuerier interface {
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error)
}
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) {
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
VariableConverter: regosql.UserConverter(),
})
if err != nil {
return -1, xerrors.Errorf("compile authorized filter: %w", err)
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
filtered, err := insertAuthorizedFilter(getUsers, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return -1, xerrors.Errorf("insert authorized filter: %w", err)
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
row := q.db.QueryRowContext(ctx, query,
query := fmt.Sprintf("-- name: GetAuthorizedUsers :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.AfterID,
arg.Search,
pq.Array(arg.Status),
pq.Array(arg.RbacRole),
arg.LastSeenBefore,
arg.LastSeenAfter,
arg.OffsetOpt,
arg.LimitOpt,
)
var count int64
err = row.Scan(&count)
return count, err
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetUsersRow
for rows.Next() {
var i GetUsersRow
if err := rows.Scan(
&i.ID,
&i.Email,
&i.Username,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
&i.LastSeenAt,
&i.Count,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {

View File

@ -65,8 +65,6 @@ type sqlcQuerier interface {
GetFileByID(ctx context.Context, id uuid.UUID) (File, error)
// Get all templates that use a file.
GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]GetFileTemplatesRow, error)
// This will never count deleted users.
GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error)
GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error)
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)

View File

@ -5108,55 +5108,6 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.
return i, err
}
const getFilteredUserCount = `-- name: GetFilteredUserCount :one
SELECT
COUNT(*)
FROM
users
WHERE
users.deleted = false
-- Start filters
-- Filter by name, email or username
AND CASE
WHEN $1 :: text != '' THEN (
email ILIKE concat('%', $1, '%')
OR username ILIKE concat('%', $1, '%')
)
ELSE true
END
-- Filter by status
AND CASE
-- @status needs to be a text because it can be empty, If it was
-- user_status enum, it would not.
WHEN cardinality($2 :: user_status[]) > 0 THEN
status = ANY($2 :: user_status[])
ELSE true
END
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
WHEN cardinality($3 :: text[]) > 0 AND 'member' != ANY($3 :: text[])
THEN rbac_roles && $3 :: text[]
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
-- @authorize_filter
`
type GetFilteredUserCountParams struct {
Search string `db:"search" json:"search"`
Status []UserStatus `db:"status" json:"status"`
RbacRole []string `db:"rbac_role" json:"rbac_role"`
}
// This will never count deleted users.
func (q *sqlQuerier) GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error) {
row := q.db.QueryRowContext(ctx, getFilteredUserCount, arg.Search, pq.Array(arg.Status), pq.Array(arg.RbacRole))
var count int64
err := row.Scan(&count)
return count, err
}
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
SELECT
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at
@ -5304,6 +5255,9 @@ WHERE
ELSE true
END
-- End of filters
-- Authorize Filter clause will be injected below in GetAuthorizedUsers
-- @authorize_filter
ORDER BY
-- Deterministic and consistent ordering of all users. This is to ensure consistent pagination.
LOWER(username) ASC OFFSET $7

View File

@ -56,42 +56,6 @@ FROM
WHERE
status = 'active'::user_status AND deleted = false;
-- name: GetFilteredUserCount :one
-- This will never count deleted users.
SELECT
COUNT(*)
FROM
users
WHERE
users.deleted = false
-- Start filters
-- Filter by name, email or username
AND CASE
WHEN @search :: text != '' THEN (
email ILIKE concat('%', @search, '%')
OR username ILIKE concat('%', @search, '%')
)
ELSE true
END
-- Filter by status
AND CASE
-- @status needs to be a text because it can be empty, If it was
-- user_status enum, it would not.
WHEN cardinality(@status :: user_status[]) > 0 THEN
status = ANY(@status :: user_status[])
ELSE true
END
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[])
THEN rbac_roles && @rbac_role :: text[]
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
-- @authorize_filter
;
-- name: InsertUser :one
INSERT INTO
users (
@ -208,6 +172,9 @@ WHERE
ELSE true
END
-- End of filters
-- Authorize Filter clause will be injected below in GetAuthorizedUsers
-- @authorize_filter
ORDER BY
-- Deterministic and consistent ordering of all users. This is to ensure consistent pagination.
LOWER(username) ASC OFFSET @offset_opt

View File

@ -242,6 +242,26 @@ neq(input.object.owner, "");
p("false")),
VariableConverter: regosql.TemplateConverter(),
},
{
Name: "UserNoOrgOwner",
Queries: []string{
`input.object.org_owner != ""`,
},
ExpectedSQL: p("'' != ''"),
VariableConverter: regosql.UserConverter(),
},
{
Name: "UserOwnsSelf",
Queries: []string{
`"10d03e62-7703-4df5-a358-4f76577d4e2f" = input.object.owner;
input.object.owner != "";
input.object.org_owner = ""`,
},
VariableConverter: regosql.UserConverter(),
ExpectedSQL: p(
p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = ''") + " AND " + p("'' != ''") + " AND " + p("'' = ''"),
),
},
}
for _, tc := range testCases {

View File

@ -36,6 +36,23 @@ func TemplateConverter() *sqltypes.VariableConverter {
return matcher
}
func UserConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
// Users are never owned by an organization, so always return the empty string
// for the org owner.
sqltypes.StringVarMatcher("''", []string{"input", "object", "org_owner"}),
// Users never have an owner, and are only owned site wide.
sqltypes.StringVarMatcher("''", []string{"input", "object", "owner"}),
)
matcher.RegisterMatcher(
// No ACLs on the user type
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
)
return matcher
}
// NoACLConverter should be used when the target SQL table does not contain
// group or user ACL columns.
func NoACLConverter() *sqltypes.VariableConverter {