mirror of https://github.com/coder/coder.git
chore: push GetUsers authorization filter to SQL (#8497)
* feat: push GetUsers filter to SQL * Remove GetAuthorizedUserFilter * Remove GetFilteredUserCount * remove GetUsersWithCount
This commit is contained in:
parent
dfac0745f3
commit
67494a3012
|
@ -586,32 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro
|
|||
return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) {
|
||||
// TODO Implement this with a SQL filter. The count is incorrect without it.
|
||||
rowUsers, err := q.db.GetUsers(ctx, arg)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
}
|
||||
|
||||
if len(rowUsers) == 0 {
|
||||
return []database.User{}, 0, nil
|
||||
}
|
||||
|
||||
act, ok := ActorFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, -1, NoActorError
|
||||
}
|
||||
|
||||
// TODO: Is this correct? Should we return a restricted user?
|
||||
users := database.ConvertUserRows(rowUsers)
|
||||
users, err = rbac.Filter(ctx, q.auth, act, rbac.ActionRead, users)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
}
|
||||
|
||||
return users, rowUsers[0].Count, nil
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error {
|
||||
deleteF := func(ctx context.Context, id uuid.UUID) error {
|
||||
return q.db.UpdateUserDeletedByID(ctx, database.UpdateUserDeletedByIDParams{
|
||||
|
@ -904,15 +878,6 @@ func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]dat
|
|||
return q.db.GetFileTemplates(ctx, fileID)
|
||||
}
|
||||
|
||||
func (q *querier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
// TODO: This should be the only implementation.
|
||||
return q.GetAuthorizedUserCount(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg)
|
||||
}
|
||||
|
@ -1389,8 +1354,12 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database
|
|||
}
|
||||
|
||||
func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
|
||||
// TODO: We should use GetUsersWithCount with a better method signature.
|
||||
return fetchWithPostFilter(q.auth, q.db.GetUsers)(ctx, arg)
|
||||
// This does the filtering in SQL.
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.GetAuthorizedUsers(ctx, arg, prep)
|
||||
}
|
||||
|
||||
// GetUsersByIDs is only used for usernames on workspace return data.
|
||||
|
@ -2639,6 +2608,9 @@ func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetW
|
|||
return q.GetWorkspaces(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.GetAuthorizedUserCount(ctx, arg, prepared)
|
||||
// GetAuthorizedUsers is not required for dbauthz since GetUsers is already
|
||||
// authenticated.
|
||||
func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
|
||||
// GetUsers is authenticated.
|
||||
return q.GetUsers(ctx, arg)
|
||||
}
|
||||
|
|
|
@ -869,24 +869,12 @@ func (s *MethodTestSuite) TestUser() {
|
|||
Asserts(a, rbac.ActionRead, b, rbac.ActionRead).
|
||||
Returns(slice.New(a, b))
|
||||
}))
|
||||
s.Run("GetAuthorizedUserCount", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.User(s.T(), db, database.User{})
|
||||
check.Args(database.GetFilteredUserCountParams{}, emptyPreparedAuthorized{}).Asserts().Returns(int64(1))
|
||||
}))
|
||||
s.Run("GetFilteredUserCount", s.Subtest(func(db database.Store, check *expects) {
|
||||
_ = dbgen.User(s.T(), db, database.User{})
|
||||
check.Args(database.GetFilteredUserCountParams{}).Asserts().Returns(int64(1))
|
||||
}))
|
||||
s.Run("GetUsers", s.Subtest(func(db database.Store, check *expects) {
|
||||
a := dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"})
|
||||
b := dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"})
|
||||
dbgen.User(s.T(), db, database.User{Username: "GetUsers-a-user"})
|
||||
dbgen.User(s.T(), db, database.User{Username: "GetUsers-b-user"})
|
||||
check.Args(database.GetUsersParams{}).
|
||||
Asserts(a, rbac.ActionRead, b, rbac.ActionRead)
|
||||
}))
|
||||
s.Run("GetUsersWithCount", s.Subtest(func(db database.Store, check *expects) {
|
||||
a := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-a-user"})
|
||||
b := dbgen.User(s.T(), db, database.User{Username: "GetUsersWithCount-b-user"})
|
||||
check.Args(database.GetUsersParams{}).Asserts(a, rbac.ActionRead, b, rbac.ActionRead)
|
||||
// Asserts are done in a SQL filter
|
||||
Asserts()
|
||||
}))
|
||||
s.Run("InsertUser", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertUserParams{
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/coder/coder/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/coderd/util/slice"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
@ -1207,14 +1208,6 @@ func (q *FakeQuerier) GetFileTemplates(_ context.Context, id uuid.UUID) ([]datab
|
|||
return rows, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count, err := q.GetAuthorizedUserCount(ctx, arg, nil)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
|
@ -5365,76 +5358,37 @@ func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
|||
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
if err := validateDatabaseType(params); err != nil {
|
||||
return 0, err
|
||||
func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
if prepared != nil {
|
||||
_, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.UserConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
users, err := q.GetUsers(ctx, arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
if prepared != nil {
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
users := make([]database.User, 0, len(q.users))
|
||||
|
||||
for _, user := range q.users {
|
||||
filteredUsers := make([]database.GetUsersRow, 0, len(users))
|
||||
for _, user := range users {
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, user)
|
||||
filteredUsers = append(filteredUsers, user)
|
||||
}
|
||||
|
||||
// Filter out deleted since they should never be returned..
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
for _, user := range users {
|
||||
if !user.Deleted {
|
||||
tmp = append(tmp, user)
|
||||
}
|
||||
}
|
||||
users = tmp
|
||||
|
||||
if params.Search != "" {
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
|
||||
tmp = append(tmp, users[i])
|
||||
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
|
||||
tmp = append(tmp, users[i])
|
||||
}
|
||||
}
|
||||
users = tmp
|
||||
}
|
||||
|
||||
if len(params.Status) > 0 {
|
||||
usersFilteredByStatus := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
|
||||
return strings.EqualFold(string(a), string(b))
|
||||
}) {
|
||||
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
|
||||
}
|
||||
}
|
||||
users = usersFilteredByStatus
|
||||
}
|
||||
|
||||
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
|
||||
usersFilteredByRole := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
|
||||
usersFilteredByRole = append(usersFilteredByRole, users[i])
|
||||
}
|
||||
}
|
||||
|
||||
users = usersFilteredByRole
|
||||
}
|
||||
|
||||
return int64(len(users)), nil
|
||||
return filteredUsers, nil
|
||||
}
|
||||
|
|
|
@ -321,13 +321,6 @@ func (m metricsStore) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([
|
|||
return rows, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
|
||||
start := time.Now()
|
||||
count, err := m.s.GetFilteredUserCount(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetFilteredUserCount").Observe(time.Since(start).Seconds())
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
|
||||
start := time.Now()
|
||||
link, err := m.s.GetGitAuthLink(ctx, arg)
|
||||
|
@ -1639,9 +1632,9 @@ func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
|||
return workspaces, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
func (m metricsStore) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
|
||||
start := time.Now()
|
||||
count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds())
|
||||
return count, err
|
||||
r0, r1 := m.s.GetAuthorizedUsers(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedUsers").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
|
|
@ -431,19 +431,19 @@ func (mr *MockStoreMockRecorder) GetAuthorizedTemplates(arg0, arg1, arg2 interfa
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedTemplates", reflect.TypeOf((*MockStore)(nil).GetAuthorizedTemplates), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetAuthorizedUserCount mocks base method.
|
||||
func (m *MockStore) GetAuthorizedUserCount(arg0 context.Context, arg1 database.GetFilteredUserCountParams, arg2 rbac.PreparedAuthorized) (int64, error) {
|
||||
// GetAuthorizedUsers mocks base method.
|
||||
func (m *MockStore) GetAuthorizedUsers(arg0 context.Context, arg1 database.GetUsersParams, arg2 rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizedUserCount", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret := m.ctrl.Call(m, "GetAuthorizedUsers", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].([]database.GetUsersRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthorizedUserCount indicates an expected call of GetAuthorizedUserCount.
|
||||
func (mr *MockStoreMockRecorder) GetAuthorizedUserCount(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
// GetAuthorizedUsers indicates an expected call of GetAuthorizedUsers.
|
||||
func (mr *MockStoreMockRecorder) GetAuthorizedUsers(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUserCount", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUserCount), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUsers", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUsers), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetAuthorizedWorkspaces mocks base method.
|
||||
|
@ -596,21 +596,6 @@ func (mr *MockStoreMockRecorder) GetFileTemplates(arg0, arg1 interface{}) *gomoc
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFileTemplates", reflect.TypeOf((*MockStore)(nil).GetFileTemplates), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetFilteredUserCount mocks base method.
|
||||
func (m *MockStore) GetFilteredUserCount(arg0 context.Context, arg1 database.GetFilteredUserCountParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetFilteredUserCount", arg0, arg1)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetFilteredUserCount indicates an expected call of GetFilteredUserCount.
|
||||
func (mr *MockStoreMockRecorder) GetFilteredUserCount(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredUserCount", reflect.TypeOf((*MockStore)(nil).GetFilteredUserCount), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetGitAuthLink mocks base method.
|
||||
func (m *MockStore) GetGitAuthLink(arg0 context.Context, arg1 database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -255,29 +255,66 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
|||
}
|
||||
|
||||
type userQuerier interface {
|
||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
|
||||
func (q *sqlQuerier) GetAuthorizedUsers(ctx context.Context, arg GetUsersParams, prepared rbac.PreparedAuthorized) ([]GetUsersRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.UserConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
filtered, err := insertAuthorizedFilter(getUsers, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
|
||||
row := q.db.QueryRowContext(ctx, query,
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUsers :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.AfterID,
|
||||
arg.Search,
|
||||
pq.Array(arg.Status),
|
||||
pq.Array(arg.RbacRole),
|
||||
arg.LastSeenBefore,
|
||||
arg.LastSeenAfter,
|
||||
arg.OffsetOpt,
|
||||
arg.LimitOpt,
|
||||
)
|
||||
var count int64
|
||||
err = row.Scan(&count)
|
||||
return count, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetUsersRow
|
||||
for rows.Next() {
|
||||
var i GetUsersRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.Username,
|
||||
&i.HashedPassword,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
&i.RBACRoles,
|
||||
&i.LoginType,
|
||||
&i.AvatarURL,
|
||||
&i.Deleted,
|
||||
&i.LastSeenAt,
|
||||
&i.Count,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
|
|
|
@ -65,8 +65,6 @@ type sqlcQuerier interface {
|
|||
GetFileByID(ctx context.Context, id uuid.UUID) (File, error)
|
||||
// Get all templates that use a file.
|
||||
GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]GetFileTemplatesRow, error)
|
||||
// This will never count deleted users.
|
||||
GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error)
|
||||
GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error)
|
||||
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
|
||||
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
|
||||
|
|
|
@ -5108,55 +5108,6 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.
|
|||
return i, err
|
||||
}
|
||||
|
||||
const getFilteredUserCount = `-- name: GetFilteredUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
users.deleted = false
|
||||
-- Start filters
|
||||
-- Filter by name, email or username
|
||||
AND CASE
|
||||
WHEN $1 :: text != '' THEN (
|
||||
email ILIKE concat('%', $1, '%')
|
||||
OR username ILIKE concat('%', $1, '%')
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by status
|
||||
AND CASE
|
||||
-- @status needs to be a text because it can be empty, If it was
|
||||
-- user_status enum, it would not.
|
||||
WHEN cardinality($2 :: user_status[]) > 0 THEN
|
||||
status = ANY($2 :: user_status[])
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by rbac_roles
|
||||
AND CASE
|
||||
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
|
||||
WHEN cardinality($3 :: text[]) > 0 AND 'member' != ANY($3 :: text[])
|
||||
THEN rbac_roles && $3 :: text[]
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
|
||||
-- @authorize_filter
|
||||
`
|
||||
|
||||
type GetFilteredUserCountParams struct {
|
||||
Search string `db:"search" json:"search"`
|
||||
Status []UserStatus `db:"status" json:"status"`
|
||||
RbacRole []string `db:"rbac_role" json:"rbac_role"`
|
||||
}
|
||||
|
||||
// This will never count deleted users.
|
||||
func (q *sqlQuerier) GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error) {
|
||||
row := q.db.QueryRowContext(ctx, getFilteredUserCount, arg.Search, pq.Array(arg.Status), pq.Array(arg.RbacRole))
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
|
||||
SELECT
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at
|
||||
|
@ -5304,6 +5255,9 @@ WHERE
|
|||
ELSE true
|
||||
END
|
||||
-- End of filters
|
||||
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedUsers
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all users. This is to ensure consistent pagination.
|
||||
LOWER(username) ASC OFFSET $7
|
||||
|
|
|
@ -56,42 +56,6 @@ FROM
|
|||
WHERE
|
||||
status = 'active'::user_status AND deleted = false;
|
||||
|
||||
-- name: GetFilteredUserCount :one
|
||||
-- This will never count deleted users.
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
users.deleted = false
|
||||
-- Start filters
|
||||
-- Filter by name, email or username
|
||||
AND CASE
|
||||
WHEN @search :: text != '' THEN (
|
||||
email ILIKE concat('%', @search, '%')
|
||||
OR username ILIKE concat('%', @search, '%')
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by status
|
||||
AND CASE
|
||||
-- @status needs to be a text because it can be empty, If it was
|
||||
-- user_status enum, it would not.
|
||||
WHEN cardinality(@status :: user_status[]) > 0 THEN
|
||||
status = ANY(@status :: user_status[])
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by rbac_roles
|
||||
AND CASE
|
||||
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
|
||||
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[])
|
||||
THEN rbac_roles && @rbac_role :: text[]
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: InsertUser :one
|
||||
INSERT INTO
|
||||
users (
|
||||
|
@ -208,6 +172,9 @@ WHERE
|
|||
ELSE true
|
||||
END
|
||||
-- End of filters
|
||||
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedUsers
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all users. This is to ensure consistent pagination.
|
||||
LOWER(username) ASC OFFSET @offset_opt
|
||||
|
|
|
@ -242,6 +242,26 @@ neq(input.object.owner, "");
|
|||
p("false")),
|
||||
VariableConverter: regosql.TemplateConverter(),
|
||||
},
|
||||
{
|
||||
Name: "UserNoOrgOwner",
|
||||
Queries: []string{
|
||||
`input.object.org_owner != ""`,
|
||||
},
|
||||
ExpectedSQL: p("'' != ''"),
|
||||
VariableConverter: regosql.UserConverter(),
|
||||
},
|
||||
{
|
||||
Name: "UserOwnsSelf",
|
||||
Queries: []string{
|
||||
`"10d03e62-7703-4df5-a358-4f76577d4e2f" = input.object.owner;
|
||||
input.object.owner != "";
|
||||
input.object.org_owner = ""`,
|
||||
},
|
||||
VariableConverter: regosql.UserConverter(),
|
||||
ExpectedSQL: p(
|
||||
p("'10d03e62-7703-4df5-a358-4f76577d4e2f' = ''") + " AND " + p("'' != ''") + " AND " + p("'' = ''"),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
|
|
@ -36,6 +36,23 @@ func TemplateConverter() *sqltypes.VariableConverter {
|
|||
return matcher
|
||||
}
|
||||
|
||||
func UserConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
// Users are never owned by an organization, so always return the empty string
|
||||
// for the org owner.
|
||||
sqltypes.StringVarMatcher("''", []string{"input", "object", "org_owner"}),
|
||||
// Users never have an owner, and are only owned site wide.
|
||||
sqltypes.StringVarMatcher("''", []string{"input", "object", "owner"}),
|
||||
)
|
||||
matcher.RegisterMatcher(
|
||||
// No ACLs on the user type
|
||||
sqltypes.AlwaysFalse(groupACLMatcher(matcher)),
|
||||
sqltypes.AlwaysFalse(userACLMatcher(matcher)),
|
||||
)
|
||||
return matcher
|
||||
}
|
||||
|
||||
// NoACLConverter should be used when the target SQL table does not contain
|
||||
// group or user ACL columns.
|
||||
func NoACLConverter() *sqltypes.VariableConverter {
|
||||
|
|
Loading…
Reference in New Issue