From 8e0a153725623075ba4daf59b4acda5ecd9b584c Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 22 Jan 2024 14:46:05 -0600 Subject: [PATCH] chore: implement device auth flow for fake idp (#11707) * chore: implement device auth flow for fake idp --- coderd/coderdtest/oidctest/idp.go | 268 ++++++++++++++++++++++++++-- coderd/externalauth/externalauth.go | 32 +++- coderd/externalauth_test.go | 46 +++++ scripts/testidp/main.go | 10 +- 4 files changed, 333 insertions(+), 23 deletions(-) diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index e830bb0511..044db86ce0 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -10,11 +10,14 @@ import ( "errors" "fmt" "io" + "math/rand" + "mime" "net" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" + "strconv" "strings" "testing" "time" @@ -34,9 +37,11 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/externalauth" + "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/util/syncmap" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" ) type token struct { @@ -45,6 +50,13 @@ type token struct { exp time.Time } +type deviceFlow struct { + // userInput is the expected input to authenticate the device flow. + userInput string + exp time.Time + granted bool +} + // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { @@ -77,6 +89,8 @@ type FakeIDP struct { refreshTokens *syncmap.Map[string, string] stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims] refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims] + // Device flow + deviceCode *syncmap.Map[string, deviceFlow] // hooks // hookValidRedirectURL can be used to reject a redirect url from the @@ -226,6 +240,8 @@ const ( authorizePath = "/oauth2/authorize" keysPath = "/oauth2/keys" userInfoPath = "/oauth2/userinfo" + deviceAuth = "/login/device/code" + deviceVerify = "/login/device" ) func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { @@ -246,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { refreshTokensUsed: syncmap.New[string, bool](), stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), + deviceCode: syncmap.New[string, deviceFlow](), hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil }, hookValidRedirectURL: func(redirectURL string) error { return nil }, @@ -288,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { // ProviderJSON is the JSON representation of the OpenID Connect provider // These are all the urls that the IDP will respond to. f.provider = ProviderJSON{ - Issuer: issuer, - AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), - TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), - JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(), - UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(), + Issuer: issuer, + AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), + TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), + JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(), + UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(), + DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(), Algorithms: []string{ "RS256", }, @@ -467,6 +485,31 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f _ = res.Body.Close() } +// DeviceLogin does the oauth2 device flow for external auth providers. +func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) { + // First we need to initiate the device flow. This will have Coder hit the + // fake IDP and get a device code. + device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID) + require.NoError(t, err) + + // Now the user needs to go to the fake IDP page and click "allow" and enter + // the device code input. For our purposes, we just send an http request to + // the verification url. No additional user input is needed. + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil) + require.NoError(t, err) + defer resp.Body.Close() + + // Now we need to exchange the device code for an access token. We do this + // in this method because it is the user that does the polling for the device + // auth flow, not the backend. + err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{ + DeviceCode: device.DeviceCode, + }) + require.NoError(t, err) +} + // CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing // unit tests, it's easier to skip this step sometimes. It does make an actual // request to the IDP, so it should be equivalent to doing this "manually" with @@ -536,12 +579,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map // ProviderJSON is the .well-known/configuration JSON type ProviderJSON struct { - Issuer string `json:"issuer"` - AuthURL string `json:"authorization_endpoint"` - TokenURL string `json:"token_endpoint"` - JWKSURL string `json:"jwks_uri"` - UserInfoURL string `json:"userinfo_endpoint"` - Algorithms []string `json:"id_token_signing_alg_values_supported"` + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + JWKSURL string `json:"jwks_uri"` + UserInfoURL string `json:"userinfo_endpoint"` + DeviceCodeURL string `json:"device_authorization_endpoint"` + Algorithms []string `json:"id_token_signing_alg_values_supported"` // This is custom ExternalAuthURL string `json:"external_auth_url"` } @@ -709,8 +753,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { })) mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - values, err := f.authenticateOIDCClientRequest(t, r) + var values url.Values + var err error + if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" { + values = r.URL.Query() + } else { + values, err = f.authenticateOIDCClientRequest(t, r) + } f.logger.Info(r.Context(), "http idp call token", + slog.F("url", r.URL.String()), slog.F("valid", err == nil), slog.F("grant_type", values.Get("grant_type")), slog.F("values", values.Encode()), @@ -784,6 +835,37 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { f.refreshTokensUsed.Store(refreshToken, true) // Always invalidate the refresh token after it is used. f.refreshTokens.Delete(refreshToken) + case "urn:ietf:params:oauth:grant-type:device_code": + // Device flow + var resp externalauth.ExchangeDeviceCodeResponse + deviceCode := values.Get("device_code") + if deviceCode == "" { + resp.Error = "invalid_request" + resp.ErrorDescription = "missing device_code" + httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp) + return + } + + deviceFlow, ok := f.deviceCode.Load(deviceCode) + if !ok { + resp.Error = "invalid_request" + resp.ErrorDescription = "device_code provided not found" + httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp) + return + } + + if !deviceFlow.granted { + // Status code ok with the error as pending. + resp.Error = "authorization_pending" + resp.ErrorDescription = "" + httpapi.Write(r.Context(), rw, http.StatusOK, resp) + return + } + + // Would be nice to get an actual email here. + claims = jwt.MapClaims{ + "email": "unknown-dev-auth", + } default: t.Errorf("unexpected grant_type %q", values.Get("grant_type")) http.Error(rw, "invalid grant_type", http.StatusBadRequest) @@ -807,8 +889,30 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // Store the claims for the next refresh f.refreshIDTokenClaims.Store(refreshToken, claims) - rw.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(rw).Encode(token) + mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")) + if mediaType == "application/x-www-form-urlencoded" { + // This val encode might not work for some data structures. + // It's good enough for now... + rw.Header().Set("Content-Type", "application/x-www-form-urlencoded") + vals := url.Values{} + for k, v := range token { + vals.Set(k, fmt.Sprintf("%v", v)) + } + _, _ = rw.Write([]byte(vals.Encode())) + return + } + // Default to json since the oauth2 package doesn't use Accept headers. + if mediaType == "application/json" || mediaType == "" { + rw.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(rw).Encode(token) + return + } + + // If we get something we don't support, throw an error. + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "'Accept' header contains unsupported media type", + Detail: fmt.Sprintf("Found %q", mediaType), + }) })) validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) { @@ -886,6 +990,125 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { _ = json.NewEncoder(rw).Encode(set) })) + mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + f.logger.Info(r.Context(), "http call device verify") + + inputParam := "user_input" + userInput := r.URL.Query().Get(inputParam) + if userInput == "" { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid user input", + Detail: fmt.Sprintf("Hit this url again with ?%s=", inputParam), + }) + return + } + + deviceCode := r.URL.Query().Get("device_code") + if deviceCode == "" { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid device code", + Detail: "Hit this url again with ?device_code=", + }) + return + } + + flow, ok := f.deviceCode.Load(deviceCode) + if !ok { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid device code", + Detail: "Device code not found.", + }) + return + } + + if time.Now().After(flow.exp) { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid device code", + Detail: "Device code expired.", + }) + return + } + + if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid device code", + Detail: "user code does not match", + }) + return + } + + f.deviceCode.Store(deviceCode, deviceFlow{ + userInput: flow.userInput, + exp: flow.exp, + granted: true, + }) + httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ + Message: "Device authenticated!", + }) + })) + + mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + f.logger.Info(r.Context(), "http call device auth") + + p := httpapi.NewQueryParamParser() + p.Required("client_id") + clientID := p.String(r.URL.Query(), "", "client_id") + _ = p.String(r.URL.Query(), "", "scopes") + if len(p.Errors) > 0 { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid query params", + Validations: p.Errors, + }) + return + } + + if clientID != f.clientID { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid client id", + }) + return + } + + deviceCode := uuid.NewString() + lifetime := time.Second * 900 + flow := deviceFlow{ + //nolint:gosec + userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8), + } + f.deviceCode.Store(deviceCode, deviceFlow{ + userInput: flow.userInput, + exp: time.Now().Add(lifetime), + }) + + verifyURL := f.issuerURL.ResolveReference(&url.URL{ + Path: deviceVerify, + RawQuery: url.Values{ + "device_code": {deviceCode}, + "user_input": {flow.userInput}, + }.Encode(), + }).String() + + if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" { + httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{ + "device_code": deviceCode, + "user_code": flow.userInput, + "verification_uri": verifyURL, + "expires_in": int(lifetime.Seconds()), + "interval": 3, + }) + return + } + + // By default, GitHub form encodes these. + _, _ = fmt.Fprint(rw, url.Values{ + "device_code": {deviceCode}, + "user_code": {flow.userInput}, + "verification_uri": {verifyURL}, + "expires_in": {strconv.Itoa(int(lifetime.Seconds()))}, + "interval": {"3"}, + }.Encode()) + })) + mux.NotFound(func(rw http.ResponseWriter, r *http.Request) { f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path)) t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path) @@ -987,6 +1210,8 @@ type ExternalAuthConfigOptions struct { // completely customize the response. It captures all routes under the /external-auth-validate/* // so the caller can do whatever they want and even add routes. routes map[string]func(email string, rw http.ResponseWriter, r *http.Request) + + UseDeviceAuth bool } func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions { @@ -1033,9 +1258,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu } } instrumentF := promoauth.NewFactory(prometheus.NewRegistry()) + oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil)) cfg := &externalauth.Config{ DisplayName: id, - InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)), + InstrumentedOAuth2Config: oauthCfg, ID: id, // No defaults for these fields by omitting the type Type: "", @@ -1043,7 +1269,19 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu // Omit the /user for the validate so we can easily append to it when modifying // the cfg for advanced tests. ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), + DeviceAuth: &externalauth.DeviceAuth{ + Config: oauthCfg, + ClientID: f.clientID, + TokenURL: f.provider.TokenURL, + Scopes: []string{}, + CodeURL: f.provider.DeviceCodeURL, + }, } + + if !custom.UseDeviceAuth { + cfg.DeviceAuth = nil + } + for _, opt := range opts { opt(cfg) } diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 72d02b5139..0c936743a0 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -6,9 +6,11 @@ import ( "encoding/json" "fmt" "io" + "mime" "net/http" "net/url" "regexp" + "strconv" "strings" "time" @@ -321,13 +323,31 @@ func (c *DeviceAuth) AuthorizeDevice(ctx context.Context) (*codersdk.ExternalAut } err = json.NewDecoder(resp.Body).Decode(&r) if err != nil { - // Some status codes do not return json payloads, and we should - // return a better error. - switch resp.StatusCode { - case http.StatusTooManyRequests: - return nil, xerrors.New("rate limit hit, unable to authorize device. please try again later") + mediaType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + mediaType = "unknown" + } + + // If the json fails to decode, do a best effort to return a better error. + switch { + case resp.StatusCode == http.StatusTooManyRequests: + retryIn := "please try again later" + resetIn := resp.Header.Get("x-ratelimit-reset") + if resetIn != "" { + // Best effort to tell the user exactly how long they need + // to wait for. + unix, err := strconv.ParseInt(resetIn, 10, 64) + if err == nil { + waitFor := time.Unix(unix, 0).Sub(time.Now().Truncate(time.Second)) + retryIn = fmt.Sprintf(" retry after %s", waitFor.Truncate(time.Second)) + } + } + // 429 returns a plaintext payload with a message. + return nil, xerrors.New(fmt.Sprintf("rate limit hit, unable to authorize device. %s", retryIn)) + case mediaType == "application/x-www-form-urlencoded": + return nil, xerrors.Errorf("status_code=%d, payload response is form-url encoded, expected a json payload", resp.StatusCode) default: - return nil, xerrors.Errorf("status_code=%d: %w", resp.StatusCode, err) + return nil, fmt.Errorf("status_code=%d, mediaType=%s: %w", resp.StatusCode, mediaType, err) } } if r.ErrorDescription != "" { diff --git a/coderd/externalauth_test.go b/coderd/externalauth_test.go index 17adfac69d..db40ccf38a 100644 --- a/coderd/externalauth_test.go +++ b/coderd/externalauth_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "regexp" "strings" "testing" @@ -264,6 +265,27 @@ func TestExternalAuthManagement(t *testing.T) { func TestExternalAuthDevice(t *testing.T) { t.Parallel() + // This is an example test on how to do device auth flow using our fake idp. + t.Run("WithFakeIDP", func(t *testing.T) { + t.Parallel() + fake := oidctest.NewFakeIDP(t, oidctest.WithServing()) + externalID := "fake-idp" + cfg := fake.ExternalAuthConfig(t, externalID, &oidctest.ExternalAuthConfigOptions{ + UseDeviceAuth: true, + }) + + client := coderdtest.New(t, &coderdtest.Options{ + ExternalAuthConfigs: []*externalauth.Config{cfg}, + }) + coderdtest.CreateFirstUser(t, client) + // Login! + fake.DeviceLogin(t, client, externalID) + + extAuth, err := client.ExternalAuthByID(context.Background(), externalID) + require.NoError(t, err) + require.True(t, extAuth.Authenticated) + }) + t.Run("NotSupported", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{ @@ -363,6 +385,30 @@ func TestExternalAuthDevice(t *testing.T) { _, err := client.ExternalAuthDeviceByID(context.Background(), "test") require.ErrorContains(t, err, "rate limit hit") }) + + // If we forget to add the accept header, we get a form encoded body instead. + t.Run("FormEncodedBody", func(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, _ = w.Write([]byte(url.Values{"access_token": {"hey"}}.Encode())) + })) + defer srv.Close() + client := coderdtest.New(t, &coderdtest.Options{ + ExternalAuthConfigs: []*externalauth.Config{{ + ID: "test", + DeviceAuth: &externalauth.DeviceAuth{ + ClientID: "test", + CodeURL: srv.URL, + Scopes: []string{"repo"}, + }, + }}, + }) + coderdtest.CreateFirstUser(t, client) + _, err := client.ExternalAuthDeviceByID(context.Background(), "test") + require.Error(t, err) + require.ErrorContains(t, err, "is form-url encoded") + }) } // nolint:bodyclose diff --git a/scripts/testidp/main.go b/scripts/testidp/main.go index 49902eca17..82fc10c936 100644 --- a/scripts/testidp/main.go +++ b/scripts/testidp/main.go @@ -23,6 +23,7 @@ var ( expiry = flag.Duration("expiry", time.Minute*5, "Token expiry") clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random") clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random") + deviceFlow = flag.Bool("device-flow", false, "Enable device flow") // By default, no regex means it will never match anything. So at least default to matching something. extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex") ) @@ -66,14 +67,18 @@ func RunIDP() func(t *testing.T) { id, sec := idp.AppCredentials() prov := idp.WellknownConfig() const appID = "fake" - coderCfg := idp.ExternalAuthConfig(t, appID, nil) + coderCfg := idp.ExternalAuthConfig(t, appID, &oidctest.ExternalAuthConfigOptions{ + UseDeviceAuth: *deviceFlow, + }) log.Println("IDP Issuer URL", idp.IssuerURL()) log.Println("Coderd Flags") + deviceCodeURL := "" if coderCfg.DeviceAuth != nil { deviceCodeURL = coderCfg.DeviceAuth.CodeURL } + cfg := withClientSecret{ ClientSecret: sec, ExternalAuthConfig: codersdk.ExternalAuthConfig{ @@ -89,13 +94,14 @@ func RunIDP() func(t *testing.T) { NoRefresh: false, Scopes: []string{"openid", "email", "profile"}, ExtraTokenKeys: coderCfg.ExtraTokenKeys, - DeviceFlow: coderCfg.DeviceAuth != nil, + DeviceFlow: *deviceFlow, DeviceCodeURL: deviceCodeURL, Regex: *extRegex, DisplayName: coderCfg.DisplayName, DisplayIcon: coderCfg.DisplayIcon, }, } + data, err := json.Marshal([]withClientSecret{cfg}) require.NoError(t, err) log.Printf(`--external-auth-providers='%s'`, string(data))