coder/enterprise/coderd/identityprovider/tokens.go

379 lines
12 KiB
Go

package identityprovider
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"net/url"
"time"
"github.com/google/uuid"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/userpassword"
"github.com/coder/coder/v2/codersdk"
)
var (
// errBadSecret means the user provided a bad secret.
errBadSecret = xerrors.New("Invalid client secret")
// errBadCode means the user provided a bad code.
errBadCode = xerrors.New("Invalid code")
// errBadToken means the user provided a bad token.
errBadToken = xerrors.New("Invalid token")
)
type tokenParams struct {
clientID string
clientSecret string
code string
grantType codersdk.OAuth2ProviderGrantType
redirectURL *url.URL
refreshToken string
}
func extractTokenParams(r *http.Request, callbackURL *url.URL) (tokenParams, []codersdk.ValidationError, error) {
p := httpapi.NewQueryParamParser()
err := r.ParseForm()
if err != nil {
return tokenParams{}, nil, xerrors.Errorf("parse form: %w", err)
}
vals := r.Form
p.RequiredNotEmpty("grant_type")
grantType := httpapi.ParseCustom(p, vals, "", "grant_type", httpapi.ParseEnum[codersdk.OAuth2ProviderGrantType])
switch grantType {
case codersdk.OAuth2ProviderGrantTypeRefreshToken:
p.RequiredNotEmpty("refresh_token")
case codersdk.OAuth2ProviderGrantTypeAuthorizationCode:
p.RequiredNotEmpty("client_secret", "client_id", "code")
}
params := tokenParams{
clientID: p.String(vals, "", "client_id"),
clientSecret: p.String(vals, "", "client_secret"),
code: p.String(vals, "", "code"),
grantType: grantType,
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
refreshToken: p.String(vals, "", "refresh_token"),
}
p.ErrorExcessParams(vals)
if len(p.Errors) > 0 {
return tokenParams{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors)
}
return params, nil, nil
}
func Tokens(db database.Store, defaultLifetime time.Duration) http.HandlerFunc {
return func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
app := httpmw.OAuth2ProviderApp(r)
callbackURL, err := url.Parse(app.CallbackURL)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to validate form values.",
Detail: err.Error(),
})
return
}
params, validationErrs, err := extractTokenParams(r, callbackURL)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query params.",
Detail: err.Error(),
Validations: validationErrs,
})
return
}
var token oauth2.Token
//nolint:gocritic,revive // More cases will be added later.
switch params.grantType {
// TODO: Client creds, device code.
case codersdk.OAuth2ProviderGrantTypeRefreshToken:
token, err = refreshTokenGrant(ctx, db, app, defaultLifetime, params)
case codersdk.OAuth2ProviderGrantTypeAuthorizationCode:
token, err = authorizationCodeGrant(ctx, db, app, defaultLifetime, params)
default:
// Grant types are validated by the parser, so getting through here means
// the developer added a type but forgot to add a case here.
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unhandled grant type.",
Detail: fmt.Sprintf("Grant type %q is unhandled", params.grantType),
})
return
}
if errors.Is(err, errBadCode) || errors.Is(err, errBadSecret) {
httpapi.Write(r.Context(), rw, http.StatusUnauthorized, codersdk.Response{
Message: err.Error(),
})
return
}
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to exchange token",
Detail: err.Error(),
})
return
}
// Some client libraries allow this to be "application/x-www-form-urlencoded". We can implement that upon
// request. The same libraries should also accept JSON. If implemented, choose based on "Accept" header.
httpapi.Write(ctx, rw, http.StatusOK, token)
}
}
func authorizationCodeGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, defaultLifetime time.Duration, params tokenParams) (oauth2.Token, error) {
// Validate the client secret.
secret, err := parseSecret(params.clientSecret)
if err != nil {
return oauth2.Token{}, errBadSecret
}
//nolint:gocritic // Users cannot read secrets so we must use the system.
dbSecret, err := db.GetOAuth2ProviderAppSecretByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(secret.prefix))
if errors.Is(err, sql.ErrNoRows) {
return oauth2.Token{}, errBadSecret
}
if err != nil {
return oauth2.Token{}, err
}
equal, err := userpassword.Compare(string(dbSecret.HashedSecret), secret.secret)
if err != nil {
return oauth2.Token{}, xerrors.Errorf("unable to compare secret: %w", err)
}
if !equal {
return oauth2.Token{}, errBadSecret
}
// Validate the authorization code.
code, err := parseSecret(params.code)
if err != nil {
return oauth2.Token{}, errBadCode
}
//nolint:gocritic // There is no user yet so we must use the system.
dbCode, err := db.GetOAuth2ProviderAppCodeByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(code.prefix))
if errors.Is(err, sql.ErrNoRows) {
return oauth2.Token{}, errBadCode
}
if err != nil {
return oauth2.Token{}, err
}
equal, err = userpassword.Compare(string(dbCode.HashedSecret), code.secret)
if err != nil {
return oauth2.Token{}, xerrors.Errorf("unable to compare code: %w", err)
}
if !equal {
return oauth2.Token{}, errBadCode
}
// Ensure the code has not expired.
if dbCode.ExpiresAt.Before(dbtime.Now()) {
return oauth2.Token{}, errBadCode
}
// Generate a refresh token.
refreshToken, err := GenerateSecret()
if err != nil {
return oauth2.Token{}, err
}
// Generate the API key we will swap for the code.
// TODO: We are ignoring scopes for now.
tokenName := fmt.Sprintf("%s_%s_oauth_session_token", dbCode.UserID, app.ID)
key, sessionToken, err := apikey.Generate(apikey.CreateParams{
UserID: dbCode.UserID,
LoginType: database.LoginTypeOAuth2ProviderApp,
// TODO: This is just the lifetime for api keys, maybe have its own config
// settings. #11693
DefaultLifetime: defaultLifetime,
// For now, we allow only one token per app and user at a time.
TokenName: tokenName,
})
if err != nil {
return oauth2.Token{}, err
}
// Grab the user roles so we can perform the exchange as the user.
//nolint:gocritic // In the token exchange, there is no user actor.
roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), dbCode.UserID)
if err != nil {
return oauth2.Token{}, err
}
userSubj := rbac.Subject{
ID: dbCode.UserID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.ScopeAll,
}
// Do the actual token exchange in the database.
err = db.InTx(func(tx database.Store) error {
ctx := dbauthz.As(ctx, userSubj)
err = tx.DeleteOAuth2ProviderAppCodeByID(ctx, dbCode.ID)
if err != nil {
return xerrors.Errorf("delete oauth2 app code: %w", err)
}
// Delete the previous key, if any.
prevKey, err := tx.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
UserID: dbCode.UserID,
TokenName: tokenName,
})
if err == nil {
err = tx.DeleteAPIKeyByID(ctx, prevKey.ID)
}
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("delete api key by name: %w", err)
}
newKey, err := tx.InsertAPIKey(ctx, key)
if err != nil {
return xerrors.Errorf("insert oauth2 access token: %w", err)
}
_, err = tx.InsertOAuth2ProviderAppToken(ctx, database.InsertOAuth2ProviderAppTokenParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
ExpiresAt: key.ExpiresAt,
HashPrefix: []byte(refreshToken.Prefix),
RefreshHash: []byte(refreshToken.Hashed),
AppSecretID: dbSecret.ID,
APIKeyID: newKey.ID,
})
if err != nil {
return xerrors.Errorf("insert oauth2 refresh token: %w", err)
}
return nil
}, nil)
if err != nil {
return oauth2.Token{}, err
}
return oauth2.Token{
AccessToken: sessionToken,
TokenType: "Bearer",
RefreshToken: refreshToken.Formatted,
Expiry: key.ExpiresAt,
}, nil
}
func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, defaultLifetime time.Duration, params tokenParams) (oauth2.Token, error) {
// Validate the token.
token, err := parseSecret(params.refreshToken)
if err != nil {
return oauth2.Token{}, errBadToken
}
//nolint:gocritic // There is no user yet so we must use the system.
dbToken, err := db.GetOAuth2ProviderAppTokenByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(token.prefix))
if errors.Is(err, sql.ErrNoRows) {
return oauth2.Token{}, errBadToken
}
if err != nil {
return oauth2.Token{}, err
}
equal, err := userpassword.Compare(string(dbToken.RefreshHash), token.secret)
if err != nil {
return oauth2.Token{}, xerrors.Errorf("unable to compare token: %w", err)
}
if !equal {
return oauth2.Token{}, errBadToken
}
// Ensure the token has not expired.
if dbToken.ExpiresAt.Before(dbtime.Now()) {
return oauth2.Token{}, errBadToken
}
// Grab the user roles so we can perform the refresh as the user.
//nolint:gocritic // There is no user yet so we must use the system.
prevKey, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), dbToken.APIKeyID)
if err != nil {
return oauth2.Token{}, err
}
//nolint:gocritic // There is no user yet so we must use the system.
roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), prevKey.UserID)
if err != nil {
return oauth2.Token{}, err
}
userSubj := rbac.Subject{
ID: prevKey.UserID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.ScopeAll,
}
// Generate a new refresh token.
refreshToken, err := GenerateSecret()
if err != nil {
return oauth2.Token{}, err
}
// Generate the new API key.
// TODO: We are ignoring scopes for now.
tokenName := fmt.Sprintf("%s_%s_oauth_session_token", prevKey.UserID, app.ID)
key, sessionToken, err := apikey.Generate(apikey.CreateParams{
UserID: prevKey.UserID,
LoginType: database.LoginTypeOAuth2ProviderApp,
// TODO: This is just the lifetime for api keys, maybe have its own config
// settings. #11693
DefaultLifetime: defaultLifetime,
// For now, we allow only one token per app and user at a time.
TokenName: tokenName,
})
if err != nil {
return oauth2.Token{}, err
}
// Replace the token.
err = db.InTx(func(tx database.Store) error {
ctx := dbauthz.As(ctx, userSubj)
err = tx.DeleteAPIKeyByID(ctx, prevKey.ID) // This cascades to the token.
if err != nil {
return xerrors.Errorf("delete oauth2 app token: %w", err)
}
newKey, err := tx.InsertAPIKey(ctx, key)
if err != nil {
return xerrors.Errorf("insert oauth2 access token: %w", err)
}
_, err = tx.InsertOAuth2ProviderAppToken(ctx, database.InsertOAuth2ProviderAppTokenParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
ExpiresAt: key.ExpiresAt,
HashPrefix: []byte(refreshToken.Prefix),
RefreshHash: []byte(refreshToken.Hashed),
AppSecretID: dbToken.AppSecretID,
APIKeyID: newKey.ID,
})
if err != nil {
return xerrors.Errorf("insert oauth2 refresh token: %w", err)
}
return nil
}, nil)
if err != nil {
return oauth2.Token{}, err
}
return oauth2.Token{
AccessToken: sessionToken,
TokenType: "Bearer",
RefreshToken: refreshToken.Formatted,
Expiry: key.ExpiresAt,
}, nil
}