mirror of https://github.com/coder/coder.git
278 lines
8.4 KiB
Go
278 lines
8.4 KiB
Go
package oauthpki
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"crypto/sha1" //#nosec // Not used for cryptography.
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/jws"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/promoauth"
|
|
)
|
|
|
|
// Config uses jwt assertions over client_secret for oauth2 authentication of
|
|
// the application. This implementation was made specifically for Azure AD.
|
|
//
|
|
// https://learn.microsoft.com/en-us/azure/active-directory/develop/certificate-credentials
|
|
//
|
|
// However this does mostly follow the standard. We can generalize this as we
|
|
// include support for more IDPs.
|
|
//
|
|
// https://datatracker.ietf.org/doc/html/rfc7523
|
|
type Config struct {
|
|
cfg promoauth.OAuth2Config
|
|
|
|
// These values should match those provided in the oauth2.Config.
|
|
// Because the inner config is an interface, we need to duplicate these
|
|
// values here.
|
|
scopes []string
|
|
clientID string
|
|
tokenURL string
|
|
|
|
// ClientSecret is the private key of the PKI cert.
|
|
// Azure AD only supports RS256 signing algorithm.
|
|
clientKey *rsa.PrivateKey
|
|
// Base64url-encoded SHA-1 thumbprint of the X.509 certificate's DER encoding.
|
|
// This is specific to Azure AD
|
|
x5t string
|
|
}
|
|
|
|
type ConfigParams struct {
|
|
ClientID string
|
|
TokenURL string
|
|
Scopes []string
|
|
PemEncodedKey []byte
|
|
PemEncodedCert []byte
|
|
|
|
Config promoauth.OAuth2Config
|
|
}
|
|
|
|
// NewOauth2PKIConfig creates the oauth2 config for PKI based auth. It requires the certificate and it's private key.
|
|
// The values should be passed in as PEM encoded values, which is the standard encoding for x509 certs saved to disk.
|
|
// It should look like:
|
|
//
|
|
// -----BEGIN RSA PRIVATE KEY----
|
|
// ...
|
|
// -----END RSA PRIVATE KEY-----
|
|
//
|
|
// -----BEGIN CERTIFICATE-----
|
|
// ...
|
|
// -----END CERTIFICATE-----
|
|
func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
|
|
if params.ClientID == "" {
|
|
return nil, xerrors.Errorf("")
|
|
}
|
|
if len(params.Scopes) == 0 {
|
|
return nil, xerrors.Errorf("scopes are required")
|
|
}
|
|
|
|
rsaKey, err := decodeClientKey(params.PemEncodedKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Azure AD requires a certificate. The sha1 of the cert is used to identify the signer.
|
|
// This is not required in the general specification.
|
|
if strings.Contains(strings.ToLower(params.TokenURL), "microsoftonline") && len(params.PemEncodedCert) == 0 {
|
|
return nil, xerrors.Errorf("oidc client certificate is required and missing")
|
|
}
|
|
|
|
block, _ := pem.Decode(params.PemEncodedCert)
|
|
// Used as an identifier, not an actual cryptographic hash.
|
|
//nolint:gosec
|
|
hashed := sha1.Sum(block.Bytes)
|
|
|
|
return &Config{
|
|
clientID: params.ClientID,
|
|
tokenURL: params.TokenURL,
|
|
scopes: params.Scopes,
|
|
cfg: params.Config,
|
|
clientKey: rsaKey,
|
|
x5t: base64.StdEncoding.EncodeToString(hashed[:]),
|
|
}, nil
|
|
}
|
|
|
|
// decodeClientKey decodes a PEM encoded rsa secret.
|
|
func decodeClientKey(pemEncoded []byte) (*rsa.PrivateKey, error) {
|
|
block, _ := pem.Decode(pemEncoded)
|
|
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed to parse private key: %w", err)
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func (ja *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
|
return ja.cfg.AuthCodeURL(state, opts...)
|
|
}
|
|
|
|
// Exchange includes the client_assertion signed JWT.
|
|
func (ja *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
|
signed, err := ja.jwtToken()
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
|
|
}
|
|
opts = append(opts,
|
|
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
|
|
oauth2.SetAuthURLParam("client_assertion", signed),
|
|
)
|
|
return ja.cfg.Exchange(ctx, code, opts...)
|
|
}
|
|
|
|
func (ja *Config) jwtToken() (string, error) {
|
|
now := time.Now()
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
|
"iss": ja.clientID,
|
|
"sub": ja.clientID,
|
|
"aud": ja.tokenURL,
|
|
// 5-10 minutes is recommended in the Azure docs.
|
|
// So we'll use 5 minutes.
|
|
"exp": now.Add(time.Minute * 5).Unix(),
|
|
"jti": uuid.New().String(),
|
|
"nbf": now.Unix(),
|
|
"iat": now.Unix(),
|
|
})
|
|
token.Header["x5t"] = ja.x5t
|
|
|
|
signed, err := token.SignedString(ja.clientKey)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("sign jwt assertion: %w", err)
|
|
}
|
|
return signed, nil
|
|
}
|
|
|
|
func (ja *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
|
|
return oauth2.ReuseTokenSource(token, &jwtTokenSource{
|
|
cfg: ja,
|
|
ctx: ctx,
|
|
refreshToken: token.RefreshToken,
|
|
})
|
|
}
|
|
|
|
type jwtTokenSource struct {
|
|
cfg *Config
|
|
ctx context.Context
|
|
refreshToken string
|
|
}
|
|
|
|
// Token must be safe for concurrent use by multiple go routines
|
|
// Very similar to the RetrieveToken implementation by the oauth2 package.
|
|
// https://github.com/golang/oauth2/blob/master/internal/token.go#L212
|
|
// Oauth2 package keeps this code unexported or in an /internal package,
|
|
// so we have to copy the implementation :(
|
|
func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
|
|
if src.refreshToken == "" {
|
|
return nil, xerrors.New("oauth2: token expired and refresh token is not set")
|
|
}
|
|
cli := http.DefaultClient
|
|
if v, ok := src.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
|
// This client should be the instrumented client already. So no need to
|
|
// handle this manually.
|
|
cli = v
|
|
}
|
|
|
|
token, err := src.cfg.jwtToken()
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
|
|
}
|
|
|
|
v := url.Values{
|
|
"client_assertion": {token},
|
|
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
|
|
"client_id": {src.cfg.clientID},
|
|
"grant_type": {"refresh_token"},
|
|
"scope": {strings.Join(src.cfg.scopes, " ")},
|
|
"refresh_token": {src.refreshToken},
|
|
}
|
|
// Using params based auth
|
|
req, err := http.NewRequest("POST", src.cfg.tokenURL, strings.NewReader(v.Encode()))
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("oauth2: make token refresh request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req = req.WithContext(src.ctx)
|
|
resp, err := cli.Do(req)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("oauth2: cannot get token: %w", err)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("oauth2: cannot fetch token reading response body: %w", err)
|
|
}
|
|
|
|
var tokenRes struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type,omitempty"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
|
|
// Extra fields returned by the refresh that are needed
|
|
IDToken string `json:"id_token"`
|
|
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
|
// error fields
|
|
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
|
ErrorCode string `json:"error"`
|
|
ErrorDescription string `json:"error_description"`
|
|
ErrorURI string `json:"error_uri"`
|
|
}
|
|
|
|
unmarshalError := json.Unmarshal(body, &tokenRes)
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
// Return a standard oauth2 error. Attempt to read some error fields. The error fields
|
|
// can be encoded in a few places, so this does not catch all of them.
|
|
return nil, &oauth2.RetrieveError{
|
|
Response: resp,
|
|
Body: body,
|
|
// Best effort for error fields
|
|
ErrorCode: tokenRes.ErrorCode,
|
|
ErrorDescription: tokenRes.ErrorDescription,
|
|
ErrorURI: tokenRes.ErrorURI,
|
|
}
|
|
}
|
|
|
|
if unmarshalError != nil {
|
|
return nil, xerrors.Errorf("oauth2: cannot unmarshal token: %w", err)
|
|
}
|
|
|
|
newToken := &oauth2.Token{
|
|
AccessToken: tokenRes.AccessToken,
|
|
TokenType: tokenRes.TokenType,
|
|
RefreshToken: tokenRes.RefreshToken,
|
|
}
|
|
|
|
if secs := tokenRes.ExpiresIn; secs > 0 {
|
|
newToken.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
|
}
|
|
|
|
// ID token is a JWT token. We can decode it to get the expiry.
|
|
// Not really sure what to do if the ExpiresIn and JWT expiry differ,
|
|
// but this one is attached in the JWT and guaranteed to be right for local
|
|
// validation. So use this one if found.
|
|
if v := tokenRes.IDToken; v != "" {
|
|
// decode returned id token to get expiry
|
|
claimSet, err := jws.Decode(v)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("oauth2: error decoding JWT token: %w", err)
|
|
}
|
|
newToken.Expiry = time.Unix(claimSet.Exp, 0)
|
|
}
|
|
|
|
return newToken, nil
|
|
}
|