coder/coderd/oauthpki/oidcpki.go

278 lines
8.4 KiB
Go

package oauthpki
import (
"context"
"crypto/rsa"
"crypto/sha1" //#nosec // Not used for cryptography.
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"golang.org/x/oauth2"
"golang.org/x/oauth2/jws"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/promoauth"
)
// Config uses jwt assertions over client_secret for oauth2 authentication of
// the application. This implementation was made specifically for Azure AD.
//
// https://learn.microsoft.com/en-us/azure/active-directory/develop/certificate-credentials
//
// However this does mostly follow the standard. We can generalize this as we
// include support for more IDPs.
//
// https://datatracker.ietf.org/doc/html/rfc7523
type Config struct {
cfg promoauth.OAuth2Config
// These values should match those provided in the oauth2.Config.
// Because the inner config is an interface, we need to duplicate these
// values here.
scopes []string
clientID string
tokenURL string
// ClientSecret is the private key of the PKI cert.
// Azure AD only supports RS256 signing algorithm.
clientKey *rsa.PrivateKey
// Base64url-encoded SHA-1 thumbprint of the X.509 certificate's DER encoding.
// This is specific to Azure AD
x5t string
}
type ConfigParams struct {
ClientID string
TokenURL string
Scopes []string
PemEncodedKey []byte
PemEncodedCert []byte
Config promoauth.OAuth2Config
}
// NewOauth2PKIConfig creates the oauth2 config for PKI based auth. It requires the certificate and it's private key.
// The values should be passed in as PEM encoded values, which is the standard encoding for x509 certs saved to disk.
// It should look like:
//
// -----BEGIN RSA PRIVATE KEY----
// ...
// -----END RSA PRIVATE KEY-----
//
// -----BEGIN CERTIFICATE-----
// ...
// -----END CERTIFICATE-----
func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
if params.ClientID == "" {
return nil, xerrors.Errorf("")
}
if len(params.Scopes) == 0 {
return nil, xerrors.Errorf("scopes are required")
}
rsaKey, err := decodeClientKey(params.PemEncodedKey)
if err != nil {
return nil, err
}
// Azure AD requires a certificate. The sha1 of the cert is used to identify the signer.
// This is not required in the general specification.
if strings.Contains(strings.ToLower(params.TokenURL), "microsoftonline") && len(params.PemEncodedCert) == 0 {
return nil, xerrors.Errorf("oidc client certificate is required and missing")
}
block, _ := pem.Decode(params.PemEncodedCert)
// Used as an identifier, not an actual cryptographic hash.
//nolint:gosec
hashed := sha1.Sum(block.Bytes)
return &Config{
clientID: params.ClientID,
tokenURL: params.TokenURL,
scopes: params.Scopes,
cfg: params.Config,
clientKey: rsaKey,
x5t: base64.StdEncoding.EncodeToString(hashed[:]),
}, nil
}
// decodeClientKey decodes a PEM encoded rsa secret.
func decodeClientKey(pemEncoded []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(pemEncoded)
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, xerrors.Errorf("failed to parse private key: %w", err)
}
return key, nil
}
func (ja *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
return ja.cfg.AuthCodeURL(state, opts...)
}
// Exchange includes the client_assertion signed JWT.
func (ja *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
signed, err := ja.jwtToken()
if err != nil {
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
}
opts = append(opts,
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
oauth2.SetAuthURLParam("client_assertion", signed),
)
return ja.cfg.Exchange(ctx, code, opts...)
}
func (ja *Config) jwtToken() (string, error) {
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": ja.clientID,
"sub": ja.clientID,
"aud": ja.tokenURL,
// 5-10 minutes is recommended in the Azure docs.
// So we'll use 5 minutes.
"exp": now.Add(time.Minute * 5).Unix(),
"jti": uuid.New().String(),
"nbf": now.Unix(),
"iat": now.Unix(),
})
token.Header["x5t"] = ja.x5t
signed, err := token.SignedString(ja.clientKey)
if err != nil {
return "", xerrors.Errorf("sign jwt assertion: %w", err)
}
return signed, nil
}
func (ja *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
return oauth2.ReuseTokenSource(token, &jwtTokenSource{
cfg: ja,
ctx: ctx,
refreshToken: token.RefreshToken,
})
}
type jwtTokenSource struct {
cfg *Config
ctx context.Context
refreshToken string
}
// Token must be safe for concurrent use by multiple go routines
// Very similar to the RetrieveToken implementation by the oauth2 package.
// https://github.com/golang/oauth2/blob/master/internal/token.go#L212
// Oauth2 package keeps this code unexported or in an /internal package,
// so we have to copy the implementation :(
func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
if src.refreshToken == "" {
return nil, xerrors.New("oauth2: token expired and refresh token is not set")
}
cli := http.DefaultClient
if v, ok := src.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
// This client should be the instrumented client already. So no need to
// handle this manually.
cli = v
}
token, err := src.cfg.jwtToken()
if err != nil {
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
}
v := url.Values{
"client_assertion": {token},
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
"client_id": {src.cfg.clientID},
"grant_type": {"refresh_token"},
"scope": {strings.Join(src.cfg.scopes, " ")},
"refresh_token": {src.refreshToken},
}
// Using params based auth
req, err := http.NewRequest("POST", src.cfg.tokenURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, xerrors.Errorf("oauth2: make token refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(src.ctx)
resp, err := cli.Do(req)
if err != nil {
return nil, xerrors.Errorf("oauth2: cannot get token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, xerrors.Errorf("oauth2: cannot fetch token reading response body: %w", err)
}
var tokenRes struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
// Extra fields returned by the refresh that are needed
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
// error fields
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
ErrorCode string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
}
unmarshalError := json.Unmarshal(body, &tokenRes)
if resp.StatusCode < 200 || resp.StatusCode > 299 {
// Return a standard oauth2 error. Attempt to read some error fields. The error fields
// can be encoded in a few places, so this does not catch all of them.
return nil, &oauth2.RetrieveError{
Response: resp,
Body: body,
// Best effort for error fields
ErrorCode: tokenRes.ErrorCode,
ErrorDescription: tokenRes.ErrorDescription,
ErrorURI: tokenRes.ErrorURI,
}
}
if unmarshalError != nil {
return nil, xerrors.Errorf("oauth2: cannot unmarshal token: %w", err)
}
newToken := &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken,
}
if secs := tokenRes.ExpiresIn; secs > 0 {
newToken.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
}
// ID token is a JWT token. We can decode it to get the expiry.
// Not really sure what to do if the ExpiresIn and JWT expiry differ,
// but this one is attached in the JWT and guaranteed to be right for local
// validation. So use this one if found.
if v := tokenRes.IDToken; v != "" {
// decode returned id token to get expiry
claimSet, err := jws.Decode(v)
if err != nil {
return nil, xerrors.Errorf("oauth2: error decoding JWT token: %w", err)
}
newToken.Expiry = time.Unix(claimSet.Exp, 0)
}
return newToken, nil
}