feat: set groupsync to use default org (#12146)

* fix: assign new oauth users to default org

This is not a final solution, as we eventually want to be able
to map to different orgs. This makes it so multi-org does not break oauth/oidc.
This commit is contained in:
Steven Masley 2024-02-16 11:09:19 -06:00 committed by GitHub
parent dbaafc863c
commit f17149c59d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 136 additions and 123 deletions

View File

@ -134,7 +134,7 @@ type Options struct {
BaseDERPMap *tailcfg.DERPMap
DERPMapUpdateFrequency time.Duration
SwaggerEndpoint bool
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
@ -301,9 +301,11 @@ func New(options *Options) *API {
options.TracerProvider = trace.NewNoopTracerProvider()
}
if options.SetUserGroups == nil {
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, groups []string, createMissingGroups bool) error {
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
slog.F("user_id", userID), slog.F("groups", groups), slog.F("create_missing_groups", createMissingGroups),
slog.F("user_id", userID),
slog.F("groups", orgGroupNames),
slog.F("create_missing_groups", createMissingGroups),
)
return nil
}

View File

@ -793,16 +793,6 @@ func (q *querier) DeleteGroupMemberFromGroup(ctx context.Context, arg database.D
return update(q.log, q.auth, fetch, q.db.DeleteGroupMemberFromGroup)(ctx, arg)
}
func (q *querier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error {
// This will remove the user from all groups in the org. This counts as updating a group.
// NOTE: instead of fetching all groups in the org with arg.UserID as a member, we instead
// check if the caller has permission to update any group in the org.
fetch := func(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) (rbac.Objecter, error) {
return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil
}
return update(q.log, q.auth, fetch, q.db.DeleteGroupMembersByOrgAndUser)(ctx, arg)
}
func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) {
err := deleteQ(q.log, q.auth, q.db.GetLicenseByID, func(ctx context.Context, id int32) error {
_, err := q.db.DeleteLicense(ctx, id)
@ -2555,6 +2545,14 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
}
func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
// This is a system function to clear user groups in group sync.
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.RemoveUserFromAllGroups(ctx, userID)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
return err

View File

@ -344,17 +344,14 @@ func (s *MethodTestSuite) TestGroup() {
GroupNames: slice.New(g1.Name, g2.Name),
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns()
}))
s.Run("DeleteGroupMembersByOrgAndUser", s.Subtest(func(db database.Store, check *expects) {
s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u1 := dbgen.User(s.T(), db, database.User{})
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g1.ID, UserID: u1.ID})
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g2.ID, UserID: u1.ID})
check.Args(database.DeleteGroupMembersByOrgAndUserParams{
OrganizationID: o.ID,
UserID: u1.ID,
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), rbac.ActionUpdate).Returns()
check.Args(u1.ID).Asserts(rbac.ResourceSystem, rbac.ActionUpdate).Returns()
}))
s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) {
g := dbgen.Group(s.T(), db, database.Group{})

View File

@ -1135,36 +1135,6 @@ func (q *FakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database
return nil
}
func (q *FakeQuerier) DeleteGroupMembersByOrgAndUser(_ context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()
newMembers := q.groupMembers[:0]
for _, member := range q.groupMembers {
if member.UserID != arg.UserID {
// Do not delete the other members
newMembers = append(newMembers, member)
} else if member.UserID == arg.UserID {
// We only want to delete from groups in the organization in the args.
for _, group := range q.groups {
// Find the group that the member is apartof.
if group.ID == member.GroupID {
// Only add back the member if the organization ID does not match
// the arg organization ID. Since the arg is saying which
// org to delete.
if group.OrganizationID != arg.OrganizationID {
newMembers = append(newMembers, member)
}
break
}
}
}
}
q.groupMembers = newMembers
return nil
}
func (q *FakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
@ -6096,6 +6066,22 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
return database.WorkspaceProxy{}, sql.ErrNoRows
}
func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error {
q.mutex.Lock()
defer q.mutex.Unlock()
newMembers := q.groupMembers[:0]
for _, member := range q.groupMembers {
if member.UserID == userID {
continue
}
newMembers = append(newMembers, member)
}
q.groupMembers = newMembers
return nil
}
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
q.mutex.Lock()
defer q.mutex.Unlock()

View File

@ -211,13 +211,6 @@ func (m metricsStore) DeleteGroupMemberFromGroup(ctx context.Context, arg databa
return err
}
func (m metricsStore) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error {
start := time.Now()
err := m.s.DeleteGroupMembersByOrgAndUser(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteGroupMembersByOrgAndUser").Observe(time.Since(start).Seconds())
return err
}
func (m metricsStore) DeleteLicense(ctx context.Context, id int32) (int32, error) {
start := time.Now()
licenseID, err := m.s.DeleteLicense(ctx, id)
@ -1642,6 +1635,13 @@ func (m metricsStore) RegisterWorkspaceProxy(ctx context.Context, arg database.R
return proxy, err
}
func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
start := time.Now()
r0 := m.s.RemoveUserFromAllGroups(ctx, userID)
m.queryLatencies.WithLabelValues("RemoveUserFromAllGroups").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
start := time.Now()
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)

View File

@ -313,20 +313,6 @@ func (mr *MockStoreMockRecorder) DeleteGroupMemberFromGroup(arg0, arg1 any) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroupMemberFromGroup", reflect.TypeOf((*MockStore)(nil).DeleteGroupMemberFromGroup), arg0, arg1)
}
// DeleteGroupMembersByOrgAndUser mocks base method.
func (m *MockStore) DeleteGroupMembersByOrgAndUser(arg0 context.Context, arg1 database.DeleteGroupMembersByOrgAndUserParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteGroupMembersByOrgAndUser", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteGroupMembersByOrgAndUser indicates an expected call of DeleteGroupMembersByOrgAndUser.
func (mr *MockStoreMockRecorder) DeleteGroupMembersByOrgAndUser(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroupMembersByOrgAndUser", reflect.TypeOf((*MockStore)(nil).DeleteGroupMembersByOrgAndUser), arg0, arg1)
}
// DeleteLicense mocks base method.
func (m *MockStore) DeleteLicense(arg0 context.Context, arg1 int32) (int32, error) {
m.ctrl.T.Helper()
@ -3470,6 +3456,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1)
}
// RemoveUserFromAllGroups mocks base method.
func (m *MockStore) RemoveUserFromAllGroups(arg0 context.Context, arg1 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveUserFromAllGroups", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveUserFromAllGroups indicates an expected call of RemoveUserFromAllGroups.
func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1)
}
// RevokeDBCryptKey mocks base method.
func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()

View File

@ -58,7 +58,6 @@ type sqlcQuerier interface {
DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error
DeleteGroupByID(ctx context.Context, id uuid.UUID) error
DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error
DeleteGroupMembersByOrgAndUser(ctx context.Context, arg DeleteGroupMembersByOrgAndUserParams) error
DeleteLicense(ctx context.Context, id int32) (int32, error)
DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error
DeleteOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) error
@ -322,6 +321,7 @@ type sqlcQuerier interface {
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
//

View File

@ -1288,24 +1288,6 @@ func (q *sqlQuerier) DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteG
return err
}
const deleteGroupMembersByOrgAndUser = `-- name: DeleteGroupMembersByOrgAndUser :exec
DELETE FROM
group_members
WHERE
group_members.user_id = $1
AND group_id = ANY(SELECT id FROM groups WHERE organization_id = $2)
`
type DeleteGroupMembersByOrgAndUserParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
}
func (q *sqlQuerier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg DeleteGroupMembersByOrgAndUserParams) error {
_, err := q.db.ExecContext(ctx, deleteGroupMembersByOrgAndUser, arg.UserID, arg.OrganizationID)
return err
}
const getGroupMembers = `-- name: GetGroupMembers :many
SELECT
users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at, users.quiet_hours_schedule, users.theme_preference, users.name
@ -1419,6 +1401,18 @@ func (q *sqlQuerier) InsertUserGroupsByName(ctx context.Context, arg InsertUserG
return err
}
const removeUserFromAllGroups = `-- name: RemoveUserFromAllGroups :exec
DELETE FROM
group_members
WHERE
user_id = $1
`
func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
_, err := q.db.ExecContext(ctx, removeUserFromAllGroups, userID)
return err
}
const deleteGroupByID = `-- name: DeleteGroupByID :exec
DELETE FROM
groups

View File

@ -42,12 +42,11 @@ SELECT
FROM
groups;
-- name: DeleteGroupMembersByOrgAndUser :exec
-- name: RemoveUserFromAllGroups :exec
DELETE FROM
group_members
WHERE
group_members.user_id = @user_id
AND group_id = ANY(SELECT id FROM groups WHERE organization_id = @organization_id);
user_id = @user_id;
-- name: InsertGroupMember :exec
INSERT INTO

View File

@ -20,6 +20,7 @@ import (
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
@ -1217,8 +1218,10 @@ type oauthLoginParams struct {
// to the Groups provided.
UsingGroups bool
CreateMissingGroups bool
Groups []string
GroupFilter *regexp.Regexp
// These are the group names from the IDP. Internally, they will map to
// some organization groups.
Groups []string
GroupFilter *regexp.Regexp
// Is UsingRoles is true, then the user will be assigned
// the roles provided.
UsingRoles bool
@ -1301,7 +1304,6 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
link database.UserLink
err error
)
user = params.User
link = params.Link
@ -1460,6 +1462,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}
// Ensure groups are correct.
// This places all groups into the default organization.
// To go multi-org, we need to add a mapping feature here to know which
// groups go to which orgs.
if params.UsingGroups {
filtered := params.Groups
if params.GroupFilter != nil {
@ -1471,8 +1476,32 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}
}
//nolint:gocritic // No user present in the context.
defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
if err != nil {
// If there is no default org, then we can't assign groups.
// By default, we assume all groups belong to the default org.
return xerrors.Errorf("get default organization: %w", err)
}
//nolint:gocritic // No user present in the context.
memberships, err := tx.GetOrganizationMembershipsByUserID(dbauthz.AsSystemRestricted(ctx), user.ID)
if err != nil {
return xerrors.Errorf("get organization memberships: %w", err)
}
// If the user is not in the default organization, then we can't assign groups.
// A user cannot be in groups to an org they are not a member of.
if !slices.ContainsFunc(memberships, func(member database.OrganizationMember) bool {
return member.OrganizationID == defaultOrganization.ID
}) {
return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID)
}
//nolint:gocritic
err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered, params.CreateMissingGroups)
err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{
defaultOrganization.ID: filtered,
}, params.CreateMissingGroups)
if err != nil {
return xerrors.Errorf("set user groups: %w", err)
}

View File

@ -14,7 +14,7 @@ import (
)
// nolint: revive
func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error {
func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
api.entitlementsMu.RLock()
enabled := api.entitlements.Features[codersdk.FeatureTemplateRBAC].Enabled
api.entitlementsMu.RUnlock()
@ -24,6 +24,8 @@ func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db databa
}
return db.InTx(func(tx database.Store) error {
// When setting the user's groups, it's easier to just clear their groups and re-add them.
// This ensures that the user's groups are always in sync with the auth provider.
orgs, err := tx.GetOrganizationsByUserID(ctx, userID)
if err != nil {
return xerrors.Errorf("get user orgs: %w", err)
@ -33,41 +35,47 @@ func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db databa
}
// Delete all groups the user belongs to.
err = tx.DeleteGroupMembersByOrgAndUser(ctx, database.DeleteGroupMembersByOrgAndUserParams{
UserID: userID,
OrganizationID: orgs[0].ID,
})
// nolint:gocritic // Requires system context to remove user from all groups.
err = tx.RemoveUserFromAllGroups(dbauthz.AsSystemRestricted(ctx), userID)
if err != nil {
return xerrors.Errorf("delete user groups: %w", err)
}
if createMissingGroups {
// This is the system creating these additional groups, so we use the system restricted context.
// nolint:gocritic
created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{
OrganizationID: orgs[0].ID,
// TODO: This could likely be improved by making these single queries.
// Either by batching or some other means. This for loop could be really
// inefficient if there are a lot of organizations. There was deployments
// on v1 with >100 orgs.
for orgID, groupNames := range orgGroupNames {
// Create the missing groups for each organization.
if createMissingGroups {
// This is the system creating these additional groups, so we use the system restricted context.
// nolint:gocritic
created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{
OrganizationID: orgID,
GroupNames: groupNames,
Source: database.GroupSourceOidc,
})
if err != nil {
return xerrors.Errorf("insert missing groups: %w", err)
}
if len(created) > 0 {
logger.Debug(ctx, "auto created missing groups",
slog.F("org_id", orgID.ID),
slog.F("created", created),
slog.F("num", len(created)),
)
}
}
// Re-add the user to all groups returned by the auth provider.
err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{
UserID: userID,
OrganizationID: orgID,
GroupNames: groupNames,
Source: database.GroupSourceOidc,
})
if err != nil {
return xerrors.Errorf("insert missing groups: %w", err)
return xerrors.Errorf("insert user groups: %w", err)
}
if len(created) > 0 {
logger.Debug(ctx, "auto created missing groups",
slog.F("org_id", orgs[0].ID),
slog.F("created", created),
)
}
}
// Re-add the user to all groups returned by the auth provider.
err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{
UserID: userID,
OrganizationID: orgs[0].ID,
GroupNames: groupNames,
})
if err != nil {
return xerrors.Errorf("insert user groups: %w", err)
}
return nil