fix: use unique ID for linked accounts (#3441)

- move OAuth-related fields off of api_keys into a new user_links table
- restrict users to single form of login
- process updates to user email/usernames for OIDC
- added a login_type column to users
This commit is contained in:
Jon Ayers 2022-08-17 18:00:53 -05:00 committed by GitHub
parent 53d1fb36db
commit c3eea98db0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 931 additions and 266 deletions

View File

@ -94,6 +94,7 @@ var AuditableResources = auditMap(map[any]map[string]Action{
"updated_at": ActionIgnore, // Changes, but is implicit and not helpful in a diff.
"status": ActionTrack,
"rbac_roles": ActionTrack,
"login_type": ActionIgnore,
},
&database.Workspace{}: {
"id": ActionTrack,

View File

@ -73,6 +73,7 @@ type data struct {
organizations []database.Organization
organizationMembers []database.OrganizationMember
users []database.User
userLinks []database.UserLink
// New tables
auditLogs []database.AuditLog
@ -1454,20 +1455,16 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
//nolint:gosimple
key := database.APIKey{
ID: arg.ID,
LifetimeSeconds: arg.LifetimeSeconds,
HashedSecret: arg.HashedSecret,
IPAddress: arg.IPAddress,
UserID: arg.UserID,
ExpiresAt: arg.ExpiresAt,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
LastUsed: arg.LastUsed,
LoginType: arg.LoginType,
OAuthAccessToken: arg.OAuthAccessToken,
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthIDToken: arg.OAuthIDToken,
OAuthExpiry: arg.OAuthExpiry,
ID: arg.ID,
LifetimeSeconds: arg.LifetimeSeconds,
HashedSecret: arg.HashedSecret,
IPAddress: arg.IPAddress,
UserID: arg.UserID,
ExpiresAt: arg.ExpiresAt,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
LastUsed: arg.LastUsed,
LoginType: arg.LoginType,
}
q.apiKeys = append(q.apiKeys, key)
return key, nil
@ -1744,6 +1741,7 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
Username: arg.Username,
Status: database.UserStatusActive,
RBACRoles: arg.RBACRoles,
LoginType: arg.LoginType,
}
q.users = append(q.users, user)
return user, nil
@ -1899,9 +1897,6 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI
apiKey.LastUsed = arg.LastUsed
apiKey.ExpiresAt = arg.ExpiresAt
apiKey.IPAddress = arg.IPAddress
apiKey.OAuthAccessToken = arg.OAuthAccessToken
apiKey.OAuthRefreshToken = arg.OAuthRefreshToken
apiKey.OAuthExpiry = arg.OAuthExpiry
q.apiKeys[index] = apiKey
return nil
}
@ -2260,3 +2255,80 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) {
return q.deploymentID, nil
}
func (q *fakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, link := range q.userLinks {
if link.LinkedID == id {
return link, nil
}
}
return database.UserLink{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, link := range q.userLinks {
if link.UserID == params.UserID && link.LoginType == params.LoginType {
return link, nil
}
}
return database.UserLink{}, sql.ErrNoRows
}
func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
//nolint:gosimple
link := database.UserLink{
UserID: args.UserID,
LoginType: args.LoginType,
LinkedID: args.LinkedID,
OAuthAccessToken: args.OAuthAccessToken,
OAuthRefreshToken: args.OAuthRefreshToken,
OAuthExpiry: args.OAuthExpiry,
}
q.userLinks = append(q.userLinks, link)
return link, nil
}
func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for i, link := range q.userLinks {
if link.UserID == params.UserID && link.LoginType == params.LoginType {
link.LinkedID = params.LinkedID
q.userLinks[i] = link
return link, nil
}
}
return database.UserLink{}, sql.ErrNoRows
}
func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for i, link := range q.userLinks {
if link.UserID == params.UserID && link.LoginType == params.LoginType {
link.OAuthAccessToken = params.OAuthAccessToken
link.OAuthRefreshToken = params.OAuthRefreshToken
link.OAuthExpiry = params.OAuthExpiry
q.userLinks[i] = link
return link, nil
}
}
return database.UserLink{}, sql.ErrNoRows
}

View File

@ -37,6 +37,7 @@ func TestNestedInTx(t *testing.T) {
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
RBACRoles: []string{},
LoginType: database.LoginTypeGithub,
})
return err
})

View File

@ -96,10 +96,6 @@ CREATE TABLE api_keys (
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
login_type login_type NOT NULL,
oauth_access_token text DEFAULT ''::text NOT NULL,
oauth_refresh_token text DEFAULT ''::text NOT NULL,
oauth_id_token text DEFAULT ''::text NOT NULL,
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
lifetime_seconds bigint DEFAULT 86400 NOT NULL,
ip_address inet DEFAULT '0.0.0.0'::inet NOT NULL
);
@ -267,6 +263,15 @@ CREATE TABLE templates (
created_by uuid NOT NULL
);
CREATE TABLE user_links (
user_id uuid NOT NULL,
login_type login_type NOT NULL,
linked_id text DEFAULT ''::text NOT NULL,
oauth_access_token text DEFAULT ''::text NOT NULL,
oauth_refresh_token text DEFAULT ''::text NOT NULL,
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL
);
CREATE TABLE users (
id uuid NOT NULL,
email text NOT NULL,
@ -275,7 +280,8 @@ CREATE TABLE users (
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
status user_status DEFAULT 'active'::public.user_status NOT NULL,
rbac_roles text[] DEFAULT '{}'::text[] NOT NULL
rbac_roles text[] DEFAULT '{}'::text[] NOT NULL,
login_type login_type DEFAULT 'password'::public.login_type NOT NULL
);
CREATE TABLE workspace_agents (
@ -416,6 +422,9 @@ ALTER TABLE ONLY template_versions
ALTER TABLE ONLY templates
ADD CONSTRAINT templates_pkey PRIMARY KEY (id);
ALTER TABLE ONLY user_links
ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type);
ALTER TABLE ONLY users
ADD CONSTRAINT users_pkey PRIMARY KEY (id);
@ -513,6 +522,9 @@ ALTER TABLE ONLY templates
ALTER TABLE ONLY templates
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_links
ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_agents
ADD CONSTRAINT workspace_agents_resource_id_fkey FOREIGN KEY (resource_id) REFERENCES workspace_resources(id) ON DELETE CASCADE;

View File

@ -13,6 +13,8 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
(
cd "$SCRIPT_DIR"
# Dump the updated schema.
go run dump/main.go
# The logic below depends on the exact version being correct :(
go run github.com/kyleconroy/sqlc/cmd/sqlc@v1.13.0 generate

View File

@ -0,0 +1,23 @@
-- This migration makes no attempt to try to populate
-- the oauth_access_token, oauth_refresh_token, and oauth_expiry
-- columns of api_key rows with the values from the dropped user_links
-- table.
BEGIN;
DROP TABLE IF EXISTS user_links;
ALTER TABLE
api_keys
ADD COLUMN oauth_access_token text DEFAULT ''::text NOT NULL;
ALTER TABLE
api_keys
ADD COLUMN oauth_refresh_token text DEFAULT ''::text NOT NULL;
ALTER TABLE
api_keys
ADD COLUMN oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL;
ALTER TABLE users DROP COLUMN login_type;
COMMIT;

View File

@ -0,0 +1,74 @@
BEGIN;
CREATE TABLE IF NOT EXISTS user_links (
user_id uuid NOT NULL,
login_type login_type NOT NULL,
linked_id text DEFAULT ''::text NOT NULL,
oauth_access_token text DEFAULT ''::text NOT NULL,
oauth_refresh_token text DEFAULT ''::text NOT NULL,
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
PRIMARY KEY(user_id, login_type),
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
);
-- This migrates columns on api_keys to the new user_links table.
-- It does this by finding all the API keys for each user, choosing
-- the most recently updated for each one and then assigning its relevant
-- values to the user_links table.
-- A user should at most have a row for an OIDC account and a Github account.
-- 'password' login types are ignored.
INSERT INTO user_links
(
user_id,
login_type,
linked_id,
oauth_access_token,
oauth_refresh_token,
oauth_expiry
)
SELECT
keys.user_id,
keys.login_type,
'',
keys.oauth_access_token,
keys.oauth_refresh_token,
keys.oauth_expiry
FROM
(
SELECT
row_number() OVER (partition by user_id, login_type ORDER BY last_used DESC) AS x,
api_keys.* FROM api_keys
) as keys
WHERE x=1 AND keys.login_type != 'password';
-- Drop columns that have been migrated to user_links.
-- It appears the 'oauth_id_token' was unused and so it has
-- been dropped here as well to avoid future confusion.
ALTER TABLE api_keys
DROP COLUMN oauth_access_token,
DROP COLUMN oauth_refresh_token,
DROP COLUMN oauth_id_token,
DROP COLUMN oauth_expiry;
ALTER TABLE users ADD COLUMN login_type login_type NOT NULL DEFAULT 'password';
UPDATE
users
SET
login_type = (
SELECT
login_type
FROM
user_links
WHERE
user_links.user_id = users.id
ORDER BY oauth_expiry DESC
LIMIT 1
)
FROM
user_links
WHERE
user_links.user_id = users.id;
COMMIT;

View File

@ -313,20 +313,16 @@ func (e *WorkspaceTransition) Scan(src interface{}) error {
}
type APIKey struct {
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"`
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"`
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
}
type AuditLog struct {
@ -491,6 +487,16 @@ type User struct {
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Status UserStatus `db:"status" json:"status"`
RBACRoles []string `db:"rbac_roles" json:"rbac_roles"`
LoginType LoginType `db:"login_type" json:"login_type"`
}
type UserLink struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
LinkedID string `db:"linked_id" json:"linked_id"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
}
type Workspace struct {

View File

@ -64,6 +64,8 @@ type querier interface {
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserCount(ctx context.Context) (int64, error)
GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error)
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error)
GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error)
GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error)
@ -106,6 +108,7 @@ type querier interface {
InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error)
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error)
InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error)
InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error)
InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error)
@ -127,6 +130,8 @@ type querier interface {
UpdateTemplateVersionByID(ctx context.Context, arg UpdateTemplateVersionByIDParams) error
UpdateTemplateVersionDescriptionByJobID(ctx context.Context, arg UpdateTemplateVersionDescriptionByJobIDParams) error
UpdateUserHashedPassword(ctx context.Context, arg UpdateUserHashedPasswordParams) error
UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error)
UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error)
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)

View File

@ -31,7 +31,7 @@ func (q *sqlQuerier) DeleteAPIKeyByID(ctx context.Context, id string) error {
const getAPIKeyByID = `-- name: GetAPIKeyByID :one
SELECT
id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address
id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address
FROM
api_keys
WHERE
@ -52,10 +52,6 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
&i.CreatedAt,
&i.UpdatedAt,
&i.LoginType,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthIDToken,
&i.OAuthExpiry,
&i.LifetimeSeconds,
&i.IPAddress,
)
@ -63,7 +59,7 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
}
const getAPIKeysLastUsedAfter = `-- name: GetAPIKeysLastUsedAfter :many
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1
SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1
`
func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) {
@ -84,10 +80,6 @@ func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.
&i.CreatedAt,
&i.UpdatedAt,
&i.LoginType,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthIDToken,
&i.OAuthExpiry,
&i.LifetimeSeconds,
&i.IPAddress,
); err != nil {
@ -116,11 +108,7 @@ INSERT INTO
expires_at,
created_at,
updated_at,
login_type,
oauth_access_token,
oauth_refresh_token,
oauth_id_token,
oauth_expiry
login_type
)
VALUES
($1,
@ -129,24 +117,20 @@ VALUES
WHEN 0 THEN 86400
ELSE $2::bigint
END
, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry, lifetime_seconds, ip_address
, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address
`
type InsertAPIKeyParams struct {
ID string `db:"id" json:"id"`
LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
ID string `db:"id" json:"id"`
LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
}
func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) {
@ -161,10 +145,6 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
arg.CreatedAt,
arg.UpdatedAt,
arg.LoginType,
arg.OAuthAccessToken,
arg.OAuthRefreshToken,
arg.OAuthIDToken,
arg.OAuthExpiry,
)
var i APIKey
err := row.Scan(
@ -176,10 +156,6 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
&i.CreatedAt,
&i.UpdatedAt,
&i.LoginType,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthIDToken,
&i.OAuthExpiry,
&i.LifetimeSeconds,
&i.IPAddress,
)
@ -192,22 +168,16 @@ UPDATE
SET
last_used = $2,
expires_at = $3,
ip_address = $4,
oauth_access_token = $5,
oauth_refresh_token = $6,
oauth_expiry = $7
ip_address = $4
WHERE
id = $1
`
type UpdateAPIKeyByIDParams struct {
ID string `db:"id" json:"id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
ID string `db:"id" json:"id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
}
func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error {
@ -216,9 +186,6 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
arg.LastUsed,
arg.ExpiresAt,
arg.IPAddress,
arg.OAuthAccessToken,
arg.OAuthRefreshToken,
arg.OAuthExpiry,
)
return err
}
@ -2447,6 +2414,169 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context
return err
}
const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one
SELECT
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
FROM
user_links
WHERE
linked_id = $1
`
func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error) {
row := q.db.QueryRowContext(ctx, getUserLinkByLinkedID, linkedID)
var i UserLink
err := row.Scan(
&i.UserID,
&i.LoginType,
&i.LinkedID,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
)
return i, err
}
const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one
SELECT
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
FROM
user_links
WHERE
user_id = $1 AND login_type = $2
`
type GetUserLinkByUserIDLoginTypeParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
}
func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error) {
row := q.db.QueryRowContext(ctx, getUserLinkByUserIDLoginType, arg.UserID, arg.LoginType)
var i UserLink
err := row.Scan(
&i.UserID,
&i.LoginType,
&i.LinkedID,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
)
return i, err
}
const insertUserLink = `-- name: InsertUserLink :one
INSERT INTO
user_links (
user_id,
login_type,
linked_id,
oauth_access_token,
oauth_refresh_token,
oauth_expiry
)
VALUES
( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
`
type InsertUserLinkParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
LinkedID string `db:"linked_id" json:"linked_id"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
}
func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) {
row := q.db.QueryRowContext(ctx, insertUserLink,
arg.UserID,
arg.LoginType,
arg.LinkedID,
arg.OAuthAccessToken,
arg.OAuthRefreshToken,
arg.OAuthExpiry,
)
var i UserLink
err := row.Scan(
&i.UserID,
&i.LoginType,
&i.LinkedID,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
)
return i, err
}
const updateUserLink = `-- name: UpdateUserLink :one
UPDATE
user_links
SET
oauth_access_token = $1,
oauth_refresh_token = $2,
oauth_expiry = $3
WHERE
user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
`
type UpdateUserLinkParams struct {
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
}
func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) {
row := q.db.QueryRowContext(ctx, updateUserLink,
arg.OAuthAccessToken,
arg.OAuthRefreshToken,
arg.OAuthExpiry,
arg.UserID,
arg.LoginType,
)
var i UserLink
err := row.Scan(
&i.UserID,
&i.LoginType,
&i.LinkedID,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
)
return i, err
}
const updateUserLinkedID = `-- name: UpdateUserLinkedID :one
UPDATE
user_links
SET
linked_id = $1
WHERE
user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
`
type UpdateUserLinkedIDParams struct {
LinkedID string `db:"linked_id" json:"linked_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
}
func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error) {
row := q.db.QueryRowContext(ctx, updateUserLinkedID, arg.LinkedID, arg.UserID, arg.LoginType)
var i UserLink
err := row.Scan(
&i.UserID,
&i.LoginType,
&i.LinkedID,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
)
return i, err
}
const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one
SELECT
-- username is returned just to help for logging purposes
@ -2458,13 +2588,13 @@ SELECT
array_append(users.rbac_roles, 'member'),
-- All org_members get the org-member role for their orgs
array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[]
AS roles
AS roles
FROM
users
LEFT JOIN organization_members
ON id = user_id
WHERE
id = $1
id = $1
`
type GetAuthorizationUserRolesRow struct {
@ -2490,7 +2620,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
SELECT
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
FROM
users
WHERE
@ -2517,13 +2647,14 @@ func (q *sqlQuerier) GetUserByEmailOrUsername(ctx context.Context, arg GetUserBy
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.LoginType,
)
return i, err
}
const getUserByID = `-- name: GetUserByID :one
SELECT
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
FROM
users
WHERE
@ -2544,6 +2675,7 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.LoginType,
)
return i, err
}
@ -2564,7 +2696,7 @@ func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) {
const getUsers = `-- name: GetUsers :many
SELECT
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
FROM
users
WHERE
@ -2614,8 +2746,8 @@ WHERE
END
-- End of filters
ORDER BY
-- Deterministic and consistent ordering of all users, even if they share
-- a timestamp. This is to ensure consistent pagination.
-- Deterministic and consistent ordering of all users, even if they share
-- a timestamp. This is to ensure consistent pagination.
(created_at, id) ASC OFFSET $5
LIMIT
-- A null limit means "no limit", so 0 means return all
@ -2656,6 +2788,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.LoginType,
); err != nil {
return nil, err
}
@ -2671,7 +2804,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
}
const getUsersByIDs = `-- name: GetUsersByIDs :many
SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles FROM users WHERE id = ANY($1 :: uuid [ ])
SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type FROM users WHERE id = ANY($1 :: uuid [ ])
`
func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) {
@ -2692,6 +2825,7 @@ func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.LoginType,
); err != nil {
return nil, err
}
@ -2715,10 +2849,11 @@ INSERT INTO
hashed_password,
created_at,
updated_at,
rbac_roles
rbac_roles,
login_type
)
VALUES
($1, $2, $3, $4, $5, $6, $7) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
`
type InsertUserParams struct {
@ -2729,6 +2864,7 @@ type InsertUserParams struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
RBACRoles []string `db:"rbac_roles" json:"rbac_roles"`
LoginType LoginType `db:"login_type" json:"login_type"`
}
func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
@ -2740,6 +2876,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
arg.CreatedAt,
arg.UpdatedAt,
pq.Array(arg.RBACRoles),
arg.LoginType,
)
var i User
err := row.Scan(
@ -2751,6 +2888,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.LoginType,
)
return i, err
}
@ -2782,7 +2920,7 @@ SET
username = $3,
updated_at = $4
WHERE
id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
`
type UpdateUserProfileParams struct {
@ -2809,19 +2947,20 @@ func (q *sqlQuerier) UpdateUserProfile(ctx context.Context, arg UpdateUserProfil
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.LoginType,
)
return i, err
}
const updateUserRoles = `-- name: UpdateUserRoles :one
UPDATE
users
users
SET
-- Remove all duplicates from the roles.
rbac_roles = ARRAY(SELECT DISTINCT UNNEST($1 :: text[]))
WHERE
id = $2
RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
id = $2
RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
`
type UpdateUserRolesParams struct {
@ -2841,6 +2980,7 @@ func (q *sqlQuerier) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesPar
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.LoginType,
)
return i, err
}
@ -2852,7 +2992,7 @@ SET
status = $2,
updated_at = $3
WHERE
id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles
id = $1 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type
`
type UpdateUserStatusParams struct {
@ -2873,6 +3013,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.LoginType,
)
return i, err
}

View File

@ -23,11 +23,7 @@ INSERT INTO
expires_at,
created_at,
updated_at,
login_type,
oauth_access_token,
oauth_refresh_token,
oauth_id_token,
oauth_expiry
login_type
)
VALUES
(@id,
@ -36,7 +32,7 @@ VALUES
WHEN 0 THEN 86400
ELSE @lifetime_seconds::bigint
END
, @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type, @oauth_access_token, @oauth_refresh_token, @oauth_id_token, @oauth_expiry) RETURNING *;
, @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type) RETURNING *;
-- name: UpdateAPIKeyByID :exec
UPDATE
@ -44,10 +40,7 @@ UPDATE
SET
last_used = $2,
expires_at = $3,
ip_address = $4,
oauth_access_token = $5,
oauth_refresh_token = $6,
oauth_expiry = $7
ip_address = $4
WHERE
id = $1;

View File

@ -0,0 +1,46 @@
-- name: GetUserLinkByLinkedID :one
SELECT
*
FROM
user_links
WHERE
linked_id = $1;
-- name: GetUserLinkByUserIDLoginType :one
SELECT
*
FROM
user_links
WHERE
user_id = $1 AND login_type = $2;
-- name: InsertUserLink :one
INSERT INTO
user_links (
user_id,
login_type,
linked_id,
oauth_access_token,
oauth_refresh_token,
oauth_expiry
)
VALUES
( $1, $2, $3, $4, $5, $6 ) RETURNING *;
-- name: UpdateUserLinkedID :one
UPDATE
user_links
SET
linked_id = $1
WHERE
user_id = $2 AND login_type = $3 RETURNING *;
-- name: UpdateUserLink :one
UPDATE
user_links
SET
oauth_access_token = $1,
oauth_refresh_token = $2,
oauth_expiry = $3
WHERE
user_id = $4 AND login_type = $5 RETURNING *;

View File

@ -37,10 +37,11 @@ INSERT INTO
hashed_password,
created_at,
updated_at,
rbac_roles
rbac_roles,
login_type
)
VALUES
($1, $2, $3, $4, $5, $6, $7) RETURNING *;
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
-- name: UpdateUserProfile :one
UPDATE
@ -54,12 +55,12 @@ WHERE
-- name: UpdateUserRoles :one
UPDATE
users
users
SET
-- Remove all duplicates from the roles.
rbac_roles = ARRAY(SELECT DISTINCT UNNEST(@granted_roles :: text[]))
WHERE
id = @id
id = @id
RETURNING *;
-- name: UpdateUserHashedPassword :exec
@ -122,8 +123,8 @@ WHERE
END
-- End of filters
ORDER BY
-- Deterministic and consistent ordering of all users, even if they share
-- a timestamp. This is to ensure consistent pagination.
-- Deterministic and consistent ordering of all users, even if they share
-- a timestamp. This is to ensure consistent pagination.
(created_at, id) ASC OFFSET @offset_opt
LIMIT
-- A null limit means "no limit", so 0 means return all
@ -152,10 +153,10 @@ SELECT
array_append(users.rbac_roles, 'member'),
-- All org_members get the org-member role for their orgs
array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[]
AS roles
AS roles
FROM
users
LEFT JOIN organization_members
ON id = user_id
WHERE
id = @user_id;
id = @user_id;

View File

@ -14,6 +14,7 @@ import (
"golang.org/x/oauth2"
"github.com/google/uuid"
"github.com/tabbed/pqtype"
"github.com/coder/coder/coderd/database"
@ -149,9 +150,21 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
// Tracks if the API key has properties updated!
changed := false
var link database.UserLink
if key.LoginType != database.LoginTypePassword {
link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
if err != nil {
write(http.StatusInternalServerError, codersdk.Response{
Message: "A database error occurred",
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
})
return
}
// Check if the OAuth token is expired!
if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() {
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() {
var oauthConfig OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
@ -167,9 +180,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
}
// If it is, let's refresh it from the provided config!
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: key.OAuthAccessToken,
RefreshToken: key.OAuthRefreshToken,
Expiry: key.OAuthExpiry,
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
Expiry: link.OAuthExpiry,
}).Token()
if err != nil {
write(http.StatusUnauthorized, codersdk.Response{
@ -178,9 +191,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
})
return
}
key.OAuthAccessToken = token.AccessToken
key.OAuthRefreshToken = token.RefreshToken
key.OAuthExpiry = token.Expiry
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
}
@ -222,13 +235,10 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
}
if changed {
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
IPAddress: key.IPAddress,
OAuthAccessToken: key.OAuthAccessToken,
OAuthRefreshToken: key.OAuthRefreshToken,
OAuthExpiry: key.OAuthExpiry,
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
IPAddress: key.IPAddress,
})
if err != nil {
write(http.StatusInternalServerError, codersdk.Response{
@ -237,6 +247,24 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
})
return
}
// If the API Key is associated with a user_link (e.g. Github/OIDC)
// then we want to update the relevant oauth fields.
if link.UserID != uuid.Nil {
link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthExpiry: link.OAuthExpiry,
})
if err != nil {
write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
return
}
}
}
// If the key is valid, we also fetch the user roles and status.

View File

@ -187,6 +187,7 @@ func TestAPIKey(t *testing.T) {
ID: id,
HashedSecret: hashed[:],
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
@ -215,6 +216,7 @@ func TestAPIKey(t *testing.T) {
HashedSecret: hashed[:],
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -253,6 +255,7 @@ func TestAPIKey(t *testing.T) {
HashedSecret: hashed[:],
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -288,6 +291,7 @@ func TestAPIKey(t *testing.T) {
LastUsed: database.Now().AddDate(0, 0, -1),
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
@ -323,6 +327,7 @@ func TestAPIKey(t *testing.T) {
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
@ -361,6 +366,13 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
_, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
@ -393,10 +405,16 @@ func TestAPIKey(t *testing.T) {
HashedSecret: hashed[:],
LoginType: database.LoginTypeGithub,
LastUsed: database.Now(),
OAuthExpiry: database.Now().AddDate(0, 0, -1),
UserID: user.ID,
})
require.NoError(t, err)
_, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
OAuthExpiry: database.Now().AddDate(0, 0, -1),
})
require.NoError(t, err)
token := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
@ -418,7 +436,6 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken)
})
t.Run("RemoteIPUpdates", func(t *testing.T) {
@ -443,6 +460,7 @@ func TestAPIKey(t *testing.T) {
LastUsed: database.Now().AddDate(0, 0, -1),
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)

View File

@ -124,6 +124,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestOrganizationParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext()))

View File

@ -53,6 +53,7 @@ func TestTemplateParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestTemplateVersionParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -47,6 +47,7 @@ func TestUserParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestWorkspaceParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -73,6 +73,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) {
HashedSecret: hashed[:],
UserID: userID,
ExpiresAt: time.Now().Add(5 * time.Hour),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
_, err = fDB.InsertUser(ctx, database.InsertUserParams{

View File

@ -6,12 +6,14 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
@ -47,10 +49,13 @@ func (api *API) userAuthMethods(rw http.ResponseWriter, _ *http.Request) {
}
func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
var (
ctx = r.Context()
state = httpmw.OAuth2(r)
)
oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token))
memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient)
oauthClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(state.Token))
memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(ctx, oauthClient)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching authenticated Github user organizations.",
@ -75,7 +80,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
return
}
ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient)
ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(ctx, oauthClient)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching authenticated Github user.",
@ -94,7 +99,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
continue
}
allowedTeam, err = api.GithubOAuth2Config.TeamMembership(r.Context(), oauthClient, allowTeam.Organization, allowTeam.Slug, *ghUser.Login)
allowedTeam, err = api.GithubOAuth2Config.TeamMembership(ctx, oauthClient, allowTeam.Organization, allowTeam.Slug, *ghUser.Login)
// The calling user may not have permission to the requested team!
if err != nil {
continue
@ -108,7 +113,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
}
}
emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient)
emails, err := api.GithubOAuth2Config.ListEmails(ctx, oauthClient)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching personal Github user.",
@ -117,33 +122,35 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
return
}
var user database.User
// Search for existing users with matching and verified emails.
// If a verified GitHub email matches a Coder user, we will return.
verifiedEmails := make([]string, 0, len(emails))
for _, email := range emails {
if !email.GetVerified() {
continue
}
user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
Email: *email.Email,
verifiedEmails = append(verifiedEmails, email.GetEmail())
}
if len(verifiedEmails) == 0 {
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: "Verify an email address on Github to authenticate!",
})
if errors.Is(err, sql.ErrNoRows) {
continue
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: fmt.Sprintf("Internal error fetching user by email %q.", *email.Email),
Detail: err.Error(),
})
return
}
if !*email.Verified {
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email),
})
return
}
break
return
}
user, link, err := findLinkedUser(ctx, api.Database, githubLinkedID(ghUser), verifiedEmails...)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "An internal error occurred.",
Detail: err.Error(),
})
return
}
if user.ID != uuid.Nil && user.LoginType != database.LoginTypeGithub {
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeGithub, user.LoginType),
})
return
}
// If the user doesn't exist, create a new one!
@ -177,10 +184,13 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
})
return
}
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
Email: *verifiedEmail.Email,
Username: *ghUser.Login,
OrganizationID: organizationID,
user, _, err = api.createUser(ctx, createUserRequest{
CreateUserRequest: codersdk.CreateUserRequest{
Email: *verifiedEmail.Email,
Username: *ghUser.Login,
OrganizationID: organizationID,
},
LoginType: database.LoginTypeGithub,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@ -191,12 +201,49 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
}
}
_, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
OAuthAccessToken: state.Token.AccessToken,
OAuthRefreshToken: state.Token.RefreshToken,
OAuthExpiry: state.Token.Expiry,
// This can happen if a user is a built-in user but is signing in
// with Github for the first time.
if link.UserID == uuid.Nil {
link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
LinkedID: githubLinkedID(ghUser),
OAuthAccessToken: state.Token.AccessToken,
OAuthRefreshToken: state.Token.RefreshToken,
OAuthExpiry: state.Token.Expiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "A database error occurred.",
Detail: fmt.Sprintf("insert user link: %s", err.Error()),
})
return
}
}
// LEGACY: Remove 10/2022.
// We started tracking linked IDs later so it's possible for a user to be a
// pre-existing Github user and not have a linked ID. The migration
// to user_links did not populate this field as it requires calling out
// to Github to query the user's ID.
if link.LinkedID == "" {
link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{
UserID: user.ID,
LinkedID: githubLinkedID(ghUser),
LoginType: database.LoginTypeGithub,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "A database error occurred.",
Detail: xerrors.Errorf("update user link: %w", err.Error).Error(),
})
return
}
}
_, created := api.createAPIKey(rw, r, createAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
})
if !created {
return
@ -219,7 +266,10 @@ type OIDCConfig struct {
}
func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
var (
ctx = r.Context()
state = httpmw.OAuth2(r)
)
// See the example here: https://github.com/coreos/go-oidc
rawIDToken, ok := state.Token.Extra("id_token").(string)
@ -230,7 +280,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
return
}
idToken, err := api.OIDCConfig.Verifier.Verify(r.Context(), rawIDToken)
idToken, err := api.OIDCConfig.Verifier.Verify(ctx, rawIDToken)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to verify OIDC token.",
@ -285,29 +335,48 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
}
}
var user database.User
user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
Email: claims.Email,
})
if errors.Is(err, sql.ErrNoRows) {
if !api.OIDCConfig.AllowSignups {
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: "Signups are disabled for OIDC authentication!",
})
return
}
user, link, err := findLinkedUser(ctx, api.Database, oidcLinkedID(idToken), claims.Email)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to find user.",
Detail: err.Error(),
})
return
}
if user.ID == uuid.Nil && !api.OIDCConfig.AllowSignups {
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: "Signups are disabled for OIDC authentication!",
})
return
}
if user.ID != uuid.Nil && user.LoginType != database.LoginTypeOIDC {
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypeOIDC, user.LoginType),
})
return
}
// This can happen if a user is a built-in user but is signing in
// with OIDC for the first time.
if user.ID == uuid.Nil {
var organizationID uuid.UUID
organizations, _ := api.Database.GetOrganizations(r.Context())
organizations, _ := api.Database.GetOrganizations(ctx)
if len(organizations) > 0 {
// Add the user to the first organization. Once multi-organization
// support is added, we should enable a configuration map of user
// email to organization.
organizationID = organizations[0].ID
}
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
Email: claims.Email,
Username: claims.Username,
OrganizationID: organizationID,
user, _, err = api.createUser(ctx, createUserRequest{
CreateUserRequest: codersdk.CreateUserRequest{
Email: claims.Email,
Username: claims.Username,
OrganizationID: organizationID,
},
LoginType: database.LoginTypeOIDC,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@ -316,21 +385,81 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
})
return
}
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get user by email.",
Detail: err.Error(),
})
return
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to insert user auth metadata.",
Detail: err.Error(),
})
return
}
}
_, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
OAuthAccessToken: state.Token.AccessToken,
OAuthRefreshToken: state.Token.RefreshToken,
OAuthExpiry: state.Token.Expiry,
if link.UserID == uuid.Nil {
link, err = api.Database.InsertUserLink(ctx, database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
LinkedID: oidcLinkedID(idToken),
OAuthAccessToken: state.Token.AccessToken,
OAuthRefreshToken: state.Token.RefreshToken,
OAuthExpiry: state.Token.Expiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "A database error occurred.",
Detail: fmt.Sprintf("insert user link: %s", err.Error()),
})
return
}
}
// LEGACY: Remove 10/2022.
// We started tracking linked IDs later so it's possible for a user to be a
// pre-existing OIDC user and not have a linked ID.
// The migration that added the user_links table could not populate
// the 'linked_id' field since it requires fields off the access token.
if link.LinkedID == "" {
link, err = api.Database.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{
UserID: user.ID,
LinkedID: oidcLinkedID(idToken),
LoginType: database.LoginTypeGithub,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "A database error occurred.",
Detail: xerrors.Errorf("update user link: %w", err.Error).Error(),
})
return
}
}
// If the upstream email or username has changed we should mirror
// that in Coder. Many enterprises use a user's email/username as
// security auditing fields so they need to stay synced.
if user.Email != claims.Email || user.Username != claims.Username {
// TODO(JonA): Since we're processing updates to a user's upstream
// email/username, it's possible for a different built-in user to
// have already claimed the username.
// In such cases in the current implementation this user can now no
// longer sign in until an administrator finds the offending built-in
// user and changes their username.
user, err = api.Database.UpdateUserProfile(ctx, database.UpdateUserProfileParams{
ID: user.ID,
Email: claims.Email,
Username: claims.Username,
UpdatedAt: database.Now(),
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update user profile.",
Detail: fmt.Sprintf("update user profile: %s", err.Error()),
})
return
}
}
_, created := api.createAPIKey(rw, r, createAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
})
if !created {
return
@ -342,3 +471,66 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
}
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
}
// githubLinkedID returns the unique ID for a GitHub user.
func githubLinkedID(u *github.User) string {
return strconv.FormatInt(u.GetID(), 10)
}
// oidcLinkedID returns the uniqued ID for an OIDC user.
// See https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability .
func oidcLinkedID(tok *oidc.IDToken) string {
return strings.Join([]string{tok.Issuer, tok.Subject}, "||")
}
// findLinkedUser tries to find a user by their unique OAuth-linked ID.
// If it doesn't not find it, it returns the user by their email.
func findLinkedUser(ctx context.Context, db database.Store, linkedID string, emails ...string) (database.User, database.UserLink, error) {
var (
user database.User
link database.UserLink
)
link, err := db.GetUserLinkByLinkedID(ctx, linkedID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return user, link, xerrors.Errorf("get user auth by linked ID: %w", err)
}
if err == nil {
user, err = db.GetUserByID(ctx, link.UserID)
if err != nil {
return database.User{}, database.UserLink{}, xerrors.Errorf("get user by id: %w", err)
}
return user, link, nil
}
for _, email := range emails {
user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
Email: email,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return user, link, xerrors.Errorf("get user by email: %w", err)
}
if errors.Is(err, sql.ErrNoRows) {
continue
}
break
}
if user.ID == uuid.Nil {
// No user found.
return database.User{}, database.UserLink{}, nil
}
// LEGACY: This is annoying but we have to search for the user_link
// again except this time we search by user_id and login_type. It's
// possible that a user_link exists without a populated 'linked_id'.
link, err = db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: user.ID,
LoginType: user.LoginType,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return database.User{}, database.UserLink{}, xerrors.Errorf("get user link by user id and login type: %w", err)
}
return user, link, nil
}

View File

@ -175,38 +175,7 @@ func TestUserOAuth2Github(t *testing.T) {
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
})
t.Run("Signup", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
AllowOrganizations: []string{"coder"},
AllowSignups: true,
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
Organization: &github.Organization{
Login: github.String("coder"),
},
}}, nil
},
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
return &github.User{
Login: github.String("kyle"),
}, nil
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
return []*github.UserEmail{{
Email: github.String("kyle@coder.com"),
Verified: github.Bool(true),
Primary: github.Bool(true),
}}, nil
},
},
})
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
t.Run("Login", func(t *testing.T) {
t.Run("MultiLoginNotAllowed", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{
@ -230,9 +199,50 @@ func TestUserOAuth2Github(t *testing.T) {
},
},
})
// Creates the first user with login_type 'password'.
_ = coderdtest.CreateFirstUser(t, client)
// Attempting to login should give us a 403 since the user
// already has a login_type of 'password'.
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
})
t.Run("Signup", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
AllowOrganizations: []string{"coder"},
AllowSignups: true,
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
Organization: &github.Organization{
Login: github.String("coder"),
},
}}, nil
},
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
return &github.User{
Login: github.String("kyle"),
ID: i64ptr(1234),
}, nil
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
return []*github.UserEmail{{
Email: github.String("kyle@coder.com"),
Verified: github.Bool(true),
Primary: github.Bool(true),
}}, nil
},
},
})
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
client.SessionToken = resp.Cookies()[0].Value
user, err := client.User(context.Background(), "me")
require.NoError(t, err)
require.Equal(t, "kyle@coder.com", user.Email)
require.Equal(t, "kyle", user.Username)
})
t.Run("SignupAllowedTeam", func(t *testing.T) {
t.Parallel()
@ -415,11 +425,13 @@ func createOIDCConfig(t *testing.T, claims jwt.MapClaims) *coderd.OIDCConfig {
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
claims["iss"] = "https://coder.com"
claims["sub"] = "hello"
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(key)
require.NoError(t, err)
verifier := oidc.NewVerifier("", &oidc.StaticKeySet{
verifier := oidc.NewVerifier("https://coder.com", &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{key.Public()},
}, &oidc.Config{
SkipClientIDCheck: true,
@ -480,3 +492,7 @@ func oidcCallback(t *testing.T, client *codersdk.Client) *http.Response {
t.Log(string(data))
return res
}
func i64ptr(i int64) *int64 {
return &i
}

View File

@ -77,10 +77,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
return
}
user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{
Email: createUser.Email,
Username: createUser.Username,
Password: createUser.Password,
user, organizationID, err := api.createUser(r.Context(), createUserRequest{
CreateUserRequest: codersdk.CreateUserRequest{
Email: createUser.Email,
Username: createUser.Username,
Password: createUser.Password,
},
LoginType: database.LoginTypePassword,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@ -196,14 +199,14 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
return
}
var createUser codersdk.CreateUserRequest
if !httpapi.Read(rw, r, &createUser) {
var req codersdk.CreateUserRequest
if !httpapi.Read(rw, r, &req) {
return
}
// Create the organization member in the org.
if !api.Authorize(r, rbac.ActionCreate,
rbac.ResourceOrganizationMember.InOrg(createUser.OrganizationID)) {
rbac.ResourceOrganizationMember.InOrg(req.OrganizationID)) {
httpapi.ResourceNotFound(rw)
return
}
@ -211,8 +214,8 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
// TODO: @emyrk Authorize the organization create if the createUser will do that.
_, err := api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
Username: createUser.Username,
Email: createUser.Email,
Username: req.Username,
Email: req.Email,
})
if err == nil {
httpapi.Write(rw, http.StatusConflict, codersdk.Response{
@ -228,10 +231,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
return
}
_, err = api.Database.GetOrganizationByID(r.Context(), createUser.OrganizationID)
_, err = api.Database.GetOrganizationByID(r.Context(), req.OrganizationID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Organization does not exist with the provided id %q.", createUser.OrganizationID),
Message: fmt.Sprintf("Organization does not exist with the provided id %q.", req.OrganizationID),
})
return
}
@ -243,7 +246,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
return
}
user, _, err := api.createUser(r.Context(), createUser)
user, _, err := api.createUser(r.Context(), createUserRequest{
CreateUserRequest: req,
LoginType: database.LoginTypePassword,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error creating user.",
@ -257,7 +263,7 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
Users: []telemetry.User{telemetry.ConvertUser(user)},
})
httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{createUser.OrganizationID}))
httpapi.Write(rw, http.StatusCreated, convertUser(user, []uuid.UUID{req.OrganizationID}))
}
// Returns the parameterized user requested. All validation
@ -701,6 +707,13 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
return
}
if user.LoginType != database.LoginTypePassword {
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("Incorrect login type, attempting to use %q but user is of login type %q", database.LoginTypePassword, user.LoginType),
})
return
}
// If the user logged into a suspended account, reject the login request.
if user.Status != database.UserStatusActive {
httpapi.Write(rw, http.StatusUnauthorized, codersdk.Response{
@ -709,7 +722,7 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
return
}
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
sessionToken, created := api.createAPIKey(rw, r, createAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
@ -732,7 +745,7 @@ func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) {
}
lifeTime := time.Hour * 24 * 7
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
sessionToken, created := api.createAPIKey(rw, r, createAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypePassword,
// All api generated keys will last 1 week. Browser login tokens have
@ -818,7 +831,16 @@ func generateAPIKeyIDSecret() (id string, secret string, err error) {
return id, secret, nil
}
func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) {
type createAPIKeyParams struct {
UserID uuid.UUID
LoginType database.LoginType
// Optional.
ExpiresAt time.Time
LifetimeSeconds int64
}
func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params createAPIKeyParams) (string, bool) {
keyID, keySecret, err := generateAPIKeyIDSecret()
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@ -856,15 +878,11 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat
Valid: true,
},
// Make sure in UTC time for common time zone
ExpiresAt: params.ExpiresAt.UTC(),
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
HashedSecret: hashed[:],
LoginType: params.LoginType,
OAuthAccessToken: params.OAuthAccessToken,
OAuthRefreshToken: params.OAuthRefreshToken,
OAuthIDToken: params.OAuthIDToken,
OAuthExpiry: params.OAuthExpiry,
ExpiresAt: params.ExpiresAt.UTC(),
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
HashedSecret: hashed[:],
LoginType: params.LoginType,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@ -891,7 +909,12 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat
return sessionToken, true
}
func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) {
type createUserRequest struct {
codersdk.CreateUserRequest
LoginType database.LoginType
}
func (api *API) createUser(ctx context.Context, req createUserRequest) (database.User, uuid.UUID, error) {
var user database.User
return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error {
orgRoles := make([]string, 0)
@ -918,6 +941,7 @@ func (api *API) createUser(ctx context.Context, req codersdk.CreateUserRequest)
UpdatedAt: database.Now(),
// All new users are defaulted to members of the site.
RBACRoles: []string{},
LoginType: req.LoginType,
}
// If a user signs up with OAuth, they can have no password!
if req.Password != "" {

View File

@ -27,6 +27,7 @@ type LoginType string
const (
LoginTypePassword LoginType = "password"
LoginTypeGithub LoginType = "github"
LoginTypeOIDC LoginType = "oidc"
)
type UsersRequest struct {

View File

@ -561,7 +561,7 @@ export type LogLevel = "debug" | "error" | "info" | "trace" | "warn"
export type LogSource = "provisioner" | "provisioner_daemon"
// From codersdk/users.go
export type LoginType = "github" | "password"
export type LoginType = "github" | "oidc" | "password"
// From codersdk/parameters.go
export type ParameterDestinationScheme = "environment_variable" | "none" | "provisioner_variable"