package oidctest import ( "context" "crypto" "crypto/rsa" "crypto/x509" "encoding/json" "encoding/pem" "errors" "fmt" "io" "math/rand" "mime" "net" "net/http" "net/http/cookiejar" "net/http/httptest" "net/http/httputil" "net/url" "strconv" "strings" "testing" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/go-chi/chi/v5" "github.com/go-jose/go-jose/v3" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/util/syncmap" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) type token struct { issued time.Time email string exp time.Time } type deviceFlow struct { // userInput is the expected input to authenticate the device flow. userInput string exp time.Time granted bool } // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { issuer string issuerURL *url.URL key *rsa.PrivateKey provider ProviderJSON handler http.Handler cfg *oauth2.Config // callbackPath allows changing where the callback path to coderd is expected. // This only affects using the Login helper functions. callbackPath string // clientID to be used by coderd clientID string clientSecret string // externalProviderID is optional to match the provider in coderd for // redirectURLs. externalProviderID string logger slog.Logger // externalAuthValidate will be called when the user tries to validate their // external auth. The fake IDP will reject any invalid tokens, so this just // controls the response payload after a successfully authed token. externalAuthValidate func(email string, rw http.ResponseWriter, r *http.Request) // These maps are used to control the state of the IDP. // That is the various access tokens, refresh tokens, states, etc. codeToStateMap *syncmap.Map[string, string] // Token -> Email accessTokens *syncmap.Map[string, token] // Refresh Token -> Email refreshTokensUsed *syncmap.Map[string, bool] refreshTokens *syncmap.Map[string, string] stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims] refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims] // Device flow deviceCode *syncmap.Map[string, deviceFlow] // hooks // hookValidRedirectURL can be used to reject a redirect url from the // IDP -> Application. Almost all IDPs have the concept of // "Authorized Redirect URLs". This can be used to emulate that. hookValidRedirectURL func(redirectURL string) error hookUserInfo func(email string) (jwt.MapClaims, error) // defaultIDClaims is if a new client connects and we didn't preset // some claims. defaultIDClaims jwt.MapClaims hookMutateToken func(token map[string]interface{}) fakeCoderd func(req *http.Request) (*http.Response, error) hookOnRefresh func(email string) error // Custom authentication for the client. This is useful if you want // to test something like PKI auth vs a client_secret. hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) serve bool // optional middlewares middlewares chi.Middlewares defaultExpire time.Duration } func StatusError(code int, err error) error { return statusHookError{ Err: err, HTTPStatusCode: code, } } // statusHookError allows a hook to change the returned http status code. type statusHookError struct { Err error HTTPStatusCode int } func (s statusHookError) Error() string { if s.Err == nil { return "" } return s.Err.Error() } type FakeIDPOpt func(idp *FakeIDP) func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) { return func(f *FakeIDP) { f.hookValidRedirectURL = hook } } func WithMiddlewares(mws ...func(http.Handler) http.Handler) func(*FakeIDP) { return func(f *FakeIDP) { f.middlewares = append(f.middlewares, mws...) } } // WithRefresh is called when a refresh token is used. The email is // the email of the user that is being refreshed assuming the claims are correct. func WithRefresh(hook func(email string) error) func(*FakeIDP) { return func(f *FakeIDP) { f.hookOnRefresh = hook } } func WithDefaultExpire(d time.Duration) func(*FakeIDP) { return func(f *FakeIDP) { f.defaultExpire = d } } func WithCallbackPath(path string) func(*FakeIDP) { return func(f *FakeIDP) { f.callbackPath = path } } func WithStaticCredentials(id, secret string) func(*FakeIDP) { return func(f *FakeIDP) { if id != "" { f.clientID = id } if secret != "" { f.clientSecret = secret } } } // WithExtra returns extra fields that be accessed on the returned Oauth Token. // These extra fields can override the default fields (id_token, access_token, etc). func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) { return func(f *FakeIDP) { f.hookMutateToken = mutateToken } } func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) { return func(f *FakeIDP) { f.hookAuthenticateClient = hook } } // WithLogging is optional, but will log some HTTP calls made to the IDP. func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { return func(f *FakeIDP) { f.logger = slogtest.Make(t, options) } } func WithLogger(logger slog.Logger) func(*FakeIDP) { return func(f *FakeIDP) { f.logger = logger } } // WithStaticUserInfo is optional, but will return the same user info for // every user on the /userinfo endpoint. func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { return func(f *FakeIDP) { f.hookUserInfo = func(_ string) (jwt.MapClaims, error) { return info, nil } } } func WithDefaultIDClaims(claims jwt.MapClaims) func(*FakeIDP) { return func(f *FakeIDP) { f.defaultIDClaims = claims } } func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) { return func(f *FakeIDP) { f.hookUserInfo = userInfoFunc } } // WithServing makes the IDP run an actual http server. func WithServing() func(*FakeIDP) { return func(f *FakeIDP) { f.serve = true } } func WithIssuer(issuer string) func(*FakeIDP) { return func(f *FakeIDP) { f.issuer = issuer } } type With429Arguments struct { AllPaths bool TokenPath bool AuthorizePath bool KeysPath bool UserInfoPath bool DeviceAuth bool DeviceVerify bool } // With429 will emulate a 429 response for the selected paths. func With429(params With429Arguments) func(*FakeIDP) { return func(f *FakeIDP) { f.middlewares = append(f.middlewares, func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { if params.AllPaths { http.Error(rw, "429, being manually blocked (all)", http.StatusTooManyRequests) return } if params.TokenPath && strings.Contains(r.URL.Path, tokenPath) { http.Error(rw, "429, being manually blocked (token)", http.StatusTooManyRequests) return } if params.AuthorizePath && strings.Contains(r.URL.Path, authorizePath) { http.Error(rw, "429, being manually blocked (authorize)", http.StatusTooManyRequests) return } if params.KeysPath && strings.Contains(r.URL.Path, keysPath) { http.Error(rw, "429, being manually blocked (keys)", http.StatusTooManyRequests) return } if params.UserInfoPath && strings.Contains(r.URL.Path, userInfoPath) { http.Error(rw, "429, being manually blocked (userinfo)", http.StatusTooManyRequests) return } if params.DeviceAuth && strings.Contains(r.URL.Path, deviceAuth) { http.Error(rw, "429, being manually blocked (device-auth)", http.StatusTooManyRequests) return } if params.DeviceVerify && strings.Contains(r.URL.Path, deviceVerify) { http.Error(rw, "429, being manually blocked (device-verify)", http.StatusTooManyRequests) return } next.ServeHTTP(rw, r) }) }) } } const ( // nolint:gosec // It thinks this is a secret lol tokenPath = "/oauth2/token" authorizePath = "/oauth2/authorize" keysPath = "/oauth2/keys" userInfoPath = "/oauth2/userinfo" deviceAuth = "/login/device/code" deviceVerify = "/login/device" ) func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { t.Helper() block, _ := pem.Decode([]byte(testRSAPrivateKey)) pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) require.NoError(t, err) idp := &FakeIDP{ key: pkey, clientID: uuid.NewString(), clientSecret: uuid.NewString(), logger: slog.Make(), codeToStateMap: syncmap.New[string, string](), accessTokens: syncmap.New[string, token](), refreshTokens: syncmap.New[string, string](), refreshTokensUsed: syncmap.New[string, bool](), stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), deviceCode: syncmap.New[string, deviceFlow](), hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil }, hookValidRedirectURL: func(redirectURL string) error { return nil }, defaultExpire: time.Minute * 5, } for _, opt := range opts { opt(idp) } if idp.issuer == "" { idp.issuer = "https://coder.com" } idp.handler = idp.httpHandler(t) idp.updateIssuerURL(t, idp.issuer) if idp.serve { idp.realServer(t) } return idp } func (f *FakeIDP) WellknownConfig() ProviderJSON { return f.provider } func (f *FakeIDP) IssuerURL() *url.URL { return f.issuerURL } func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { t.Helper() u, err := url.Parse(issuer) require.NoError(t, err, "invalid issuer URL") f.issuer = issuer f.issuerURL = u // ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. f.provider = ProviderJSON{ Issuer: issuer, AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(), UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(), DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(), Algorithms: []string{ "RS256", }, ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(), } } // realServer turns the FakeIDP into a real http server. func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { t.Helper() srvURL := "localhost:0" issURL, err := url.Parse(f.issuer) if err == nil { if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" { srvURL = issURL.Host } } l, err := net.Listen("tcp", srvURL) require.NoError(t, err, "failed to create listener") ctx, cancel := context.WithCancel(context.Background()) srv := &httptest.Server{ Listener: l, Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5}, } srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx } srv.Start() t.Cleanup(srv.CloseClientConnections) t.Cleanup(srv.Close) t.Cleanup(cancel) f.updateIssuerURL(t, srv.URL) return srv } // GenerateAuthenticatedToken skips all oauth2 flows, and just generates a // valid token for some given claims. func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) { state := uuid.NewString() f.stateToIDTokenClaims.Store(state, claims) code := f.newCode(state) return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code) } // Login does the full OIDC flow starting at the "LoginButton". // The client argument is just to get the URL of the Coder instance. // // The client passed in is just to get the url of the Coder instance. // The actual client that is used is 100% unauthenticated and fresh. func (f *FakeIDP) Login(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { t.Helper() client, resp := f.AttemptLogin(t, client, idTokenClaims, opts...) if resp.StatusCode != http.StatusOK { data, err := httputil.DumpResponse(resp, true) if err == nil { t.Logf("Attempt Login response payload\n%s", string(data)) } } require.Equal(t, http.StatusOK, resp.StatusCode, "client failed to login") return client, resp } func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { t.Helper() var err error cli := f.HTTPClient(client.HTTPClient) shallowCpyCli := *cli if shallowCpyCli.Jar == nil { shallowCpyCli.Jar, err = cookiejar.New(nil) require.NoError(t, err, "failed to create cookie jar") } unauthenticated := codersdk.New(client.URL) unauthenticated.HTTPClient = &shallowCpyCli return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...) } // LoginWithClient reuses the context of the passed in client. This means the same // cookies will be used. This should be an unauthenticated client in most cases. // // This is a niche case, but it is needed for testing ConvertLoginType. func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { t.Helper() path := "/api/v2/users/oidc/callback" if f.callbackPath != "" { path = f.callbackPath } coderOauthURL, err := client.URL.Parse(path) require.NoError(t, err) f.SetRedirect(t, coderOauthURL.String()) cli := f.HTTPClient(client.HTTPClient) cli.CheckRedirect = func(req *http.Request, via []*http.Request) error { // Store the idTokenClaims to the specific state request. This ties // the claims 1:1 with a given authentication flow. state := req.URL.Query().Get("state") f.stateToIDTokenClaims.Store(state, idTokenClaims) return nil } req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil) require.NoError(t, err) if cli.Jar == nil { cli.Jar, err = cookiejar.New(nil) require.NoError(t, err, "failed to create cookie jar") } for _, opt := range opts { opt(req) } res, err := cli.Do(req) require.NoError(t, err) // If the coder session token exists, return the new authed client! var user *codersdk.Client cookies := cli.Jar.Cookies(client.URL) for _, cookie := range cookies { if cookie.Name == codersdk.SessionTokenCookie { user = codersdk.New(client.URL) user.SetSessionToken(cookie.Value) } } t.Cleanup(func() { if res.Body != nil { _ = res.Body.Close() } }) return user, res } // ExternalLogin does the oauth2 flow for external auth providers. This requires // an authenticated coder client. func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) { coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID)) require.NoError(t, err) f.SetRedirect(t, coderOauthURL.String()) cli := f.HTTPClient(client.HTTPClient) cli.CheckRedirect = func(req *http.Request, via []*http.Request) error { // Store the idTokenClaims to the specific state request. This ties // the claims 1:1 with a given authentication flow. state := req.URL.Query().Get("state") f.stateToIDTokenClaims.Store(state, jwt.MapClaims{}) return nil } ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil) require.NoError(t, err) // External auth flow requires the user be authenticated. headerName := client.SessionTokenHeader if headerName == "" { headerName = codersdk.SessionTokenHeader } req.Header.Set(headerName, client.SessionToken()) if cli.Jar == nil { cli.Jar, err = cookiejar.New(nil) require.NoError(t, err, "failed to create cookie jar") } for _, opt := range opts { opt(req) } res, err := cli.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, res.StatusCode, "client failed to login") _ = res.Body.Close() } // DeviceLogin does the oauth2 device flow for external auth providers. func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) { // First we need to initiate the device flow. This will have Coder hit the // fake IDP and get a device code. device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID) require.NoError(t, err) // Now the user needs to go to the fake IDP page and click "allow" and enter // the device code input. For our purposes, we just send an http request to // the verification url. No additional user input is needed. ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil) require.NoError(t, err) defer resp.Body.Close() // Now we need to exchange the device code for an access token. We do this // in this method because it is the user that does the polling for the device // auth flow, not the backend. err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{ DeviceCode: device.DeviceCode, }) require.NoError(t, err) } // CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing // unit tests, it's easier to skip this step sometimes. It does make an actual // request to the IDP, so it should be equivalent to doing this "manually" with // actual requests. func (f *FakeIDP) CreateAuthCode(t testing.TB, state string) string { // We need to store some claims, because this is also an OIDC provider, and // it expects some claims to be present. f.stateToIDTokenClaims.Store(state, jwt.MapClaims{}) code, err := OAuth2GetCode(f.cfg.AuthCodeURL(state), func(req *http.Request) (*http.Response, error) { rw := httptest.NewRecorder() f.handler.ServeHTTP(rw, req) resp := rw.Result() return resp, nil }) require.NoError(t, err, "failed to get auth code") return code } // OIDCCallback will emulate the IDP redirecting back to the Coder callback. // This is helpful if no Coderd exists because the IDP needs to redirect to // something. // Essentially this is used to fake the Coderd side of the exchange. // The flow starts at the user hitting the OIDC login page. func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) { t.Helper() if f.serve { panic("cannot use OIDCCallback with WithServing. This is only for the in memory usage") } f.stateToIDTokenClaims.Store(state, idTokenClaims) cli := f.HTTPClient(nil) u := f.cfg.AuthCodeURL(state) req, err := http.NewRequest("GET", u, nil) require.NoError(t, err) resp, err := cli.Do(req.WithContext(context.Background())) require.NoError(t, err) t.Cleanup(func() { if resp.Body != nil { _ = resp.Body.Close() } }) return resp, nil } // ProviderJSON is the .well-known/configuration JSON type ProviderJSON struct { Issuer string `json:"issuer"` AuthURL string `json:"authorization_endpoint"` TokenURL string `json:"token_endpoint"` JWKSURL string `json:"jwks_uri"` UserInfoURL string `json:"userinfo_endpoint"` DeviceCodeURL string `json:"device_authorization_endpoint"` Algorithms []string `json:"id_token_signing_alg_values_supported"` // This is custom ExternalAuthURL string `json:"external_auth_url"` } // newCode enforces the code exchanged is actually a valid code // created by the IDP. func (f *FakeIDP) newCode(state string) string { code := uuid.NewString() f.codeToStateMap.Store(code, state) return code } // newToken enforces the access token exchanged is actually a valid access token // created by the IDP. func (f *FakeIDP) newToken(email string, expires time.Time) string { accessToken := uuid.NewString() f.accessTokens.Store(accessToken, token{ issued: time.Now(), email: email, exp: expires, }) return accessToken } func (f *FakeIDP) newRefreshTokens(email string) string { refreshToken := uuid.NewString() f.refreshTokens.Store(refreshToken, email) return refreshToken } // authenticateBearerTokenRequest enforces the access token is valid. func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) { t.Helper() auth := req.Header.Get("Authorization") token := strings.TrimPrefix(auth, "Bearer ") authToken, ok := f.accessTokens.Load(token) if !ok { return "", xerrors.New("invalid access token") } if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) { return "", xerrors.New("access token expired") } return token, nil } // authenticateOIDCClientRequest enforces the client_id and client_secret are valid. func (f *FakeIDP) authenticateOIDCClientRequest(t testing.TB, req *http.Request) (url.Values, error) { t.Helper() if f.hookAuthenticateClient != nil { return f.hookAuthenticateClient(t, req) } data, err := io.ReadAll(req.Body) if !assert.NoError(t, err, "read token request body") { return nil, xerrors.Errorf("authenticate request, read body: %w", err) } values, err := url.ParseQuery(string(data)) if !assert.NoError(t, err, "parse token request values") { return nil, xerrors.New("invalid token request") } if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") { return nil, xerrors.New("client_id mismatch") } if !assert.Equal(t, f.clientSecret, values.Get("client_secret"), "client_secret mismatch") { return nil, xerrors.New("client_secret mismatch") } return values, nil } // encodeClaims is a helper func to convert claims to a valid JWT. func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string { t.Helper() if _, ok := claims["exp"]; !ok { claims["exp"] = time.Now().Add(time.Hour).UnixMilli() } if _, ok := claims["aud"]; !ok { claims["aud"] = f.clientID } if _, ok := claims["iss"]; !ok { claims["iss"] = f.issuer } signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key) require.NoError(t, err) return signed } // httpHandler is the IDP http server. func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { t.Helper() mux := chi.NewMux() mux.Use(f.middlewares...) // This endpoint is required to initialize the OIDC provider. // It is used to get the OIDC configuration. mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http OIDC config", slog.F("url", r.URL.String())) _ = json.NewEncoder(rw).Encode(f.provider) }) // Authorize is called when the user is redirected to the IDP to login. // This is the browser hitting the IDP and the user logging into Google or // w/e and clicking "Allow". They will be redirected back to the redirect // when this is done. mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http call authorize", slog.F("url", r.URL.String())) clientID := r.URL.Query().Get("client_id") if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") { http.Error(rw, "invalid client_id", http.StatusBadRequest) return } redirectURI := r.URL.Query().Get("redirect_uri") state := r.URL.Query().Get("state") scope := r.URL.Query().Get("scope") assert.NotEmpty(t, scope, "scope is empty") responseType := r.URL.Query().Get("response_type") switch responseType { case "code": case "token": t.Errorf("response_type %q not supported", responseType) http.Error(rw, "invalid response_type", http.StatusBadRequest) return default: t.Errorf("unexpected response_type %q", responseType) http.Error(rw, "invalid response_type", http.StatusBadRequest) return } err := f.hookValidRedirectURL(redirectURI) if err != nil { t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error()) http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } ru, err := url.Parse(redirectURI) if err != nil { t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error()) http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest) return } q := ru.Query() q.Set("state", state) q.Set("code", f.newCode(state)) ru.RawQuery = q.Encode() http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect) })) mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { var values url.Values var err error if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" { values = r.URL.Query() } else { values, err = f.authenticateOIDCClientRequest(t, r) } f.logger.Info(r.Context(), "http idp call token", slog.F("url", r.URL.String()), slog.F("valid", err == nil), slog.F("grant_type", values.Get("grant_type")), slog.F("values", values.Encode()), ) if err != nil { http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } getEmail := func(claims jwt.MapClaims) string { email, ok := claims["email"] if !ok { return "unknown" } emailStr, ok := email.(string) if !ok { return "wrong-type" } return emailStr } var claims jwt.MapClaims switch values.Get("grant_type") { case "authorization_code": code := values.Get("code") if !assert.NotEmpty(t, code, "code is empty") { http.Error(rw, "invalid code", http.StatusBadRequest) return } stateStr, ok := f.codeToStateMap.Load(code) if !assert.True(t, ok, "invalid code") { http.Error(rw, "invalid code", http.StatusBadRequest) return } // Always invalidate the code after it is used. f.codeToStateMap.Delete(code) idTokenClaims, ok := f.getClaims(f.stateToIDTokenClaims, stateStr) if !ok { t.Errorf("missing id token claims") http.Error(rw, "missing id token claims", http.StatusBadRequest) return } claims = idTokenClaims case "refresh_token": refreshToken := values.Get("refresh_token") if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") { http.Error(rw, "invalid refresh_token", http.StatusBadRequest) return } _, ok := f.refreshTokens.Load(refreshToken) if !assert.True(t, ok, "invalid refresh_token") { http.Error(rw, "invalid refresh_token", http.StatusBadRequest) return } idTokenClaims, ok := f.getClaims(f.refreshIDTokenClaims, refreshToken) if !ok { t.Errorf("missing id token claims in refresh") http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest) return } claims = idTokenClaims err := f.hookOnRefresh(getEmail(claims)) if err != nil { http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } f.refreshTokensUsed.Store(refreshToken, true) // Always invalidate the refresh token after it is used. f.refreshTokens.Delete(refreshToken) case "urn:ietf:params:oauth:grant-type:device_code": // Device flow var resp externalauth.ExchangeDeviceCodeResponse deviceCode := values.Get("device_code") if deviceCode == "" { resp.Error = "invalid_request" resp.ErrorDescription = "missing device_code" httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp) return } deviceFlow, ok := f.deviceCode.Load(deviceCode) if !ok { resp.Error = "invalid_request" resp.ErrorDescription = "device_code provided not found" httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp) return } if !deviceFlow.granted { // Status code ok with the error as pending. resp.Error = "authorization_pending" resp.ErrorDescription = "" httpapi.Write(r.Context(), rw, http.StatusOK, resp) return } // Would be nice to get an actual email here. claims = jwt.MapClaims{ "email": "unknown-dev-auth", } default: t.Errorf("unexpected grant_type %q", values.Get("grant_type")) http.Error(rw, "invalid grant_type", http.StatusBadRequest) return } exp := time.Now().Add(f.defaultExpire) claims["exp"] = exp.UnixMilli() email := getEmail(claims) refreshToken := f.newRefreshTokens(email) token := map[string]interface{}{ "access_token": f.newToken(email, exp), "refresh_token": refreshToken, "token_type": "Bearer", "expires_in": int64((f.defaultExpire).Seconds()), "id_token": f.encodeClaims(t, claims), } if f.hookMutateToken != nil { f.hookMutateToken(token) } // Store the claims for the next refresh f.refreshIDTokenClaims.Store(refreshToken, claims) mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")) if mediaType == "application/x-www-form-urlencoded" { // This val encode might not work for some data structures. // It's good enough for now... rw.Header().Set("Content-Type", "application/x-www-form-urlencoded") vals := url.Values{} for k, v := range token { vals.Set(k, fmt.Sprintf("%v", v)) } _, _ = rw.Write([]byte(vals.Encode())) return } // Default to json since the oauth2 package doesn't use Accept headers. if mediaType == "application/json" || mediaType == "" { rw.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(rw).Encode(token) return } // If we get something we don't support, throw an error. httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "'Accept' header contains unsupported media type", Detail: fmt.Sprintf("Found %q", mediaType), }) })) validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) { token, err := f.authenticateBearerTokenRequest(t, r) if err != nil { http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized) return "", false } authToken, ok := f.accessTokens.Load(token) if !ok { t.Errorf("access token user for user_info has no email to indicate which user") http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized) return "", false } if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) { http.Error(rw, "auth token expired", http.StatusUnauthorized) return "", false } return authToken.email, true } mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { email, ok := validateMW(rw, r) f.logger.Info(r.Context(), "http userinfo endpoint", slog.F("valid", ok), slog.F("email", email), ) if !ok { return } claims, err := f.hookUserInfo(email) if err != nil { http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err)) return } _ = json.NewEncoder(rw).Encode(claims) })) // There is almost no difference between this and /userinfo. // The main tweak is that this route is "mounted" vs "handle" because "/userinfo" // should be strict, and this one needs to handle sub routes. mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { email, ok := validateMW(rw, r) f.logger.Info(r.Context(), "http external auth validate", slog.F("valid", ok), slog.F("email", email), ) if !ok { return } if f.externalAuthValidate == nil { t.Errorf("missing external auth validate handler") http.Error(rw, "missing external auth validate handler", http.StatusBadRequest) return } f.externalAuthValidate(email, rw, r) })) mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http call idp /keys") set := jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { Key: f.key.Public(), KeyID: "test-key", Algorithm: "RSA", }, }, } _ = json.NewEncoder(rw).Encode(set) })) mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http call device verify") inputParam := "user_input" userInput := r.URL.Query().Get(inputParam) if userInput == "" { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid user input", Detail: fmt.Sprintf("Hit this url again with ?%s=", inputParam), }) return } deviceCode := r.URL.Query().Get("device_code") if deviceCode == "" { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid device code", Detail: "Hit this url again with ?device_code=", }) return } flow, ok := f.deviceCode.Load(deviceCode) if !ok { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid device code", Detail: "Device code not found.", }) return } if time.Now().After(flow.exp) { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid device code", Detail: "Device code expired.", }) return } if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid device code", Detail: "user code does not match", }) return } f.deviceCode.Store(deviceCode, deviceFlow{ userInput: flow.userInput, exp: flow.exp, granted: true, }) httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ Message: "Device authenticated!", }) })) mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { f.logger.Info(r.Context(), "http call device auth") p := httpapi.NewQueryParamParser() p.RequiredNotEmpty("client_id") clientID := p.String(r.URL.Query(), "", "client_id") _ = p.String(r.URL.Query(), "", "scopes") if len(p.Errors) > 0 { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid query params", Validations: p.Errors, }) return } if clientID != f.clientID { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid client id", }) return } deviceCode := uuid.NewString() lifetime := time.Second * 900 flow := deviceFlow{ //nolint:gosec userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8), } f.deviceCode.Store(deviceCode, deviceFlow{ userInput: flow.userInput, exp: time.Now().Add(lifetime), }) verifyURL := f.issuerURL.ResolveReference(&url.URL{ Path: deviceVerify, RawQuery: url.Values{ "device_code": {deviceCode}, "user_input": {flow.userInput}, }.Encode(), }).String() if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" { httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{ "device_code": deviceCode, "user_code": flow.userInput, "verification_uri": verifyURL, "expires_in": int(lifetime.Seconds()), "interval": 3, }) return } // By default, GitHub form encodes these. _, _ = fmt.Fprint(rw, url.Values{ "device_code": {deviceCode}, "user_code": {flow.userInput}, "verification_uri": {verifyURL}, "expires_in": {strconv.Itoa(int(lifetime.Seconds()))}, "interval": {"3"}, }.Encode()) })) mux.NotFound(func(rw http.ResponseWriter, r *http.Request) { f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path)) t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path) }) return mux } // HTTPClient does nothing if IsServing is used. // // If IsServing is not used, then it will return a client that will make requests // to the IDP all in memory. If a request is not to the IDP, then the passed in // client will be used. If no client is passed in, then any regular network // requests will fail. func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { if f.serve { if rest == nil || rest.Transport == nil { return &http.Client{} } return rest } var jar http.CookieJar if rest != nil { jar = rest.Jar } return &http.Client{ Jar: jar, Transport: fakeRoundTripper{ roundTrip: func(req *http.Request) (*http.Response, error) { u, _ := url.Parse(f.issuer) if req.URL.Host != u.Host { if f.fakeCoderd != nil { return f.fakeCoderd(req) } if rest == nil || rest.Transport == nil { return nil, xerrors.Errorf("unexpected network request to %q", req.URL.Host) } return rest.Transport.RoundTrip(req) } resp := httptest.NewRecorder() f.handler.ServeHTTP(resp, req) return resp.Result(), nil }, }, } } // RefreshUsed returns if the refresh token has been used. All refresh tokens // can only be used once, then they are deleted. func (f *FakeIDP) RefreshUsed(refreshToken string) bool { used, _ := f.refreshTokensUsed.Load(refreshToken) return used } // UpdateRefreshClaims allows the caller to change what claims are returned // for a given refresh token. By default, all refreshes use the same claims as // the original IDToken issuance. func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) { f.refreshIDTokenClaims.Store(refreshToken, claims) } // SetRedirect is required for the IDP to know where to redirect and call // Coderd. func (f *FakeIDP) SetRedirect(t testing.TB, u string) { t.Helper() f.cfg.RedirectURL = u } // SetCoderdCallback is optional and only works if not using the IsServing. // It will setup a fake "Coderd" for the IDP to call when the IDP redirects // back after authenticating. func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) { if f.serve { panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") } f.fakeCoderd = callback } func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) { f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) { resp := httptest.NewRecorder() handler.ServeHTTP(resp, req) return resp.Result(), nil }) } // ExternalAuthConfigOptions exists to provide additional functionality ontop // of the standard "validate" url. Some providers like github we actually parse // the response from the validate URL to gain additional information. type ExternalAuthConfigOptions struct { // ValidatePayload is the payload that is used when the user calls the // equivalent of "userinfo" for oauth2. This is not standardized, so is // different for each provider type. ValidatePayload func(email string) interface{} // routes is more advanced usage. This allows the caller to // completely customize the response. It captures all routes under the /external-auth-validate/* // so the caller can do whatever they want and even add routes. routes map[string]func(email string, rw http.ResponseWriter, r *http.Request) UseDeviceAuth bool } func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions { if route == "/" || route == "" || route == "/user" { panic("cannot override the /user route. Use ValidatePayload instead") } if !strings.HasPrefix(route, "/") { route = "/" + route } if o.routes == nil { o.routes = make(map[string]func(email string, rw http.ResponseWriter, r *http.Request)) } o.routes[route] = handle return o } // ExternalAuthConfig is the config for external auth providers. func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAuthConfigOptions, opts ...func(cfg *externalauth.Config)) *externalauth.Config { if custom == nil { custom = &ExternalAuthConfigOptions{} } f.externalProviderID = id f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) { newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate") switch newPath { // /user is ALWAYS supported under the `/` path too. case "/user", "/", "": var payload interface{} = "OK" if custom.ValidatePayload != nil { payload = custom.ValidatePayload(email) } _ = json.NewEncoder(rw).Encode(payload) default: if custom.routes == nil { custom.routes = make(map[string]func(email string, rw http.ResponseWriter, r *http.Request)) } handle, ok := custom.routes[newPath] if !ok { t.Errorf("missing route handler for %s", newPath) http.Error(rw, fmt.Sprintf("missing route handler for %s", newPath), http.StatusBadRequest) return } handle(email, rw, r) } } instrumentF := promoauth.NewFactory(prometheus.NewRegistry()) oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil)) cfg := &externalauth.Config{ DisplayName: id, InstrumentedOAuth2Config: oauthCfg, ID: id, // No defaults for these fields by omitting the type Type: "", DisplayIcon: f.WellknownConfig().UserInfoURL, // Omit the /user for the validate so we can easily append to it when modifying // the cfg for advanced tests. ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), DeviceAuth: &externalauth.DeviceAuth{ Config: oauthCfg, ClientID: f.clientID, TokenURL: f.provider.TokenURL, Scopes: []string{}, CodeURL: f.provider.DeviceCodeURL, }, } if !custom.UseDeviceAuth { cfg.DeviceAuth = nil } for _, opt := range opts { opt(cfg) } f.updateIssuerURL(t, f.issuer) return cfg } func (f *FakeIDP) AppCredentials() (clientID string, clientSecret string) { return f.clientID, f.clientSecret } // OIDCConfig returns the OIDC config to use for Coderd. func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { t.Helper() if len(scopes) == 0 { scopes = []string{"openid", "email", "profile"} } oauthCfg := &oauth2.Config{ ClientID: f.clientID, ClientSecret: f.clientSecret, Endpoint: oauth2.Endpoint{ AuthURL: f.provider.AuthURL, TokenURL: f.provider.TokenURL, AuthStyle: oauth2.AuthStyleInParams, }, // If the user is using a real network request, they will need to do // 'fake.SetRedirect()' RedirectURL: "https://redirect.com", Scopes: scopes, } ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil)) p, err := oidc.NewProvider(ctx, f.provider.Issuer) require.NoError(t, err, "failed to create OIDC provider") cfg := &coderd.OIDCConfig{ OAuth2Config: oauthCfg, Provider: p, Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{ PublicKeys: []crypto.PublicKey{f.key.Public()}, }, &oidc.Config{ ClientID: oauthCfg.ClientID, SupportedSigningAlgs: []string{ "RS256", }, // Todo: add support for Now() }), UsernameField: "preferred_username", EmailField: "email", AuthURLParams: map[string]string{"access_type": "offline"}, } for _, opt := range opts { if opt == nil { continue } opt(cfg) } f.cfg = oauthCfg return cfg } func (f *FakeIDP) getClaims(m *syncmap.Map[string, jwt.MapClaims], key string) (jwt.MapClaims, bool) { v, ok := m.Load(key) if !ok { if f.defaultIDClaims != nil { return f.defaultIDClaims, true } return nil, false } return v, true } func httpErrorCode(defaultCode int, err error) int { var stautsErr statusHookError status := defaultCode if errors.As(err, &stautsErr) { status = stautsErr.HTTPStatusCode } return status } type fakeRoundTripper struct { roundTrip func(req *http.Request) (*http.Response, error) } func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return f.roundTrip(req) } //nolint:gosec // these are test credentials const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY----- MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92 5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0 wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9 pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8 YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8 -----END RSA PRIVATE KEY-----`