mirror of https://github.com/coder/coder.git
chore: add claims to oauth link in db for debug (#10827)
* chore: add claims to oauth link in db for debug
This commit is contained in:
parent
0534f8f59b
commit
abb2c7656a
|
@ -480,6 +480,37 @@ const docTemplate = `{
|
|||
}
|
||||
}
|
||||
},
|
||||
"/debug/{user}/debug-link": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"summary": "Debug OIDC context for a user",
|
||||
"operationId": "debug-oidc-context-for-a-user",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, name, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/deployment/config": {
|
||||
"get": {
|
||||
"security": [
|
||||
|
|
|
@ -408,6 +408,35 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/debug/{user}/debug-link": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Agents"],
|
||||
"summary": "Debug OIDC context for a user",
|
||||
"operationId": "debug-oidc-context-for-a-user",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, name, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/deployment/config": {
|
||||
"get": {
|
||||
"security": [
|
||||
|
|
|
@ -972,6 +972,10 @@ func New(options *Options) *API {
|
|||
r.Get("/tailnet", api.debugTailnet)
|
||||
r.Get("/health", api.debugDeploymentHealth)
|
||||
r.Get("/ws", (&healthcheck.WebsocketEchoServer{}).ServeHTTP)
|
||||
r.Route("/{user}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractUserParam(options.Database))
|
||||
r.Get("/debug-link", api.userDebugOIDC)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package oidctest
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -77,6 +78,7 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code
|
|||
OAuthExpiry: time.Now().Add(time.Hour * -1),
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
DebugContext: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err, "expire user link")
|
||||
|
||||
|
|
|
@ -1022,6 +1022,7 @@ func (s *MethodTestSuite) TestUser() {
|
|||
OAuthExpiry: link.OAuthExpiry,
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
DebugContext: json.RawMessage("{}"),
|
||||
}).Asserts(link, rbac.ActionUpdate).Returns(link)
|
||||
}))
|
||||
s.Run("UpdateUserRoles", s.Subtest(func(db database.Store, check *expects) {
|
||||
|
|
|
@ -513,6 +513,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
|
|||
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
|
||||
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
|
||||
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
|
||||
DebugContext: takeFirstSlice(orig.DebugContext, json.RawMessage("{}")),
|
||||
})
|
||||
|
||||
require.NoError(t, err, "insert link")
|
||||
|
|
|
@ -5106,6 +5106,7 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
|
|||
OAuthRefreshToken: args.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID,
|
||||
OAuthExpiry: args.OAuthExpiry,
|
||||
DebugContext: args.DebugContext,
|
||||
}
|
||||
|
||||
q.userLinks = append(q.userLinks, link)
|
||||
|
@ -6188,6 +6189,7 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
|
|||
link.OAuthRefreshToken = params.OAuthRefreshToken
|
||||
link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID
|
||||
link.OAuthExpiry = params.OAuthExpiry
|
||||
link.DebugContext = params.DebugContext
|
||||
|
||||
q.userLinks[i] = link
|
||||
return link, nil
|
||||
|
|
|
@ -870,13 +870,16 @@ CREATE TABLE user_links (
|
|||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
oauth_access_token_key_id text,
|
||||
oauth_refresh_token_key_id text
|
||||
oauth_refresh_token_key_id text,
|
||||
debug_context jsonb DEFAULT '{}'::jsonb NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN user_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
|
||||
|
||||
COMMENT ON COLUMN user_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';
|
||||
|
||||
COMMENT ON COLUMN user_links.debug_context IS 'Debug information includes information like id_token and userinfo claims.';
|
||||
|
||||
CREATE TABLE workspace_agent_log_sources (
|
||||
workspace_agent_id uuid NOT NULL,
|
||||
id uuid NOT NULL,
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
BEGIN;
|
||||
|
||||
ALTER TABLE user_links DROP COLUMN debug_context;
|
||||
|
||||
COMMIT;
|
|
@ -0,0 +1,6 @@
|
|||
BEGIN;
|
||||
|
||||
ALTER TABLE user_links ADD COLUMN debug_context jsonb DEFAULT '{}' NOT NULL;
|
||||
COMMENT ON COLUMN user_links.debug_context IS 'Debug information includes information like id_token and userinfo claims.';
|
||||
|
||||
COMMIT;
|
|
@ -2127,6 +2127,8 @@ type UserLink struct {
|
|||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
// The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
// Debug information includes information like id_token and userinfo claims.
|
||||
DebugContext json.RawMessage `db:"debug_context" json:"debug_context"`
|
||||
}
|
||||
|
||||
// Visible fields of users are allowed to be joined with other tables for including context of other resources.
|
||||
|
|
|
@ -6548,7 +6548,7 @@ func (q *sqlQuerier) InsertTemplateVersionVariable(ctx context.Context, arg Inse
|
|||
|
||||
const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one
|
||||
SELECT
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
|
@ -6567,13 +6567,14 @@ func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string)
|
|||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.DebugContext,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one
|
||||
SELECT
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
|
@ -6597,12 +6598,13 @@ func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUs
|
|||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.DebugContext,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserLinksByUserID = `-- name: GetUserLinksByUserID :many
|
||||
SELECT user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM user_links WHERE user_id = $1
|
||||
SELECT user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context FROM user_links WHERE user_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error) {
|
||||
|
@ -6623,6 +6625,7 @@ func (q *sqlQuerier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID)
|
|||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.DebugContext,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -6647,21 +6650,23 @@ INSERT INTO
|
|||
oauth_access_token_key_id,
|
||||
oauth_refresh_token,
|
||||
oauth_refresh_token_key_id,
|
||||
oauth_expiry
|
||||
oauth_expiry,
|
||||
debug_context
|
||||
)
|
||||
VALUES
|
||||
( $1, $2, $3, $4, $5, $6, $7, $8 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context
|
||||
`
|
||||
|
||||
type InsertUserLinkParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
LinkedID string `db:"linked_id" json:"linked_id"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
LinkedID string `db:"linked_id" json:"linked_id"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
DebugContext json.RawMessage `db:"debug_context" json:"debug_context"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) {
|
||||
|
@ -6674,6 +6679,7 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam
|
|||
arg.OAuthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
arg.OAuthExpiry,
|
||||
arg.DebugContext,
|
||||
)
|
||||
var i UserLink
|
||||
err := row.Scan(
|
||||
|
@ -6685,6 +6691,7 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam
|
|||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.DebugContext,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -6697,19 +6704,21 @@ SET
|
|||
oauth_access_token_key_id = $2,
|
||||
oauth_refresh_token = $3,
|
||||
oauth_refresh_token_key_id = $4,
|
||||
oauth_expiry = $5
|
||||
oauth_expiry = $5,
|
||||
debug_context = $6
|
||||
WHERE
|
||||
user_id = $6 AND login_type = $7 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
user_id = $7 AND login_type = $8 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context
|
||||
`
|
||||
|
||||
type UpdateUserLinkParams struct {
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
DebugContext json.RawMessage `db:"debug_context" json:"debug_context"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) {
|
||||
|
@ -6719,6 +6728,7 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam
|
|||
arg.OAuthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
arg.OAuthExpiry,
|
||||
arg.DebugContext,
|
||||
arg.UserID,
|
||||
arg.LoginType,
|
||||
)
|
||||
|
@ -6732,6 +6742,7 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam
|
|||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.DebugContext,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -6742,7 +6753,7 @@ UPDATE
|
|||
SET
|
||||
linked_id = $1
|
||||
WHERE
|
||||
user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, debug_context
|
||||
`
|
||||
|
||||
type UpdateUserLinkedIDParams struct {
|
||||
|
@ -6763,6 +6774,7 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke
|
|||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.DebugContext,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
|
|
@ -27,10 +27,11 @@ INSERT INTO
|
|||
oauth_access_token_key_id,
|
||||
oauth_refresh_token,
|
||||
oauth_refresh_token_key_id,
|
||||
oauth_expiry
|
||||
oauth_expiry,
|
||||
debug_context
|
||||
)
|
||||
VALUES
|
||||
( $1, $2, $3, $4, $5, $6, $7, $8 ) RETURNING *;
|
||||
( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) RETURNING *;
|
||||
|
||||
-- name: UpdateUserLinkedID :one
|
||||
UPDATE
|
||||
|
@ -48,6 +49,7 @@ SET
|
|||
oauth_access_token_key_id = $2,
|
||||
oauth_refresh_token = $3,
|
||||
oauth_refresh_token_key_id = $4,
|
||||
oauth_expiry = $5
|
||||
oauth_expiry = $5,
|
||||
debug_context = $6
|
||||
WHERE
|
||||
user_id = $6 AND login_type = $7 RETURNING *;
|
||||
user_id = $7 AND login_type = $8 RETURNING *;
|
||||
|
|
|
@ -378,6 +378,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
|||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
DebugContext: link.DebugContext,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
|
|
|
@ -1674,6 +1674,7 @@ func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig ht
|
|||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
DebugContext: link.DebugContext,
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("update user link: %w", err)
|
||||
|
|
|
@ -3,6 +3,7 @@ package coderd
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -631,6 +632,7 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
|||
Email: verifiedEmail.GetEmail(),
|
||||
Username: ghUser.GetLogin(),
|
||||
AvatarURL: ghUser.GetAvatarURL(),
|
||||
DebugContext: OauthDebugContext{},
|
||||
}).SetInitAuditRequest(func(params *audit.RequestParams) (*audit.Request[database.User], func()) {
|
||||
return audit.InitRequest[database.User](rw, params)
|
||||
})
|
||||
|
@ -770,8 +772,8 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
// "email_verified" is an optional claim that changes the behavior
|
||||
// of our OIDC handler, so each property must be pulled manually out
|
||||
// of the claim mapping.
|
||||
claims := map[string]interface{}{}
|
||||
err = idToken.Claims(&claims)
|
||||
idtokenClaims := map[string]interface{}{}
|
||||
err = idToken.Claims(&idtokenClaims)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "oauth2: unable to extract OIDC claims", slog.Error(err))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
|
@ -783,8 +785,8 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
|
||||
logger.Debug(ctx, "got oidc claims",
|
||||
slog.F("source", "id_token"),
|
||||
slog.F("claim_fields", claimFields(claims)),
|
||||
slog.F("blank", blankFields(claims)),
|
||||
slog.F("claim_fields", claimFields(idtokenClaims)),
|
||||
slog.F("blank", blankFields(idtokenClaims)),
|
||||
)
|
||||
|
||||
// Not all claims are necessarily embedded in the `id_token`.
|
||||
|
@ -797,10 +799,12 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
// Some providers (e.g. ADFS) do not support custom OIDC claims in the
|
||||
// UserInfo endpoint, so we allow users to disable it and only rely on the
|
||||
// ID token.
|
||||
userInfoClaims := make(map[string]interface{})
|
||||
// If user info is skipped, the idtokenClaims are the claims.
|
||||
mergedClaims := idtokenClaims
|
||||
if !api.OIDCConfig.IgnoreUserInfo {
|
||||
userInfo, err := api.OIDCConfig.Provider.UserInfo(ctx, oauth2.StaticTokenSource(state.Token))
|
||||
if err == nil {
|
||||
userInfoClaims := map[string]interface{}{}
|
||||
err = userInfo.Claims(&userInfoClaims)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "oauth2: unable to unmarshal user info claims", slog.Error(err))
|
||||
|
@ -818,13 +822,13 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Merge the claims from the ID token and the UserInfo endpoint.
|
||||
// Information from UserInfo takes precedence.
|
||||
claims = mergeClaims(claims, userInfoClaims)
|
||||
mergedClaims = mergeClaims(idtokenClaims, userInfoClaims)
|
||||
|
||||
// Log all of the field names after merging.
|
||||
logger.Debug(ctx, "got oidc claims",
|
||||
slog.F("source", "merged"),
|
||||
slog.F("claim_fields", claimFields(claims)),
|
||||
slog.F("blank", blankFields(claims)),
|
||||
slog.F("claim_fields", claimFields(mergedClaims)),
|
||||
slog.F("blank", blankFields(mergedClaims)),
|
||||
)
|
||||
} else if !strings.Contains(err.Error(), "user info endpoint is not supported by this provider") {
|
||||
logger.Error(ctx, "oauth2: unable to obtain user information claims", slog.Error(err))
|
||||
|
@ -841,13 +845,13 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
usernameRaw, ok := claims[api.OIDCConfig.UsernameField]
|
||||
usernameRaw, ok := mergedClaims[api.OIDCConfig.UsernameField]
|
||||
var username string
|
||||
if ok {
|
||||
username, _ = usernameRaw.(string)
|
||||
}
|
||||
|
||||
emailRaw, ok := claims[api.OIDCConfig.EmailField]
|
||||
emailRaw, ok := mergedClaims[api.OIDCConfig.EmailField]
|
||||
if !ok {
|
||||
// Email is an optional claim in OIDC and
|
||||
// instead the email is frequently sent in
|
||||
|
@ -871,7 +875,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
verifiedRaw, ok := claims["email_verified"]
|
||||
verifiedRaw, ok := mergedClaims["email_verified"]
|
||||
if ok {
|
||||
verified, ok := verifiedRaw.(bool)
|
||||
if ok && !verified {
|
||||
|
@ -891,7 +895,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
// This is so we can support manual group assignment.
|
||||
if api.OIDCConfig.GroupField != "" {
|
||||
usingGroups = true
|
||||
groupsRaw, ok := claims[api.OIDCConfig.GroupField]
|
||||
groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField]
|
||||
if ok && api.OIDCConfig.GroupField != "" {
|
||||
// Convert the []interface{} we get to a []string.
|
||||
groupsInterface, ok := groupsRaw.([]interface{})
|
||||
|
@ -926,7 +930,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// This conditional is purely to warn the user they might have misconfigured their OIDC
|
||||
// configuration.
|
||||
if _, groupClaimExists := claims["groups"]; !usingGroups && groupClaimExists {
|
||||
if _, groupClaimExists := mergedClaims["groups"]; !usingGroups && groupClaimExists {
|
||||
logger.Debug(ctx, "claim 'groups' was returned, but 'oidc-group-field' is not set, check your coder oidc settings")
|
||||
}
|
||||
|
||||
|
@ -961,7 +965,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var picture string
|
||||
pictureRaw, ok := claims["picture"]
|
||||
pictureRaw, ok := mergedClaims["picture"]
|
||||
if ok {
|
||||
picture, _ = pictureRaw.(string)
|
||||
}
|
||||
|
@ -978,7 +982,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
|
||||
roles := api.OIDCConfig.UserRolesDefault
|
||||
if api.OIDCConfig.RoleSyncEnabled() {
|
||||
rolesRow, ok := claims[api.OIDCConfig.UserRoleField]
|
||||
rolesRow, ok := mergedClaims[api.OIDCConfig.UserRoleField]
|
||||
if !ok {
|
||||
// If no claim is provided than we can assume the user is just
|
||||
// a member. This is because there is no way to tell the difference
|
||||
|
@ -1055,6 +1059,10 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||
Groups: groups,
|
||||
CreateMissingGroups: api.OIDCConfig.CreateMissingGroups,
|
||||
GroupFilter: api.OIDCConfig.GroupFilter,
|
||||
DebugContext: OauthDebugContext{
|
||||
IDTokenClaims: idtokenClaims,
|
||||
UserInfoClaims: userInfoClaims,
|
||||
},
|
||||
}).SetInitAuditRequest(func(params *audit.RequestParams) (*audit.Request[database.User], func()) {
|
||||
return audit.InitRequest[database.User](rw, params)
|
||||
})
|
||||
|
@ -1123,6 +1131,13 @@ func mergeClaims(a, b map[string]interface{}) map[string]interface{} {
|
|||
return c
|
||||
}
|
||||
|
||||
// OauthDebugContext provides helpful information for admins to debug
|
||||
// OAuth login issues.
|
||||
type OauthDebugContext struct {
|
||||
IDTokenClaims map[string]interface{} `json:"id_token_claims"`
|
||||
UserInfoClaims map[string]interface{} `json:"user_info_claims"`
|
||||
}
|
||||
|
||||
type oauthLoginParams struct {
|
||||
User database.User
|
||||
Link database.UserLink
|
||||
|
@ -1147,6 +1162,8 @@ type oauthLoginParams struct {
|
|||
UsingRoles bool
|
||||
Roles []string
|
||||
|
||||
DebugContext OauthDebugContext
|
||||
|
||||
commitLock sync.Mutex
|
||||
initAuditRequest func(params *audit.RequestParams) *audit.Request[database.User]
|
||||
commits []func()
|
||||
|
@ -1326,6 +1343,11 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
|
|||
}
|
||||
}
|
||||
|
||||
debugContext, err := json.Marshal(params.DebugContext)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal debug context: %w", err)
|
||||
}
|
||||
|
||||
if link.UserID == uuid.Nil {
|
||||
//nolint:gocritic // System needs to insert the user link (linked_id, oauth_token, oauth_expiry).
|
||||
link, err = tx.InsertUserLink(dbauthz.AsSystemRestricted(ctx), database.InsertUserLinkParams{
|
||||
|
@ -1337,6 +1359,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
|
|||
OAuthRefreshToken: params.State.Token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
|
||||
OAuthExpiry: params.State.Token.Expiry,
|
||||
DebugContext: debugContext,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user link: %w", err)
|
||||
|
@ -1353,6 +1376,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
|
|||
OAuthRefreshToken: params.State.Token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
|
||||
OAuthExpiry: params.State.Token.Expiry,
|
||||
DebugContext: debugContext,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update user link: %w", err)
|
||||
|
|
|
@ -28,6 +28,47 @@ import (
|
|||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// userDebugOIDC returns the OIDC debug context for the user.
|
||||
// Not going to expose this via swagger as the return payload is not guaranteed
|
||||
// to be consistent between releases.
|
||||
//
|
||||
// @Summary Debug OIDC context for a user
|
||||
// @ID debug-oidc-context-for-a-user
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Agents
|
||||
// @Success 200 "Success"
|
||||
// @Param user path string true "User ID, name, or me"
|
||||
// @Router /debug/{user}/debug-link [get]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) userDebugOIDC(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
)
|
||||
|
||||
if user.LoginType != database.LoginTypeOIDC {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "User is not an OIDC user.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
link, err := api.Database.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get user links.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// This will encode properly because it is a json.RawMessage.
|
||||
httpapi.Write(ctx, rw, http.StatusOK, link.DebugContext)
|
||||
}
|
||||
|
||||
// Returns whether the initial user has been created or not.
|
||||
//
|
||||
// @Summary Check initial user created
|
||||
|
|
|
@ -43,6 +43,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
|||
OAuthExpiry: userLink.OAuthExpiry,
|
||||
UserID: uid,
|
||||
LoginType: userLink.LoginType,
|
||||
DebugContext: userLink.DebugContext,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
|
||||
}
|
||||
|
@ -132,6 +133,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
|||
OAuthExpiry: userLink.OAuthExpiry,
|
||||
UserID: uid,
|
||||
LoginType: userLink.LoginType,
|
||||
DebugContext: userLink.DebugContext,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
|
@ -55,6 +56,7 @@ func TestUserLinks(t *testing.T) {
|
|||
OAuthRefreshToken: "refresh",
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
DebugContext: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "access", updated.OAuthAccessToken)
|
||||
|
|
Loading…
Reference in New Issue