coder/coderd/coderdtest/oidctest/idp.go

1443 lines
44 KiB
Go

package oidctest
import (
"context"
"crypto"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"math/rand"
"mime"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-chi/chi/v5"
"github.com/go-jose/go-jose/v3"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/coderd/util/syncmap"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
type token struct {
issued time.Time
email string
exp time.Time
}
type deviceFlow struct {
// userInput is the expected input to authenticate the device flow.
userInput string
exp time.Time
granted bool
}
// FakeIDP is a functional OIDC provider.
// It only supports 1 OIDC client.
type FakeIDP struct {
issuer string
issuerURL *url.URL
key *rsa.PrivateKey
provider ProviderJSON
handler http.Handler
cfg *oauth2.Config
// callbackPath allows changing where the callback path to coderd is expected.
// This only affects using the Login helper functions.
callbackPath string
// clientID to be used by coderd
clientID string
clientSecret string
// externalProviderID is optional to match the provider in coderd for
// redirectURLs.
externalProviderID string
logger slog.Logger
// externalAuthValidate will be called when the user tries to validate their
// external auth. The fake IDP will reject any invalid tokens, so this just
// controls the response payload after a successfully authed token.
externalAuthValidate func(email string, rw http.ResponseWriter, r *http.Request)
// These maps are used to control the state of the IDP.
// That is the various access tokens, refresh tokens, states, etc.
codeToStateMap *syncmap.Map[string, string]
// Token -> Email
accessTokens *syncmap.Map[string, token]
// Refresh Token -> Email
refreshTokensUsed *syncmap.Map[string, bool]
refreshTokens *syncmap.Map[string, string]
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
// Device flow
deviceCode *syncmap.Map[string, deviceFlow]
// hooks
// hookValidRedirectURL can be used to reject a redirect url from the
// IDP -> Application. Almost all IDPs have the concept of
// "Authorized Redirect URLs". This can be used to emulate that.
hookValidRedirectURL func(redirectURL string) error
hookUserInfo func(email string) (jwt.MapClaims, error)
// defaultIDClaims is if a new client connects and we didn't preset
// some claims.
defaultIDClaims jwt.MapClaims
hookMutateToken func(token map[string]interface{})
fakeCoderd func(req *http.Request) (*http.Response, error)
hookOnRefresh func(email string) error
// Custom authentication for the client. This is useful if you want
// to test something like PKI auth vs a client_secret.
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
serve bool
// optional middlewares
middlewares chi.Middlewares
defaultExpire time.Duration
}
func StatusError(code int, err error) error {
return statusHookError{
Err: err,
HTTPStatusCode: code,
}
}
// statusHookError allows a hook to change the returned http status code.
type statusHookError struct {
Err error
HTTPStatusCode int
}
func (s statusHookError) Error() string {
if s.Err == nil {
return ""
}
return s.Err.Error()
}
type FakeIDPOpt func(idp *FakeIDP)
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookValidRedirectURL = hook
}
}
func WithMiddlewares(mws ...func(http.Handler) http.Handler) func(*FakeIDP) {
return func(f *FakeIDP) {
f.middlewares = append(f.middlewares, mws...)
}
}
// WithRefresh is called when a refresh token is used. The email is
// the email of the user that is being refreshed assuming the claims are correct.
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookOnRefresh = hook
}
}
func WithDefaultExpire(d time.Duration) func(*FakeIDP) {
return func(f *FakeIDP) {
f.defaultExpire = d
}
}
func WithCallbackPath(path string) func(*FakeIDP) {
return func(f *FakeIDP) {
f.callbackPath = path
}
}
func WithStaticCredentials(id, secret string) func(*FakeIDP) {
return func(f *FakeIDP) {
if id != "" {
f.clientID = id
}
if secret != "" {
f.clientSecret = secret
}
}
}
// WithExtra returns extra fields that be accessed on the returned Oauth Token.
// These extra fields can override the default fields (id_token, access_token, etc).
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookMutateToken = mutateToken
}
}
func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookAuthenticateClient = hook
}
}
// WithLogging is optional, but will log some HTTP calls made to the IDP.
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = slogtest.Make(t, options)
}
}
func WithLogger(logger slog.Logger) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = logger
}
}
// WithStaticUserInfo is optional, but will return the same user info for
// every user on the /userinfo endpoint.
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookUserInfo = func(_ string) (jwt.MapClaims, error) {
return info, nil
}
}
}
func WithDefaultIDClaims(claims jwt.MapClaims) func(*FakeIDP) {
return func(f *FakeIDP) {
f.defaultIDClaims = claims
}
}
func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookUserInfo = userInfoFunc
}
}
// WithServing makes the IDP run an actual http server.
func WithServing() func(*FakeIDP) {
return func(f *FakeIDP) {
f.serve = true
}
}
func WithIssuer(issuer string) func(*FakeIDP) {
return func(f *FakeIDP) {
f.issuer = issuer
}
}
type With429Arguments struct {
AllPaths bool
TokenPath bool
AuthorizePath bool
KeysPath bool
UserInfoPath bool
DeviceAuth bool
DeviceVerify bool
}
// With429 will emulate a 429 response for the selected paths.
func With429(params With429Arguments) func(*FakeIDP) {
return func(f *FakeIDP) {
f.middlewares = append(f.middlewares, func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if params.AllPaths {
http.Error(rw, "429, being manually blocked (all)", http.StatusTooManyRequests)
return
}
if params.TokenPath && strings.Contains(r.URL.Path, tokenPath) {
http.Error(rw, "429, being manually blocked (token)", http.StatusTooManyRequests)
return
}
if params.AuthorizePath && strings.Contains(r.URL.Path, authorizePath) {
http.Error(rw, "429, being manually blocked (authorize)", http.StatusTooManyRequests)
return
}
if params.KeysPath && strings.Contains(r.URL.Path, keysPath) {
http.Error(rw, "429, being manually blocked (keys)", http.StatusTooManyRequests)
return
}
if params.UserInfoPath && strings.Contains(r.URL.Path, userInfoPath) {
http.Error(rw, "429, being manually blocked (userinfo)", http.StatusTooManyRequests)
return
}
if params.DeviceAuth && strings.Contains(r.URL.Path, deviceAuth) {
http.Error(rw, "429, being manually blocked (device-auth)", http.StatusTooManyRequests)
return
}
if params.DeviceVerify && strings.Contains(r.URL.Path, deviceVerify) {
http.Error(rw, "429, being manually blocked (device-verify)", http.StatusTooManyRequests)
return
}
next.ServeHTTP(rw, r)
})
})
}
}
const (
// nolint:gosec // It thinks this is a secret lol
tokenPath = "/oauth2/token"
authorizePath = "/oauth2/authorize"
keysPath = "/oauth2/keys"
userInfoPath = "/oauth2/userinfo"
deviceAuth = "/login/device/code"
deviceVerify = "/login/device"
)
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
t.Helper()
block, _ := pem.Decode([]byte(testRSAPrivateKey))
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)
idp := &FakeIDP{
key: pkey,
clientID: uuid.NewString(),
clientSecret: uuid.NewString(),
logger: slog.Make(),
codeToStateMap: syncmap.New[string, string](),
accessTokens: syncmap.New[string, token](),
refreshTokens: syncmap.New[string, string](),
refreshTokensUsed: syncmap.New[string, bool](),
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
deviceCode: syncmap.New[string, deviceFlow](),
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
hookValidRedirectURL: func(redirectURL string) error { return nil },
defaultExpire: time.Minute * 5,
}
for _, opt := range opts {
opt(idp)
}
if idp.issuer == "" {
idp.issuer = "https://coder.com"
}
idp.handler = idp.httpHandler(t)
idp.updateIssuerURL(t, idp.issuer)
if idp.serve {
idp.realServer(t)
}
return idp
}
func (f *FakeIDP) WellknownConfig() ProviderJSON {
return f.provider
}
func (f *FakeIDP) IssuerURL() *url.URL {
return f.issuerURL
}
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
t.Helper()
u, err := url.Parse(issuer)
require.NoError(t, err, "invalid issuer URL")
f.issuer = issuer
f.issuerURL = u
// ProviderJSON is the JSON representation of the OpenID Connect provider
// These are all the urls that the IDP will respond to.
f.provider = ProviderJSON{
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(),
Algorithms: []string{
"RS256",
},
ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
}
}
// realServer turns the FakeIDP into a real http server.
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
t.Helper()
srvURL := "localhost:0"
issURL, err := url.Parse(f.issuer)
if err == nil {
if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" {
srvURL = issURL.Host
}
}
l, err := net.Listen("tcp", srvURL)
require.NoError(t, err, "failed to create listener")
ctx, cancel := context.WithCancel(context.Background())
srv := &httptest.Server{
Listener: l,
Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5},
}
srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
srv.Start()
t.Cleanup(srv.CloseClientConnections)
t.Cleanup(srv.Close)
t.Cleanup(cancel)
f.updateIssuerURL(t, srv.URL)
return srv
}
// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
// valid token for some given claims.
func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) {
state := uuid.NewString()
f.stateToIDTokenClaims.Store(state, claims)
code := f.newCode(state)
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
}
// Login does the full OIDC flow starting at the "LoginButton".
// The client argument is just to get the URL of the Coder instance.
//
// The client passed in is just to get the url of the Coder instance.
// The actual client that is used is 100% unauthenticated and fresh.
func (f *FakeIDP) Login(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
t.Helper()
client, resp := f.AttemptLogin(t, client, idTokenClaims, opts...)
if resp.StatusCode != http.StatusOK {
data, err := httputil.DumpResponse(resp, true)
if err == nil {
t.Logf("Attempt Login response payload\n%s", string(data))
}
}
require.Equal(t, http.StatusOK, resp.StatusCode, "client failed to login")
return client, resp
}
func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
t.Helper()
var err error
cli := f.HTTPClient(client.HTTPClient)
shallowCpyCli := *cli
if shallowCpyCli.Jar == nil {
shallowCpyCli.Jar, err = cookiejar.New(nil)
require.NoError(t, err, "failed to create cookie jar")
}
unauthenticated := codersdk.New(client.URL)
unauthenticated.HTTPClient = &shallowCpyCli
return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...)
}
// LoginWithClient reuses the context of the passed in client. This means the same
// cookies will be used. This should be an unauthenticated client in most cases.
//
// This is a niche case, but it is needed for testing ConvertLoginType.
func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
t.Helper()
path := "/api/v2/users/oidc/callback"
if f.callbackPath != "" {
path = f.callbackPath
}
coderOauthURL, err := client.URL.Parse(path)
require.NoError(t, err)
f.SetRedirect(t, coderOauthURL.String())
cli := f.HTTPClient(client.HTTPClient)
cli.CheckRedirect = func(req *http.Request, via []*http.Request) error {
// Store the idTokenClaims to the specific state request. This ties
// the claims 1:1 with a given authentication flow.
state := req.URL.Query().Get("state")
f.stateToIDTokenClaims.Store(state, idTokenClaims)
return nil
}
req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil)
require.NoError(t, err)
if cli.Jar == nil {
cli.Jar, err = cookiejar.New(nil)
require.NoError(t, err, "failed to create cookie jar")
}
for _, opt := range opts {
opt(req)
}
res, err := cli.Do(req)
require.NoError(t, err)
// If the coder session token exists, return the new authed client!
var user *codersdk.Client
cookies := cli.Jar.Cookies(client.URL)
for _, cookie := range cookies {
if cookie.Name == codersdk.SessionTokenCookie {
user = codersdk.New(client.URL)
user.SetSessionToken(cookie.Value)
}
}
t.Cleanup(func() {
if res.Body != nil {
_ = res.Body.Close()
}
})
return user, res
}
// ExternalLogin does the oauth2 flow for external auth providers. This requires
// an authenticated coder client.
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) {
coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID))
require.NoError(t, err)
f.SetRedirect(t, coderOauthURL.String())
cli := f.HTTPClient(client.HTTPClient)
cli.CheckRedirect = func(req *http.Request, via []*http.Request) error {
// Store the idTokenClaims to the specific state request. This ties
// the claims 1:1 with a given authentication flow.
state := req.URL.Query().Get("state")
f.stateToIDTokenClaims.Store(state, jwt.MapClaims{})
return nil
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil)
require.NoError(t, err)
// External auth flow requires the user be authenticated.
headerName := client.SessionTokenHeader
if headerName == "" {
headerName = codersdk.SessionTokenHeader
}
req.Header.Set(headerName, client.SessionToken())
if cli.Jar == nil {
cli.Jar, err = cookiejar.New(nil)
require.NoError(t, err, "failed to create cookie jar")
}
for _, opt := range opts {
opt(req)
}
res, err := cli.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "client failed to login")
_ = res.Body.Close()
}
// DeviceLogin does the oauth2 device flow for external auth providers.
func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) {
// First we need to initiate the device flow. This will have Coder hit the
// fake IDP and get a device code.
device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID)
require.NoError(t, err)
// Now the user needs to go to the fake IDP page and click "allow" and enter
// the device code input. For our purposes, we just send an http request to
// the verification url. No additional user input is needed.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil)
require.NoError(t, err)
defer resp.Body.Close()
// Now we need to exchange the device code for an access token. We do this
// in this method because it is the user that does the polling for the device
// auth flow, not the backend.
err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{
DeviceCode: device.DeviceCode,
})
require.NoError(t, err)
}
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
// unit tests, it's easier to skip this step sometimes. It does make an actual
// request to the IDP, so it should be equivalent to doing this "manually" with
// actual requests.
func (f *FakeIDP) CreateAuthCode(t testing.TB, state string) string {
// We need to store some claims, because this is also an OIDC provider, and
// it expects some claims to be present.
f.stateToIDTokenClaims.Store(state, jwt.MapClaims{})
code, err := OAuth2GetCode(f.cfg.AuthCodeURL(state), func(req *http.Request) (*http.Response, error) {
rw := httptest.NewRecorder()
f.handler.ServeHTTP(rw, req)
resp := rw.Result()
return resp, nil
})
require.NoError(t, err, "failed to get auth code")
return code
}
// OIDCCallback will emulate the IDP redirecting back to the Coder callback.
// This is helpful if no Coderd exists because the IDP needs to redirect to
// something.
// Essentially this is used to fake the Coderd side of the exchange.
// The flow starts at the user hitting the OIDC login page.
func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) {
t.Helper()
if f.serve {
panic("cannot use OIDCCallback with WithServing. This is only for the in memory usage")
}
f.stateToIDTokenClaims.Store(state, idTokenClaims)
cli := f.HTTPClient(nil)
u := f.cfg.AuthCodeURL(state)
req, err := http.NewRequest("GET", u, nil)
require.NoError(t, err)
resp, err := cli.Do(req.WithContext(context.Background()))
require.NoError(t, err)
t.Cleanup(func() {
if resp.Body != nil {
_ = resp.Body.Close()
}
})
return resp, nil
}
// ProviderJSON is the .well-known/configuration JSON
type ProviderJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
DeviceCodeURL string `json:"device_authorization_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
// This is custom
ExternalAuthURL string `json:"external_auth_url"`
}
// newCode enforces the code exchanged is actually a valid code
// created by the IDP.
func (f *FakeIDP) newCode(state string) string {
code := uuid.NewString()
f.codeToStateMap.Store(code, state)
return code
}
// newToken enforces the access token exchanged is actually a valid access token
// created by the IDP.
func (f *FakeIDP) newToken(email string, expires time.Time) string {
accessToken := uuid.NewString()
f.accessTokens.Store(accessToken, token{
issued: time.Now(),
email: email,
exp: expires,
})
return accessToken
}
func (f *FakeIDP) newRefreshTokens(email string) string {
refreshToken := uuid.NewString()
f.refreshTokens.Store(refreshToken, email)
return refreshToken
}
// authenticateBearerTokenRequest enforces the access token is valid.
func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) {
t.Helper()
auth := req.Header.Get("Authorization")
token := strings.TrimPrefix(auth, "Bearer ")
authToken, ok := f.accessTokens.Load(token)
if !ok {
return "", xerrors.New("invalid access token")
}
if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
return "", xerrors.New("access token expired")
}
return token, nil
}
// authenticateOIDCClientRequest enforces the client_id and client_secret are valid.
func (f *FakeIDP) authenticateOIDCClientRequest(t testing.TB, req *http.Request) (url.Values, error) {
t.Helper()
if f.hookAuthenticateClient != nil {
return f.hookAuthenticateClient(t, req)
}
data, err := io.ReadAll(req.Body)
if !assert.NoError(t, err, "read token request body") {
return nil, xerrors.Errorf("authenticate request, read body: %w", err)
}
values, err := url.ParseQuery(string(data))
if !assert.NoError(t, err, "parse token request values") {
return nil, xerrors.New("invalid token request")
}
if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") {
return nil, xerrors.New("client_id mismatch")
}
if !assert.Equal(t, f.clientSecret, values.Get("client_secret"), "client_secret mismatch") {
return nil, xerrors.New("client_secret mismatch")
}
return values, nil
}
// encodeClaims is a helper func to convert claims to a valid JWT.
func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string {
t.Helper()
if _, ok := claims["exp"]; !ok {
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
}
if _, ok := claims["aud"]; !ok {
claims["aud"] = f.clientID
}
if _, ok := claims["iss"]; !ok {
claims["iss"] = f.issuer
}
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key)
require.NoError(t, err)
return signed
}
// httpHandler is the IDP http server.
func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
t.Helper()
mux := chi.NewMux()
mux.Use(f.middlewares...)
// This endpoint is required to initialize the OIDC provider.
// It is used to get the OIDC configuration.
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http OIDC config", slog.F("url", r.URL.String()))
_ = json.NewEncoder(rw).Encode(f.provider)
})
// Authorize is called when the user is redirected to the IDP to login.
// This is the browser hitting the IDP and the user logging into Google or
// w/e and clicking "Allow". They will be redirected back to the redirect
// when this is done.
mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call authorize", slog.F("url", r.URL.String()))
clientID := r.URL.Query().Get("client_id")
if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") {
http.Error(rw, "invalid client_id", http.StatusBadRequest)
return
}
redirectURI := r.URL.Query().Get("redirect_uri")
state := r.URL.Query().Get("state")
scope := r.URL.Query().Get("scope")
assert.NotEmpty(t, scope, "scope is empty")
responseType := r.URL.Query().Get("response_type")
switch responseType {
case "code":
case "token":
t.Errorf("response_type %q not supported", responseType)
http.Error(rw, "invalid response_type", http.StatusBadRequest)
return
default:
t.Errorf("unexpected response_type %q", responseType)
http.Error(rw, "invalid response_type", http.StatusBadRequest)
return
}
err := f.hookValidRedirectURL(redirectURI)
if err != nil {
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
ru, err := url.Parse(redirectURI)
if err != nil {
t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error())
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
return
}
q := ru.Query()
q.Set("state", state)
q.Set("code", f.newCode(state))
ru.RawQuery = q.Encode()
http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect)
}))
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
var values url.Values
var err error
if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" {
values = r.URL.Query()
} else {
values, err = f.authenticateOIDCClientRequest(t, r)
}
f.logger.Info(r.Context(), "http idp call token",
slog.F("url", r.URL.String()),
slog.F("valid", err == nil),
slog.F("grant_type", values.Get("grant_type")),
slog.F("values", values.Encode()),
)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
getEmail := func(claims jwt.MapClaims) string {
email, ok := claims["email"]
if !ok {
return "unknown"
}
emailStr, ok := email.(string)
if !ok {
return "wrong-type"
}
return emailStr
}
var claims jwt.MapClaims
switch values.Get("grant_type") {
case "authorization_code":
code := values.Get("code")
if !assert.NotEmpty(t, code, "code is empty") {
http.Error(rw, "invalid code", http.StatusBadRequest)
return
}
stateStr, ok := f.codeToStateMap.Load(code)
if !assert.True(t, ok, "invalid code") {
http.Error(rw, "invalid code", http.StatusBadRequest)
return
}
// Always invalidate the code after it is used.
f.codeToStateMap.Delete(code)
idTokenClaims, ok := f.getClaims(f.stateToIDTokenClaims, stateStr)
if !ok {
t.Errorf("missing id token claims")
http.Error(rw, "missing id token claims", http.StatusBadRequest)
return
}
claims = idTokenClaims
case "refresh_token":
refreshToken := values.Get("refresh_token")
if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") {
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
return
}
_, ok := f.refreshTokens.Load(refreshToken)
if !assert.True(t, ok, "invalid refresh_token") {
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
return
}
idTokenClaims, ok := f.getClaims(f.refreshIDTokenClaims, refreshToken)
if !ok {
t.Errorf("missing id token claims in refresh")
http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest)
return
}
claims = idTokenClaims
err := f.hookOnRefresh(getEmail(claims))
if err != nil {
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
f.refreshTokensUsed.Store(refreshToken, true)
// Always invalidate the refresh token after it is used.
f.refreshTokens.Delete(refreshToken)
case "urn:ietf:params:oauth:grant-type:device_code":
// Device flow
var resp externalauth.ExchangeDeviceCodeResponse
deviceCode := values.Get("device_code")
if deviceCode == "" {
resp.Error = "invalid_request"
resp.ErrorDescription = "missing device_code"
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
return
}
deviceFlow, ok := f.deviceCode.Load(deviceCode)
if !ok {
resp.Error = "invalid_request"
resp.ErrorDescription = "device_code provided not found"
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
return
}
if !deviceFlow.granted {
// Status code ok with the error as pending.
resp.Error = "authorization_pending"
resp.ErrorDescription = ""
httpapi.Write(r.Context(), rw, http.StatusOK, resp)
return
}
// Would be nice to get an actual email here.
claims = jwt.MapClaims{
"email": "unknown-dev-auth",
}
default:
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
return
}
exp := time.Now().Add(f.defaultExpire)
claims["exp"] = exp.UnixMilli()
email := getEmail(claims)
refreshToken := f.newRefreshTokens(email)
token := map[string]interface{}{
"access_token": f.newToken(email, exp),
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": int64((f.defaultExpire).Seconds()),
"id_token": f.encodeClaims(t, claims),
}
if f.hookMutateToken != nil {
f.hookMutateToken(token)
}
// Store the claims for the next refresh
f.refreshIDTokenClaims.Store(refreshToken, claims)
mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept"))
if mediaType == "application/x-www-form-urlencoded" {
// This val encode might not work for some data structures.
// It's good enough for now...
rw.Header().Set("Content-Type", "application/x-www-form-urlencoded")
vals := url.Values{}
for k, v := range token {
vals.Set(k, fmt.Sprintf("%v", v))
}
_, _ = rw.Write([]byte(vals.Encode()))
return
}
// Default to json since the oauth2 package doesn't use Accept headers.
if mediaType == "application/json" || mediaType == "" {
rw.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(rw).Encode(token)
return
}
// If we get something we don't support, throw an error.
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "'Accept' header contains unsupported media type",
Detail: fmt.Sprintf("Found %q", mediaType),
})
}))
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
token, err := f.authenticateBearerTokenRequest(t, r)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized)
return "", false
}
authToken, ok := f.accessTokens.Load(token)
if !ok {
t.Errorf("access token user for user_info has no email to indicate which user")
http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized)
return "", false
}
if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
http.Error(rw, "auth token expired", http.StatusUnauthorized)
return "", false
}
return authToken.email, true
}
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email, ok := validateMW(rw, r)
f.logger.Info(r.Context(), "http userinfo endpoint",
slog.F("valid", ok),
slog.F("email", email),
)
if !ok {
return
}
claims, err := f.hookUserInfo(email)
if err != nil {
http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
_ = json.NewEncoder(rw).Encode(claims)
}))
// There is almost no difference between this and /userinfo.
// The main tweak is that this route is "mounted" vs "handle" because "/userinfo"
// should be strict, and this one needs to handle sub routes.
mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email, ok := validateMW(rw, r)
f.logger.Info(r.Context(), "http external auth validate",
slog.F("valid", ok),
slog.F("email", email),
)
if !ok {
return
}
if f.externalAuthValidate == nil {
t.Errorf("missing external auth validate handler")
http.Error(rw, "missing external auth validate handler", http.StatusBadRequest)
return
}
f.externalAuthValidate(email, rw, r)
}))
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call idp /keys")
set := jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: f.key.Public(),
KeyID: "test-key",
Algorithm: "RSA",
},
},
}
_ = json.NewEncoder(rw).Encode(set)
}))
mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call device verify")
inputParam := "user_input"
userInput := r.URL.Query().Get(inputParam)
if userInput == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid user input",
Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam),
})
return
}
deviceCode := r.URL.Query().Get("device_code")
if deviceCode == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Hit this url again with ?device_code=<device_code>",
})
return
}
flow, ok := f.deviceCode.Load(deviceCode)
if !ok {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Device code not found.",
})
return
}
if time.Now().After(flow.exp) {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Device code expired.",
})
return
}
if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "user code does not match",
})
return
}
f.deviceCode.Store(deviceCode, deviceFlow{
userInput: flow.userInput,
exp: flow.exp,
granted: true,
})
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
Message: "Device authenticated!",
})
}))
mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call device auth")
p := httpapi.NewQueryParamParser()
p.RequiredNotEmpty("client_id")
clientID := p.String(r.URL.Query(), "", "client_id")
_ = p.String(r.URL.Query(), "", "scopes")
if len(p.Errors) > 0 {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query params",
Validations: p.Errors,
})
return
}
if clientID != f.clientID {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid client id",
})
return
}
deviceCode := uuid.NewString()
lifetime := time.Second * 900
flow := deviceFlow{
//nolint:gosec
userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8),
}
f.deviceCode.Store(deviceCode, deviceFlow{
userInput: flow.userInput,
exp: time.Now().Add(lifetime),
})
verifyURL := f.issuerURL.ResolveReference(&url.URL{
Path: deviceVerify,
RawQuery: url.Values{
"device_code": {deviceCode},
"user_input": {flow.userInput},
}.Encode(),
}).String()
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{
"device_code": deviceCode,
"user_code": flow.userInput,
"verification_uri": verifyURL,
"expires_in": int(lifetime.Seconds()),
"interval": 3,
})
return
}
// By default, GitHub form encodes these.
_, _ = fmt.Fprint(rw, url.Values{
"device_code": {deviceCode},
"user_code": {flow.userInput},
"verification_uri": {verifyURL},
"expires_in": {strconv.Itoa(int(lifetime.Seconds()))},
"interval": {"3"},
}.Encode())
}))
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
})
return mux
}
// HTTPClient does nothing if IsServing is used.
//
// If IsServing is not used, then it will return a client that will make requests
// to the IDP all in memory. If a request is not to the IDP, then the passed in
// client will be used. If no client is passed in, then any regular network
// requests will fail.
func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
if f.serve {
if rest == nil || rest.Transport == nil {
return &http.Client{}
}
return rest
}
var jar http.CookieJar
if rest != nil {
jar = rest.Jar
}
return &http.Client{
Jar: jar,
Transport: fakeRoundTripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
u, _ := url.Parse(f.issuer)
if req.URL.Host != u.Host {
if f.fakeCoderd != nil {
return f.fakeCoderd(req)
}
if rest == nil || rest.Transport == nil {
return nil, xerrors.Errorf("unexpected network request to %q", req.URL.Host)
}
return rest.Transport.RoundTrip(req)
}
resp := httptest.NewRecorder()
f.handler.ServeHTTP(resp, req)
return resp.Result(), nil
},
},
}
}
// RefreshUsed returns if the refresh token has been used. All refresh tokens
// can only be used once, then they are deleted.
func (f *FakeIDP) RefreshUsed(refreshToken string) bool {
used, _ := f.refreshTokensUsed.Load(refreshToken)
return used
}
// UpdateRefreshClaims allows the caller to change what claims are returned
// for a given refresh token. By default, all refreshes use the same claims as
// the original IDToken issuance.
func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) {
f.refreshIDTokenClaims.Store(refreshToken, claims)
}
// SetRedirect is required for the IDP to know where to redirect and call
// Coderd.
func (f *FakeIDP) SetRedirect(t testing.TB, u string) {
t.Helper()
f.cfg.RedirectURL = u
}
// SetCoderdCallback is optional and only works if not using the IsServing.
// It will setup a fake "Coderd" for the IDP to call when the IDP redirects
// back after authenticating.
func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) {
if f.serve {
panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'")
}
f.fakeCoderd = callback
}
func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) {
f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) {
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
return resp.Result(), nil
})
}
// ExternalAuthConfigOptions exists to provide additional functionality ontop
// of the standard "validate" url. Some providers like github we actually parse
// the response from the validate URL to gain additional information.
type ExternalAuthConfigOptions struct {
// ValidatePayload is the payload that is used when the user calls the
// equivalent of "userinfo" for oauth2. This is not standardized, so is
// different for each provider type.
ValidatePayload func(email string) interface{}
// routes is more advanced usage. This allows the caller to
// completely customize the response. It captures all routes under the /external-auth-validate/*
// so the caller can do whatever they want and even add routes.
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)
UseDeviceAuth bool
}
func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
if route == "/" || route == "" || route == "/user" {
panic("cannot override the /user route. Use ValidatePayload instead")
}
if !strings.HasPrefix(route, "/") {
route = "/" + route
}
if o.routes == nil {
o.routes = make(map[string]func(email string, rw http.ResponseWriter, r *http.Request))
}
o.routes[route] = handle
return o
}
// ExternalAuthConfig is the config for external auth providers.
func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAuthConfigOptions, opts ...func(cfg *externalauth.Config)) *externalauth.Config {
if custom == nil {
custom = &ExternalAuthConfigOptions{}
}
f.externalProviderID = id
f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) {
newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate")
switch newPath {
// /user is ALWAYS supported under the `/` path too.
case "/user", "/", "":
var payload interface{} = "OK"
if custom.ValidatePayload != nil {
payload = custom.ValidatePayload(email)
}
_ = json.NewEncoder(rw).Encode(payload)
default:
if custom.routes == nil {
custom.routes = make(map[string]func(email string, rw http.ResponseWriter, r *http.Request))
}
handle, ok := custom.routes[newPath]
if !ok {
t.Errorf("missing route handler for %s", newPath)
http.Error(rw, fmt.Sprintf("missing route handler for %s", newPath), http.StatusBadRequest)
return
}
handle(email, rw, r)
}
}
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil))
cfg := &externalauth.Config{
DisplayName: id,
InstrumentedOAuth2Config: oauthCfg,
ID: id,
// No defaults for these fields by omitting the type
Type: "",
DisplayIcon: f.WellknownConfig().UserInfoURL,
// Omit the /user for the validate so we can easily append to it when modifying
// the cfg for advanced tests.
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
DeviceAuth: &externalauth.DeviceAuth{
Config: oauthCfg,
ClientID: f.clientID,
TokenURL: f.provider.TokenURL,
Scopes: []string{},
CodeURL: f.provider.DeviceCodeURL,
},
}
if !custom.UseDeviceAuth {
cfg.DeviceAuth = nil
}
for _, opt := range opts {
opt(cfg)
}
f.updateIssuerURL(t, f.issuer)
return cfg
}
func (f *FakeIDP) AppCredentials() (clientID string, clientSecret string) {
return f.clientID, f.clientSecret
}
// OIDCConfig returns the OIDC config to use for Coderd.
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
t.Helper()
if len(scopes) == 0 {
scopes = []string{"openid", "email", "profile"}
}
oauthCfg := &oauth2.Config{
ClientID: f.clientID,
ClientSecret: f.clientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: f.provider.AuthURL,
TokenURL: f.provider.TokenURL,
AuthStyle: oauth2.AuthStyleInParams,
},
// If the user is using a real network request, they will need to do
// 'fake.SetRedirect()'
RedirectURL: "https://redirect.com",
Scopes: scopes,
}
ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil))
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
require.NoError(t, err, "failed to create OIDC provider")
cfg := &coderd.OIDCConfig{
OAuth2Config: oauthCfg,
Provider: p,
Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{f.key.Public()},
}, &oidc.Config{
ClientID: oauthCfg.ClientID,
SupportedSigningAlgs: []string{
"RS256",
},
// Todo: add support for Now()
}),
UsernameField: "preferred_username",
EmailField: "email",
AuthURLParams: map[string]string{"access_type": "offline"},
}
for _, opt := range opts {
if opt == nil {
continue
}
opt(cfg)
}
f.cfg = oauthCfg
return cfg
}
func (f *FakeIDP) getClaims(m *syncmap.Map[string, jwt.MapClaims], key string) (jwt.MapClaims, bool) {
v, ok := m.Load(key)
if !ok {
if f.defaultIDClaims != nil {
return f.defaultIDClaims, true
}
return nil, false
}
return v, true
}
func httpErrorCode(defaultCode int, err error) int {
var stautsErr statusHookError
status := defaultCode
if errors.As(err, &stautsErr) {
status = stautsErr.HTTPStatusCode
}
return status
}
type fakeRoundTripper struct {
roundTrip func(req *http.Request) (*http.Response, error)
}
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return f.roundTrip(req)
}
//nolint:gosec // these are test credentials
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
-----END RSA PRIVATE KEY-----`