mirror of https://github.com/coder/coder.git
fix: use unique ID for linked accounts (#3441)
- move OAuth-related fields off of api_keys into a new user_links table - restrict users to single form of login - process updates to user email/usernames for OIDC - added a login_type column to users
This commit is contained in:
parent
53d1fb36db
commit
c3eea98db0
|
@ -94,6 +94,7 @@ var AuditableResources = auditMap(map[any]map[string]Action{
|
|||
"updated_at": ActionIgnore, // Changes, but is implicit and not helpful in a diff.
|
||||
"status": ActionTrack,
|
||||
"rbac_roles": ActionTrack,
|
||||
"login_type": ActionIgnore,
|
||||
},
|
||||
&database.Workspace{}: {
|
||||
"id": ActionTrack,
|
||||
|
|
|
@ -73,6 +73,7 @@ type data struct {
|
|||
organizations []database.Organization
|
||||
organizationMembers []database.OrganizationMember
|
||||
users []database.User
|
||||
userLinks []database.UserLink
|
||||
|
||||
// New tables
|
||||
auditLogs []database.AuditLog
|
||||
|
@ -1454,20 +1455,16 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
|
|||
|
||||
//nolint:gosimple
|
||||
key := database.APIKey{
|
||||
ID: arg.ID,
|
||||
LifetimeSeconds: arg.LifetimeSeconds,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
IPAddress: arg.IPAddress,
|
||||
UserID: arg.UserID,
|
||||
ExpiresAt: arg.ExpiresAt,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
LastUsed: arg.LastUsed,
|
||||
LoginType: arg.LoginType,
|
||||
OAuthAccessToken: arg.OAuthAccessToken,
|
||||
OAuthRefreshToken: arg.OAuthRefreshToken,
|
||||
OAuthIDToken: arg.OAuthIDToken,
|
||||
OAuthExpiry: arg.OAuthExpiry,
|
||||
ID: arg.ID,
|
||||
LifetimeSeconds: arg.LifetimeSeconds,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
IPAddress: arg.IPAddress,
|
||||
UserID: arg.UserID,
|
||||
ExpiresAt: arg.ExpiresAt,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
LastUsed: arg.LastUsed,
|
||||
LoginType: arg.LoginType,
|
||||
}
|
||||
q.apiKeys = append(q.apiKeys, key)
|
||||
return key, nil
|
||||
|
@ -1744,6 +1741,7 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
|
|||
Username: arg.Username,
|
||||
Status: database.UserStatusActive,
|
||||
RBACRoles: arg.RBACRoles,
|
||||
LoginType: arg.LoginType,
|
||||
}
|
||||
q.users = append(q.users, user)
|
||||
return user, nil
|
||||
|
@ -1899,9 +1897,6 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI
|
|||
apiKey.LastUsed = arg.LastUsed
|
||||
apiKey.ExpiresAt = arg.ExpiresAt
|
||||
apiKey.IPAddress = arg.IPAddress
|
||||
apiKey.OAuthAccessToken = arg.OAuthAccessToken
|
||||
apiKey.OAuthRefreshToken = arg.OAuthRefreshToken
|
||||
apiKey.OAuthExpiry = arg.OAuthExpiry
|
||||
q.apiKeys[index] = apiKey
|
||||
return nil
|
||||
}
|
||||
|
@ -2260,3 +2255,80 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) {
|
|||
|
||||
return q.deploymentID, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, link := range q.userLinks {
|
||||
if link.LinkedID == id {
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
//nolint:gosimple
|
||||
link := database.UserLink{
|
||||
UserID: args.UserID,
|
||||
LoginType: args.LoginType,
|
||||
LinkedID: args.LinkedID,
|
||||
OAuthAccessToken: args.OAuthAccessToken,
|
||||
OAuthRefreshToken: args.OAuthRefreshToken,
|
||||
OAuthExpiry: args.OAuthExpiry,
|
||||
}
|
||||
|
||||
q.userLinks = append(q.userLinks, link)
|
||||
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for i, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
link.LinkedID = params.LinkedID
|
||||
|
||||
q.userLinks[i] = link
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for i, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
link.OAuthAccessToken = params.OAuthAccessToken
|
||||
link.OAuthRefreshToken = params.OAuthRefreshToken
|
||||
link.OAuthExpiry = params.OAuthExpiry
|
||||
|
||||
q.userLinks[i] = link
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ func TestNestedInTx(t *testing.T) {
|
|||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
RBACRoles: []string{},
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
return err
|
||||
})
|
||||
|
|
|
@ -96,10 +96,6 @@ CREATE TABLE api_keys (
|
|||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
oauth_access_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_id_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
lifetime_seconds bigint DEFAULT 86400 NOT NULL,
|
||||
ip_address inet DEFAULT '0.0.0.0'::inet NOT NULL
|
||||
);
|
||||
|
@ -267,6 +263,15 @@ CREATE TABLE templates (
|
|||
created_by uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE user_links (
|
||||
user_id uuid NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
linked_id text DEFAULT ''::text NOT NULL,
|
||||
oauth_access_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE users (
|
||||
id uuid NOT NULL,
|
||||
email text NOT NULL,
|
||||
|
@ -275,7 +280,8 @@ CREATE TABLE users (
|
|||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
status user_status DEFAULT 'active'::public.user_status NOT NULL,
|
||||
rbac_roles text[] DEFAULT '{}'::text[] NOT NULL
|
||||
rbac_roles text[] DEFAULT '{}'::text[] NOT NULL,
|
||||
login_type login_type DEFAULT 'password'::public.login_type NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE workspace_agents (
|
||||
|
@ -416,6 +422,9 @@ ALTER TABLE ONLY template_versions
|
|||
ALTER TABLE ONLY templates
|
||||
ADD CONSTRAINT templates_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY user_links
|
||||
ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
|
||||
|
||||
ALTER TABLE ONLY users
|
||||
ADD CONSTRAINT users_pkey PRIMARY KEY (id);
|
||||
|
||||
|
@ -513,6 +522,9 @@ ALTER TABLE ONLY templates
|
|||
ALTER TABLE ONLY templates
|
||||
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_links
|
||||
ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY workspace_agents
|
||||
ADD CONSTRAINT workspace_agents_resource_id_fkey FOREIGN KEY (resource_id) REFERENCES workspace_resources(id) ON DELETE CASCADE;
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
|
|||
(
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Dump the updated schema.
|
||||
go run dump/main.go
|
||||
# The logic below depends on the exact version being correct :(
|
||||
go run github.com/kyleconroy/sqlc/cmd/sqlc@v1.13.0 generate
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
-- This migration makes no attempt to try to populate
|
||||
-- the oauth_access_token, oauth_refresh_token, and oauth_expiry
|
||||
-- columns of api_key rows with the values from the dropped user_links
|
||||
-- table.
|
||||
BEGIN;
|
||||
|
||||
DROP TABLE IF EXISTS user_links;
|
||||
|
||||
ALTER TABLE
|
||||
api_keys
|
||||
ADD COLUMN oauth_access_token text DEFAULT ''::text NOT NULL;
|
||||
|
||||
ALTER TABLE
|
||||
api_keys
|
||||
ADD COLUMN oauth_refresh_token text DEFAULT ''::text NOT NULL;
|
||||
|
||||
ALTER TABLE
|
||||
api_keys
|
||||
ADD COLUMN oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL;
|
||||
|
||||
ALTER TABLE users DROP COLUMN login_type;
|
||||
|
||||
COMMIT;
|
|
@ -0,0 +1,74 @@
|
|||
BEGIN;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_links (
|
||||
user_id uuid NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
linked_id text DEFAULT ''::text NOT NULL,
|
||||
oauth_access_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
PRIMARY KEY(user_id, login_type),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- This migrates columns on api_keys to the new user_links table.
|
||||
-- It does this by finding all the API keys for each user, choosing
|
||||
-- the most recently updated for each one and then assigning its relevant
|
||||
-- values to the user_links table.
|
||||
-- A user should at most have a row for an OIDC account and a Github account.
|
||||
-- 'password' login types are ignored.
|
||||
|
||||
INSERT INTO user_links
|
||||
(
|
||||
user_id,
|
||||
login_type,
|
||||
linked_id,
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_expiry
|
||||
)
|
||||
SELECT
|
||||
keys.user_id,
|
||||
keys.login_type,
|
||||
'',
|
||||
keys.oauth_access_token,
|
||||
keys.oauth_refresh_token,
|
||||
keys.oauth_expiry
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
row_number() OVER (partition by user_id, login_type ORDER BY last_used DESC) AS x,
|
||||
api_keys.* FROM api_keys
|
||||
) as keys
|
||||
WHERE x=1 AND keys.login_type != 'password';
|
||||
|
||||
-- Drop columns that have been migrated to user_links.
|
||||
-- It appears the 'oauth_id_token' was unused and so it has
|
||||
-- been dropped here as well to avoid future confusion.
|
||||
ALTER TABLE api_keys
|
||||
DROP COLUMN oauth_access_token,
|
||||
DROP COLUMN oauth_refresh_token,
|
||||
DROP COLUMN oauth_id_token,
|
||||
DROP COLUMN oauth_expiry;
|
||||
|
||||
ALTER TABLE users ADD COLUMN login_type login_type NOT NULL DEFAULT 'password';
|
||||
|
||||
UPDATE
|
||||
users
|
||||
SET
|
||||
login_type = (
|
||||
SELECT
|
||||
login_type
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
user_links.user_id = users.id
|
||||
ORDER BY oauth_expiry DESC
|
||||
LIMIT 1
|
||||
)
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
user_links.user_id = users.id;
|
||||
|
||||
COMMIT;
|
|
@ -313,20 +313,16 @@ func (e *WorkspaceTransition) Scan(src interface{}) error {
|
|||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"`
|
||||
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"`
|
||||
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
|
||||
}
|
||||
|
||||
type AuditLog struct {
|
||||
|
@ -491,6 +487,16 @@ type User struct {
|
|||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Status UserStatus `db:"status" json:"status"`
|
||||
RBACRoles []string `db:"rbac_roles" json:"rbac_roles"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
}
|
||||
|
||||
type UserLink struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
LinkedID string `db:"linked_id" json:"linked_id"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
type Workspace struct {
|
||||
|
|
|
@ -64,6 +64,8 @@ type querier interface {
|
|||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserCount(ctx context.Context) (int64, error)
|
||||
GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error)
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error)
|
||||
GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error)
|
||||
GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error)
|
||||
|
@ -106,6 +108,7 @@ type querier interface {
|
|||
InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error)
|
||||
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error)
|
||||
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
|
||||
InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error)
|
||||
InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error)
|
||||
InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error)
|
||||
InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error)
|
||||
|
@ -127,6 +130,8 @@ type querier interface {
|
|||
UpdateTemplateVersionByID(ctx context.Context, arg UpdateTemplateVersionByIDParams) error
|
||||
UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg UpdateTemplateVersionDescriptionByJobIDParams) error
|
||||
UpdateUserHashedPassword(ctx context.Context, arg UpdateUserHashedPasswordParams) error
|
||||
UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error)
|
||||
UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error)
|
||||
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
|
||||
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
|
||||
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
|
||||
|
|
|
@ -31,7 +31,7 @@ func (q *sqlQuerier) DeleteAPIKeyByID(ctx context.Context, id string) error {
|
|||
|
||||
const getAPIKeyByID = `-- name: GetAPIKeyByID :one
|
||||
SELECT
|
||||
id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address
|
||||
id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address
|
||||
FROM
|
||||
api_keys
|
||||
WHERE
|
||||
|
@ -52,10 +52,6 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
|
|||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.LoginType,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthIDToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.LifetimeSeconds,
|
||||
&i.IPAddress,
|
||||
)
|
||||
|
@ -63,7 +59,7 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
|
|||
}
|
||||
|
||||
const getAPIKeysLastUsedAfter = `-- name: GetAPIKeysLastUsedAfter :many
|
||||
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1
|
||||
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) {
|
||||
|
@ -84,10 +80,6 @@ func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.
|
|||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.LoginType,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthIDToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.LifetimeSeconds,
|
||||
&i.IPAddress,
|
||||
); err != nil {
|
||||
|
@ -116,11 +108,7 @@ INSERT INTO
|
|||
expires_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
login_type,
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_id_token,
|
||||
oauth_expiry
|
||||
login_type
|
||||
)
|
||||
VALUES
|
||||
($1,
|
||||
|
@ -129,24 +117,20 @@ VALUES
|
|||
WHEN 0 THEN 86400
|
||||
ELSE $2::bigint
|
||||
END
|
||||
, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address
|
||||
, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address
|
||||
`
|
||||
|
||||
type InsertAPIKeyParams struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
ID string `db:"id" json:"id"`
|
||||
LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) {
|
||||
|
@ -161,10 +145,6 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
|
|||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.LoginType,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthIDToken,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
var i APIKey
|
||||
err := row.Scan(
|
||||
|
@ -176,10 +156,6 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
|
|||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.LoginType,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthIDToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.LifetimeSeconds,
|
||||
&i.IPAddress,
|
||||
)
|
||||
|
@ -192,22 +168,16 @@ UPDATE
|
|||
SET
|
||||
last_used = $2,
|
||||
expires_at = $3,
|
||||
ip_address = $4,
|
||||
oauth_access_token = $5,
|
||||
oauth_refresh_token = $6,
|
||||
oauth_expiry = $7
|
||||
ip_address = $4
|
||||
WHERE
|
||||
id = $1
|
||||
`
|
||||
|
||||
type UpdateAPIKeyByIDParams struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
ID string `db:"id" json:"id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error {
|
||||
|
@ -216,9 +186,6 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
|
|||
arg.LastUsed,
|
||||
arg.ExpiresAt,
|
||||
arg.IPAddress,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
@ -2447,6 +2414,169 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context
|
|||
return err
|
||||
}
|
||||
|
||||
const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one
|
||||
SELECT
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
linked_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserLinkByLinkedID, linkedID)
|
||||
var i UserLink
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.LoginType,
|
||||
&i.LinkedID,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one
|
||||
SELECT
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
user_id = $1 AND login_type = $2
|
||||
`
|
||||
|
||||
type GetUserLinkByUserIDLoginTypeParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserLinkByUserIDLoginType, arg.UserID, arg.LoginType)
|
||||
var i UserLink
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.LoginType,
|
||||
&i.LinkedID,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertUserLink = `-- name: InsertUserLink :one
|
||||
INSERT INTO
|
||||
user_links (
|
||||
user_id,
|
||||
login_type,
|
||||
linked_id,
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_expiry
|
||||
)
|
||||
VALUES
|
||||
( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
`
|
||||
|
||||
type InsertUserLinkParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
LinkedID string `db:"linked_id" json:"linked_id"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertUserLink,
|
||||
arg.UserID,
|
||||
arg.LoginType,
|
||||
arg.LinkedID,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
var i UserLink
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.LoginType,
|
||||
&i.LinkedID,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateUserLink = `-- name: UpdateUserLink :one
|
||||
UPDATE
|
||||
user_links
|
||||
SET
|
||||
oauth_access_token = $1,
|
||||
oauth_refresh_token = $2,
|
||||
oauth_expiry = $3
|
||||
WHERE
|
||||
user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
`
|
||||
|
||||
type UpdateUserLinkParams struct {
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserLink,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthExpiry,
|
||||
arg.UserID,
|
||||
arg.LoginType,
|
||||
)
|
||||
var i UserLink
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.LoginType,
|
||||
&i.LinkedID,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateUserLinkedID = `-- name: UpdateUserLinkedID :one
|
||||
UPDATE
|
||||
user_links
|
||||
SET
|
||||
linked_id = $1
|
||||
WHERE
|
||||
user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
`
|
||||
|
||||
type UpdateUserLinkedIDParams struct {
|
||||
LinkedID string `db:"linked_id" json:"linked_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserLinkedID, arg.LinkedID, arg.UserID, arg.LoginType)
|
||||
var i UserLink
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.LoginType,
|
||||
&i.LinkedID,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one
|
||||
SELECT
|
||||
-- username is returned just to help for logging purposes
|
||||
|
@ -2458,13 +2588,13 @@ SELECT
|
|||
array_append(users.rbac_roles, 'member'),
|
||||
-- All org_members get the org-member role for their orgs
|
||||
array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[]
|
||||
AS roles
|
||||
AS roles
|
||||
FROM
|
||||
users
|
||||
LEFT JOIN organization_members
|
||||
ON id = user_id
|
||||
WHERE
|
||||
id = $1
|
||||
id = $1
|
||||
`
|
||||
|
||||
type GetAuthorizationUserRolesRow struct {
|
||||
|
@ -2490,7 +2620,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.
|
|||
|
||||
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
|
||||
SELECT
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
|
@ -2517,13 +2647,14 @@ func (q *sqlQuerier) GetUserByEmailOrUsername(ctx context.Context, arg GetUserBy
|
|||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
pq.Array(&i.RBACRoles),
|
||||
&i.LoginType,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
|
@ -2544,6 +2675,7 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error
|
|||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
pq.Array(&i.RBACRoles),
|
||||
&i.LoginType,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -2564,7 +2696,7 @@ func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) {
|
|||
|
||||
const getUsers = `-- name: GetUsers :many
|
||||
SELECT
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
|
@ -2614,8 +2746,8 @@ WHERE
|
|||
END
|
||||
-- End of filters
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all users, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
-- Deterministic and consistent ordering of all users, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(created_at, id) ASC OFFSET $5
|
||||
LIMIT
|
||||
-- A null limit means "no limit", so 0 means return all
|
||||
|
@ -2656,6 +2788,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
|
|||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
pq.Array(&i.RBACRoles),
|
||||
&i.LoginType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -2671,7 +2804,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
|
|||
}
|
||||
|
||||
const getUsersByIDs = `-- name: GetUsersByIDs :many
|
||||
SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE id = ANY($1 :: uuid [ ])
|
||||
SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE id = ANY($1 :: uuid [ ])
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) {
|
||||
|
@ -2692,6 +2825,7 @@ func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User
|
|||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
pq.Array(&i.RBACRoles),
|
||||
&i.LoginType,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -2715,10 +2849,11 @@ INSERT INTO
|
|||
hashed_password,
|
||||
created_at,
|
||||
updated_at,
|
||||
rbac_roles
|
||||
rbac_roles,
|
||||
login_type
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
|
||||
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
|
||||
`
|
||||
|
||||
type InsertUserParams struct {
|
||||
|
@ -2729,6 +2864,7 @@ type InsertUserParams struct {
|
|||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
RBACRoles []string `db:"rbac_roles" json:"rbac_roles"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
|
||||
|
@ -2740,6 +2876,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
|
|||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
pq.Array(arg.RBACRoles),
|
||||
arg.LoginType,
|
||||
)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
|
@ -2751,6 +2888,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
|
|||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
pq.Array(&i.RBACRoles),
|
||||
&i.LoginType,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -2782,7 +2920,7 @@ SET
|
|||
username = $3,
|
||||
updated_at = $4
|
||||
WHERE
|
||||
id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
|
||||
id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
|
||||
`
|
||||
|
||||
type UpdateUserProfileParams struct {
|
||||
|
@ -2809,19 +2947,20 @@ func (q *sqlQuerier) UpdateUserProfile(ctx context.Context, arg UpdateUserProfil
|
|||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
pq.Array(&i.RBACRoles),
|
||||
&i.LoginType,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateUserRoles = `-- name: UpdateUserRoles :one
|
||||
UPDATE
|
||||
users
|
||||
users
|
||||
SET
|
||||
-- Remove all duplicates from the roles.
|
||||
rbac_roles = ARRAY(SELECT DISTINCT UNNEST($1 :: text[]))
|
||||
WHERE
|
||||
id = $2
|
||||
RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
|
||||
id = $2
|
||||
RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
|
||||
`
|
||||
|
||||
type UpdateUserRolesParams struct {
|
||||
|
@ -2841,6 +2980,7 @@ func (q *sqlQuerier) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesPar
|
|||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
pq.Array(&i.RBACRoles),
|
||||
&i.LoginType,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -2852,7 +2992,7 @@ SET
|
|||
status = $2,
|
||||
updated_at = $3
|
||||
WHERE
|
||||
id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
|
||||
id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
|
||||
`
|
||||
|
||||
type UpdateUserStatusParams struct {
|
||||
|
@ -2873,6 +3013,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP
|
|||
&i.UpdatedAt,
|
||||
&i.Status,
|
||||
pq.Array(&i.RBACRoles),
|
||||
&i.LoginType,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
|
|
@ -23,11 +23,7 @@ INSERT INTO
|
|||
expires_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
login_type,
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_id_token,
|
||||
oauth_expiry
|
||||
login_type
|
||||
)
|
||||
VALUES
|
||||
(@id,
|
||||
|
@ -36,7 +32,7 @@ VALUES
|
|||
WHEN 0 THEN 86400
|
||||
ELSE @lifetime_seconds::bigint
|
||||
END
|
||||
, @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type, @oauth_access_token, @oauth_refresh_token, @oauth_id_token, @oauth_expiry) RETURNING *;
|
||||
, @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type) RETURNING *;
|
||||
|
||||
-- name: UpdateAPIKeyByID :exec
|
||||
UPDATE
|
||||
|
@ -44,10 +40,7 @@ UPDATE
|
|||
SET
|
||||
last_used = $2,
|
||||
expires_at = $3,
|
||||
ip_address = $4,
|
||||
oauth_access_token = $5,
|
||||
oauth_refresh_token = $6,
|
||||
oauth_expiry = $7
|
||||
ip_address = $4
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
-- name: GetUserLinkByLinkedID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
linked_id = $1;
|
||||
|
||||
-- name: GetUserLinkByUserIDLoginType :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
user_id = $1 AND login_type = $2;
|
||||
|
||||
-- name: InsertUserLink :one
|
||||
INSERT INTO
|
||||
user_links (
|
||||
user_id,
|
||||
login_type,
|
||||
linked_id,
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_expiry
|
||||
)
|
||||
VALUES
|
||||
( $1, $2, $3, $4, $5, $6 ) RETURNING *;
|
||||
|
||||
-- name: UpdateUserLinkedID :one
|
||||
UPDATE
|
||||
user_links
|
||||
SET
|
||||
linked_id = $1
|
||||
WHERE
|
||||
user_id = $2 AND login_type = $3 RETURNING *;
|
||||
|
||||
-- name: UpdateUserLink :one
|
||||
UPDATE
|
||||
user_links
|
||||
SET
|
||||
oauth_access_token = $1,
|
||||
oauth_refresh_token = $2,
|
||||
oauth_expiry = $3
|
||||
WHERE
|
||||
user_id = $4 AND login_type = $5 RETURNING *;
|
|
@ -37,10 +37,11 @@ INSERT INTO
|
|||
hashed_password,
|
||||
created_at,
|
||||
updated_at,
|
||||
rbac_roles
|
||||
rbac_roles,
|
||||
login_type
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7) RETURNING *;
|
||||
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
|
||||
|
||||
-- name: UpdateUserProfile :one
|
||||
UPDATE
|
||||
|
@ -54,12 +55,12 @@ WHERE
|
|||
|
||||
-- name: UpdateUserRoles :one
|
||||
UPDATE
|
||||
users
|
||||
users
|
||||
SET
|
||||
-- Remove all duplicates from the roles.
|
||||
rbac_roles = ARRAY(SELECT DISTINCT UNNEST(@granted_roles :: text[]))
|
||||
WHERE
|
||||
id = @id
|
||||
id = @id
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateUserHashedPassword :exec
|
||||
|
@ -122,8 +123,8 @@ WHERE
|
|||
END
|
||||
-- End of filters
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all users, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
-- Deterministic and consistent ordering of all users, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(created_at, id) ASC OFFSET @offset_opt
|
||||
LIMIT
|
||||
-- A null limit means "no limit", so 0 means return all
|
||||
|
@ -152,10 +153,10 @@ SELECT
|
|||
array_append(users.rbac_roles, 'member'),
|
||||
-- All org_members get the org-member role for their orgs
|
||||
array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[]
|
||||
AS roles
|
||||
AS roles
|
||||
FROM
|
||||
users
|
||||
LEFT JOIN organization_members
|
||||
ON id = user_id
|
||||
WHERE
|
||||
id = @user_id;
|
||||
id = @user_id;
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tabbed/pqtype"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
|
@ -149,9 +150,21 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
|||
// Tracks if the API key has properties updated!
|
||||
changed := false
|
||||
|
||||
var link database.UserLink
|
||||
if key.LoginType != database.LoginTypePassword {
|
||||
link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
if err != nil {
|
||||
write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred",
|
||||
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Check if the OAuth token is expired!
|
||||
if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() {
|
||||
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() {
|
||||
var oauthConfig OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
|
@ -167,9 +180,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
|||
}
|
||||
// If it is, let's refresh it from the provided config!
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: key.OAuthAccessToken,
|
||||
RefreshToken: key.OAuthRefreshToken,
|
||||
Expiry: key.OAuthExpiry,
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
|
@ -178,9 +191,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
|||
})
|
||||
return
|
||||
}
|
||||
key.OAuthAccessToken = token.AccessToken
|
||||
key.OAuthRefreshToken = token.RefreshToken
|
||||
key.OAuthExpiry = token.Expiry
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
key.ExpiresAt = token.Expiry
|
||||
changed = true
|
||||
}
|
||||
|
@ -222,13 +235,10 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
|||
}
|
||||
if changed {
|
||||
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
OAuthAccessToken: key.OAuthAccessToken,
|
||||
OAuthRefreshToken: key.OAuthRefreshToken,
|
||||
OAuthExpiry: key.OAuthExpiry,
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
})
|
||||
if err != nil {
|
||||
write(http.StatusInternalServerError, codersdk.Response{
|
||||
|
@ -237,6 +247,24 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
|||
})
|
||||
return
|
||||
}
|
||||
// If the API Key is associated with a user_link (e.g. Github/OIDC)
|
||||
// then we want to update the relevant oauth fields.
|
||||
if link.UserID != uuid.Nil {
|
||||
link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
})
|
||||
if err != nil {
|
||||
write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the key is valid, we also fetch the user roles and status.
|
||||
|
|
|
@ -187,6 +187,7 @@ func TestAPIKey(t *testing.T) {
|
|||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
|
@ -215,6 +216,7 @@ func TestAPIKey(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -253,6 +255,7 @@ func TestAPIKey(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -288,6 +291,7 @@ func TestAPIKey(t *testing.T) {
|
|||
LastUsed: database.Now().AddDate(0, 0, -1),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
|
@ -323,6 +327,7 @@ func TestAPIKey(t *testing.T) {
|
|||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
|
@ -361,6 +366,13 @@ func TestAPIKey(t *testing.T) {
|
|||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
|
@ -393,10 +405,16 @@ func TestAPIKey(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeGithub,
|
||||
LastUsed: database.Now(),
|
||||
OAuthExpiry: database.Now().AddDate(0, 0, -1),
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
OAuthExpiry: database.Now().AddDate(0, 0, -1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "wow",
|
||||
RefreshToken: "moo",
|
||||
|
@ -418,7 +436,6 @@ func TestAPIKey(t *testing.T) {
|
|||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
|
||||
require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken)
|
||||
})
|
||||
|
||||
t.Run("RemoteIPUpdates", func(t *testing.T) {
|
||||
|
@ -443,6 +460,7 @@ func TestAPIKey(t *testing.T) {
|
|||
LastUsed: database.Now().AddDate(0, 0, -1),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
|
|
|
@ -124,6 +124,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s
|
|||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ func TestOrganizationParam(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
|
|
|
@ -53,6 +53,7 @@ func TestTemplateParam(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ func TestTemplateVersionParam(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ func TestUserParam(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ func TestWorkspaceParam(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) {
|
|||
HashedSecret: hashed[:],
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(5 * time.Hour),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fDB.InsertUser(ctx, database.InsertUserParams{
|
||||
|
|
|
@ -6,12 +6,14 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
|
@ -47,10 +49,13 @@ func (api *API) userAuthMethods(rw http.ResponseWriter, _ *http.Request) {
|
|||
}
|
||||
|
||||
func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
var (
|
||||
ctx = r.Context()
|
||||
state = httpmw.OAuth2(r)
|
||||
)
|
||||
|
||||
oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token))
|
||||
memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient)
|
||||
oauthClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(state.Token))
|
||||
memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(ctx, oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching authenticated Github user organizations.",
|
||||
|
@ -75,7 +80,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient)
|
||||
ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(ctx, oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching authenticated Github user.",
|
||||
|
@ -94,7 +99,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
|||
continue
|
||||
}
|
||||
|
||||
allowedTeam, err = api.GithubOAuth2Config.TeamMembership(r.Context(), oauthClient, allowTeam.Organization, allowTeam.Slug, *ghUser.Login)
|
||||
allowedTeam, err = api.GithubOAuth2Config.TeamMembership(ctx, oauthClient, allowTeam.Organization, allowTeam.Slug, *ghUser.Login)
|
||||
// The calling user may not have permission to the requested team!
|
||||
if err != nil {
|
||||
continue
|
||||
|
@ -108,7 +113,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient)
|
||||
emails, err := api.GithubOAuth2Config.ListEmails(ctx, oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching personal Github user.",
|
||||
|
@ -117,33 +122,35 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var user database.User
|
||||
// Search for existing users with matching and verified emails.
|
||||
// If a verified GitHub email matches a Coder user, we will return.
|
||||
verifiedEmails := make([]string, 0, len(emails))
|
||||
for _, email := range emails {
|
||||
if !email.GetVerified() {
|
||||
continue
|
||||
}
|
||||
user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
Email: *email.Email,
|
||||
verifiedEmails = append(verifiedEmails, email.GetEmail())
|
||||
}
|
||||
|
||||
if len(verifiedEmails) == 0 {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Verify an email address on Github to authenticate!",
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: fmt.Sprintf("Internal error fetching user by email %q.", *email.Email),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !*email.Verified {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
user, link, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "An internal error occurred.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if user.ID != uuid.Nil && user.LoginType != database.LoginTypeGithub {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeGithub, user.LoginType),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// If the user doesn't exist, create a new one!
|
||||
|
@ -177,10 +184,13 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
return
|
||||
}
|
||||
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
|
||||
Email: *verifiedEmail.Email,
|
||||
Username: *ghUser.Login,
|
||||
OrganizationID: organizationID,
|
||||
user, _, err = api.createUser(ctx, createUserRequest{
|
||||
CreateUserRequest: codersdk.CreateUserRequest{
|
||||
Email: *verifiedEmail.Email,
|
||||
Username: *ghUser.Login,
|
||||
OrganizationID: organizationID,
|
||||
},
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
|
@ -191,12 +201,49 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
_, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
OAuthAccessToken: state.Token.AccessToken,
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
// This can happen if a user is a built-in user but is signing in
|
||||
// with Github for the first time.
|
||||
if link.UserID == uuid.Nil {
|
||||
link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
LinkedID: githubLinkedID(ghUser),
|
||||
OAuthAccessToken: state.Token.AccessToken,
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred.",
|
||||
Detail: fmt.Sprintf("insert user link: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// LEGACY: Remove 10/2022.
|
||||
// We started tracking linked IDs later so it's possible for a user to be a
|
||||
// pre-existing Github user and not have a linked ID. The migration
|
||||
// to user_links did not populate this field as it requires calling out
|
||||
// to Github to query the user's ID.
|
||||
if link.LinkedID == "" {
|
||||
link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{
|
||||
UserID: user.ID,
|
||||
LinkedID: githubLinkedID(ghUser),
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred.",
|
||||
Detail: xerrors.Errorf("update user link: %w", err.Error).Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, created := api.createAPIKey(rw, r, createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
|
@ -219,7 +266,10 @@ type OIDCConfig struct {
|
|||
}
|
||||
|
||||
func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
var (
|
||||
ctx = r.Context()
|
||||
state = httpmw.OAuth2(r)
|
||||
)
|
||||
|
||||
// See the example here: https://github.com/coreos/go-oidc
|
||||
rawIDToken, ok := state.Token.Extra("id_token").(string)
|
||||
|
@ -230,7 +280,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
idToken, err := api.OIDCConfig.Verifier.Verify(r.Context(), rawIDToken)
|
||||
idToken, err := api.OIDCConfig.Verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to verify OIDC token.",
|
||||
|
@ -285,29 +335,48 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
var user database.User
|
||||
user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
Email: claims.Email,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if !api.OIDCConfig.AllowSignups {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Signups are disabled for OIDC authentication!",
|
||||
})
|
||||
return
|
||||
}
|
||||
user, link, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to find user.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if user.ID == uuid.Nil && !api.OIDCConfig.AllowSignups {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Signups are disabled for OIDC authentication!",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if user.ID != uuid.Nil && user.LoginType != database.LoginTypeOIDC {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// This can happen if a user is a built-in user but is signing in
|
||||
// with OIDC for the first time.
|
||||
if user.ID == uuid.Nil {
|
||||
var organizationID uuid.UUID
|
||||
organizations, _ := api.Database.GetOrganizations(r.Context())
|
||||
organizations, _ := api.Database.GetOrganizations(ctx)
|
||||
if len(organizations) > 0 {
|
||||
// Add the user to the first organization. Once multi-organization
|
||||
// support is added, we should enable a configuration map of user
|
||||
// email to organization.
|
||||
organizationID = organizations[0].ID
|
||||
}
|
||||
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
|
||||
Email: claims.Email,
|
||||
Username: claims.Username,
|
||||
OrganizationID: organizationID,
|
||||
|
||||
user, _, err = api.createUser(ctx, createUserRequest{
|
||||
CreateUserRequest: codersdk.CreateUserRequest{
|
||||
Email: claims.Email,
|
||||
Username: claims.Username,
|
||||
OrganizationID: organizationID,
|
||||
},
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
|
@ -316,21 +385,81 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get user by email.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to insert user auth metadata.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthAccessToken: state.Token.AccessToken,
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
if link.UserID == uuid.Nil {
|
||||
link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
LinkedID: oidcLinkedID(idToken),
|
||||
OAuthAccessToken: state.Token.AccessToken,
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred.",
|
||||
Detail: fmt.Sprintf("insert user link: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// LEGACY: Remove 10/2022.
|
||||
// We started tracking linked IDs later so it's possible for a user to be a
|
||||
// pre-existing OIDC user and not have a linked ID.
|
||||
// The migration that added the user_links table could not populate
|
||||
// the 'linked_id' field since it requires fields off the access token.
|
||||
if link.LinkedID == "" {
|
||||
link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{
|
||||
UserID: user.ID,
|
||||
LinkedID: oidcLinkedID(idToken),
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred.",
|
||||
Detail: xerrors.Errorf("update user link: %w", err.Error).Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If the upstream email or username has changed we should mirror
|
||||
// that in Coder. Many enterprises use a user's email/username as
|
||||
// security auditing fields so they need to stay synced.
|
||||
if user.Email != claims.Email || user.Username != claims.Username {
|
||||
// TODO(JonA): Since we're processing updates to a user's upstream
|
||||
// email/username, it's possible for a different built-in user to
|
||||
// have already claimed the username.
|
||||
// In such cases in the current implementation this user can now no
|
||||
// longer sign in until an administrator finds the offending built-in
|
||||
// user and changes their username.
|
||||
user, err = api.Database.UpdateUserProfile(ctx, database.UpdateUserProfileParams{
|
||||
ID: user.ID,
|
||||
Email: claims.Email,
|
||||
Username: claims.Username,
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update user profile.",
|
||||
Detail: fmt.Sprintf("update user profile: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, created := api.createAPIKey(rw, r, createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
|
@ -342,3 +471,66 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// githubLinkedID returns the unique ID for a GitHub user.
|
||||
func githubLinkedID(u *github.User) string {
|
||||
return strconv.FormatInt(u.GetID(), 10)
|
||||
}
|
||||
|
||||
// oidcLinkedID returns the uniqued ID for an OIDC user.
|
||||
// See https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability .
|
||||
func oidcLinkedID(tok *oidc.IDToken) string {
|
||||
return strings.Join([]string{tok.Issuer, tok.Subject}, "||")
|
||||
}
|
||||
|
||||
// findLinkedUser tries to find a user by their unique OAuth-linked ID.
|
||||
// If it doesn't not find it, it returns the user by their email.
|
||||
func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, database.UserLink, error) {
|
||||
var (
|
||||
user database.User
|
||||
link database.UserLink
|
||||
)
|
||||
link, err := db.GetUserLinkByLinkedID(ctx, linkedID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return user, link, xerrors.Errorf("get user auth by linked ID: %w", err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
user, err = db.GetUserByID(ctx, link.UserID)
|
||||
if err != nil {
|
||||
return database.User{}, database.UserLink{}, xerrors.Errorf("get user by id: %w", err)
|
||||
}
|
||||
return user, link, nil
|
||||
}
|
||||
|
||||
for _, email := range emails {
|
||||
user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
|
||||
Email: email,
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return user, link, xerrors.Errorf("get user by email: %w", err)
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if user.ID == uuid.Nil {
|
||||
// No user found.
|
||||
return database.User{}, database.UserLink{}, nil
|
||||
}
|
||||
|
||||
// LEGACY: This is annoying but we have to search for the user_link
|
||||
// again except this time we search by user_id and login_type. It's
|
||||
// possible that a user_link exists without a populated 'linked_id'.
|
||||
link, err = db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: user.ID,
|
||||
LoginType: user.LoginType,
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return database.User{}, database.UserLink{}, xerrors.Errorf("get user link by user id and login type: %w", err)
|
||||
}
|
||||
|
||||
return user, link, nil
|
||||
}
|
||||
|
|
|
@ -175,38 +175,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
|||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
})
|
||||
t.Run("Signup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
AllowSignups: true,
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{
|
||||
Login: github.String("kyle"),
|
||||
}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{{
|
||||
Email: github.String("kyle@coder.com"),
|
||||
Verified: github.Bool(true),
|
||||
Primary: github.Bool(true),
|
||||
}}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
})
|
||||
t.Run("Login", func(t *testing.T) {
|
||||
t.Run("MultiLoginNotAllowed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
|
@ -230,9 +199,50 @@ func TestUserOAuth2Github(t *testing.T) {
|
|||
},
|
||||
},
|
||||
})
|
||||
// Creates the first user with login_type 'password'.
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
// Attempting to login should give us a 403 since the user
|
||||
// already has a login_type of 'password'.
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
})
|
||||
t.Run("Signup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
AllowSignups: true,
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{
|
||||
Login: github.String("kyle"),
|
||||
ID: i64ptr(1234),
|
||||
}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{{
|
||||
Email: github.String("kyle@coder.com"),
|
||||
Verified: github.Bool(true),
|
||||
Primary: github.Bool(true),
|
||||
}}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
client.SessionToken = resp.Cookies()[0].Value
|
||||
user, err := client.User(context.Background(), "me")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "kyle@coder.com", user.Email)
|
||||
require.Equal(t, "kyle", user.Username)
|
||||
})
|
||||
t.Run("SignupAllowedTeam", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
@ -415,11 +425,13 @@ func createOIDCConfig(t *testing.T, claims jwt.MapClaims) *coderd.OIDCConfig {
|
|||
|
||||
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
claims["iss"] = "https://coder.com"
|
||||
claims["sub"] = "hello"
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifier := oidc.NewVerifier("", &oidc.StaticKeySet{
|
||||
verifier := oidc.NewVerifier("https://coder.com", &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
|
@ -480,3 +492,7 @@ func oidcCallback(t *testing.T, client *codersdk.Client) *http.Response {
|
|||
t.Log(string(data))
|
||||
return res
|
||||
}
|
||||
|
||||
func i64ptr(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
|
|
|
@ -77,10 +77,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{
|
||||
Email: createUser.Email,
|
||||
Username: createUser.Username,
|
||||
Password: createUser.Password,
|
||||
user, organizationID, err := api.createUser(r.Context(), createUserRequest{
|
||||
CreateUserRequest: codersdk.CreateUserRequest{
|
||||
Email: createUser.Email,
|
||||
Username: createUser.Username,
|
||||
Password: createUser.Password,
|
||||
},
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
|
@ -196,14 +199,14 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var createUser codersdk.CreateUserRequest
|
||||
if !httpapi.Read(rw, r, &createUser) {
|
||||
var req codersdk.CreateUserRequest
|
||||
if !httpapi.Read(rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create the organization member in the org.
|
||||
if !api.Authorize(r, rbac.ActionCreate,
|
||||
rbac.ResourceOrganizationMember.InOrg(createUser.OrganizationID)) {
|
||||
rbac.ResourceOrganizationMember.InOrg(req.OrganizationID)) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
@ -211,8 +214,8 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
|||
// TODO: @emyrk Authorize the organization create if the createUser will do that.
|
||||
|
||||
_, err := api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
Username: createUser.Username,
|
||||
Email: createUser.Email,
|
||||
Username: req.Username,
|
||||
Email: req.Email,
|
||||
})
|
||||
if err == nil {
|
||||
httpapi.Write(rw, http.StatusConflict, codersdk.Response{
|
||||
|
@ -228,10 +231,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
_, err = api.Database.GetOrganizationByID(r.Context(), createUser.OrganizationID)
|
||||
_, err = api.Database.GetOrganizationByID(r.Context(), req.OrganizationID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Organization does not exist with the provided id %q.", createUser.OrganizationID),
|
||||
Message: fmt.Sprintf("Organization does not exist with the provided id %q.", req.OrganizationID),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -243,7 +246,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
user, _, err := api.createUser(r.Context(), createUser)
|
||||
user, _, err := api.createUser(r.Context(), createUserRequest{
|
||||
CreateUserRequest: req,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating user.",
|
||||
|
@ -257,7 +263,7 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
|||
Users: []telemetry.User{telemetry.ConvertUser(user)},
|
||||
})
|
||||
|
||||
httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{createUser.OrganizationID}))
|
||||
httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{req.OrganizationID}))
|
||||
}
|
||||
|
||||
// Returns the parameterized user requested. All validation
|
||||
|
@ -701,6 +707,13 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if user.LoginType != database.LoginTypePassword {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypePassword, user.LoginType),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// If the user logged into a suspended account, reject the login request.
|
||||
if user.Status != database.UserStatusActive {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, codersdk.Response{
|
||||
|
@ -709,7 +722,7 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
sessionToken, created := api.createAPIKey(rw, r, createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
|
@ -732,7 +745,7 @@ func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
lifeTime := time.Hour * 24 * 7
|
||||
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
sessionToken, created := api.createAPIKey(rw, r, createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
// All api generated keys will last 1 week. Browser login tokens have
|
||||
|
@ -818,7 +831,16 @@ func generateAPIKeyIDSecret() (id string, secret string, err error) {
|
|||
return id, secret, nil
|
||||
}
|
||||
|
||||
func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) {
|
||||
type createAPIKeyParams struct {
|
||||
UserID uuid.UUID
|
||||
LoginType database.LoginType
|
||||
|
||||
// Optional.
|
||||
ExpiresAt time.Time
|
||||
LifetimeSeconds int64
|
||||
}
|
||||
|
||||
func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params createAPIKeyParams) (string, bool) {
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
|
@ -856,15 +878,11 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat
|
|||
Valid: true,
|
||||
},
|
||||
// Make sure in UTC time for common time zone
|
||||
ExpiresAt: params.ExpiresAt.UTC(),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: params.LoginType,
|
||||
OAuthAccessToken: params.OAuthAccessToken,
|
||||
OAuthRefreshToken: params.OAuthRefreshToken,
|
||||
OAuthIDToken: params.OAuthIDToken,
|
||||
OAuthExpiry: params.OAuthExpiry,
|
||||
ExpiresAt: params.ExpiresAt.UTC(),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: params.LoginType,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
|
@ -891,7 +909,12 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat
|
|||
return sessionToken, true
|
||||
}
|
||||
|
||||
func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) {
|
||||
type createUserRequest struct {
|
||||
codersdk.CreateUserRequest
|
||||
LoginType database.LoginType
|
||||
}
|
||||
|
||||
func (api *API) createUser(ctx context.Context, req createUserRequest) (database.User, uuid.UUID, error) {
|
||||
var user database.User
|
||||
return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error {
|
||||
orgRoles := make([]string, 0)
|
||||
|
@ -918,6 +941,7 @@ func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest)
|
|||
UpdatedAt: database.Now(),
|
||||
// All new users are defaulted to members of the site.
|
||||
RBACRoles: []string{},
|
||||
LoginType: req.LoginType,
|
||||
}
|
||||
// If a user signs up with OAuth, they can have no password!
|
||||
if req.Password != "" {
|
||||
|
|
|
@ -27,6 +27,7 @@ type LoginType string
|
|||
const (
|
||||
LoginTypePassword LoginType = "password"
|
||||
LoginTypeGithub LoginType = "github"
|
||||
LoginTypeOIDC LoginType = "oidc"
|
||||
)
|
||||
|
||||
type UsersRequest struct {
|
||||
|
|
|
@ -561,7 +561,7 @@ export type LogLevel = "debug" | "error" | "info" | "trace" | "warn"
|
|||
export type LogSource = "provisioner" | "provisioner_daemon"
|
||||
|
||||
// From codersdk/users.go
|
||||
export type LoginType = "github" | "password"
|
||||
export type LoginType = "github" | "oidc" | "password"
|
||||
|
||||
// From codersdk/parameters.go
|
||||
export type ParameterDestinationScheme = "environment_variable" | "none" | "provisioner_variable"
|
||||
|
|
Loading…
Reference in New Issue