chore: implement device auth flow for fake idp (#11707)

* chore: implement device auth flow for fake idp
This commit is contained in:
Steven Masley 2024-01-22 14:46:05 -06:00 committed by GitHub
parent 16c6cefde8
commit 8e0a153725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 333 additions and 23 deletions

View File

@ -10,11 +10,14 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"mime"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"
"time"
@ -34,9 +37,11 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/coderd/util/syncmap"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
type token struct {
@ -45,6 +50,13 @@ type token struct {
exp time.Time
}
type deviceFlow struct {
// userInput is the expected input to authenticate the device flow.
userInput string
exp time.Time
granted bool
}
// FakeIDP is a functional OIDC provider.
// It only supports 1 OIDC client.
type FakeIDP struct {
@ -77,6 +89,8 @@ type FakeIDP struct {
refreshTokens *syncmap.Map[string, string]
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
// Device flow
deviceCode *syncmap.Map[string, deviceFlow]
// hooks
// hookValidRedirectURL can be used to reject a redirect url from the
@ -226,6 +240,8 @@ const (
authorizePath = "/oauth2/authorize"
keysPath = "/oauth2/keys"
userInfoPath = "/oauth2/userinfo"
deviceAuth = "/login/device/code"
deviceVerify = "/login/device"
)
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
@ -246,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
refreshTokensUsed: syncmap.New[string, bool](),
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
deviceCode: syncmap.New[string, deviceFlow](),
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
hookValidRedirectURL: func(redirectURL string) error { return nil },
@ -288,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
// ProviderJSON is the JSON representation of the OpenID Connect provider
// These are all the urls that the IDP will respond to.
f.provider = ProviderJSON{
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(),
Algorithms: []string{
"RS256",
},
@ -467,6 +485,31 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
_ = res.Body.Close()
}
// DeviceLogin does the oauth2 device flow for external auth providers.
func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) {
// First we need to initiate the device flow. This will have Coder hit the
// fake IDP and get a device code.
device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID)
require.NoError(t, err)
// Now the user needs to go to the fake IDP page and click "allow" and enter
// the device code input. For our purposes, we just send an http request to
// the verification url. No additional user input is needed.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil)
require.NoError(t, err)
defer resp.Body.Close()
// Now we need to exchange the device code for an access token. We do this
// in this method because it is the user that does the polling for the device
// auth flow, not the backend.
err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{
DeviceCode: device.DeviceCode,
})
require.NoError(t, err)
}
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
// unit tests, it's easier to skip this step sometimes. It does make an actual
// request to the IDP, so it should be equivalent to doing this "manually" with
@ -536,12 +579,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
// ProviderJSON is the .well-known/configuration JSON
type ProviderJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
DeviceCodeURL string `json:"device_authorization_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
// This is custom
ExternalAuthURL string `json:"external_auth_url"`
}
@ -709,8 +753,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
}))
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
values, err := f.authenticateOIDCClientRequest(t, r)
var values url.Values
var err error
if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" {
values = r.URL.Query()
} else {
values, err = f.authenticateOIDCClientRequest(t, r)
}
f.logger.Info(r.Context(), "http idp call token",
slog.F("url", r.URL.String()),
slog.F("valid", err == nil),
slog.F("grant_type", values.Get("grant_type")),
slog.F("values", values.Encode()),
@ -784,6 +835,37 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
f.refreshTokensUsed.Store(refreshToken, true)
// Always invalidate the refresh token after it is used.
f.refreshTokens.Delete(refreshToken)
case "urn:ietf:params:oauth:grant-type:device_code":
// Device flow
var resp externalauth.ExchangeDeviceCodeResponse
deviceCode := values.Get("device_code")
if deviceCode == "" {
resp.Error = "invalid_request"
resp.ErrorDescription = "missing device_code"
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
return
}
deviceFlow, ok := f.deviceCode.Load(deviceCode)
if !ok {
resp.Error = "invalid_request"
resp.ErrorDescription = "device_code provided not found"
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
return
}
if !deviceFlow.granted {
// Status code ok with the error as pending.
resp.Error = "authorization_pending"
resp.ErrorDescription = ""
httpapi.Write(r.Context(), rw, http.StatusOK, resp)
return
}
// Would be nice to get an actual email here.
claims = jwt.MapClaims{
"email": "unknown-dev-auth",
}
default:
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
@ -807,8 +889,30 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
// Store the claims for the next refresh
f.refreshIDTokenClaims.Store(refreshToken, claims)
rw.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(rw).Encode(token)
mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept"))
if mediaType == "application/x-www-form-urlencoded" {
// This val encode might not work for some data structures.
// It's good enough for now...
rw.Header().Set("Content-Type", "application/x-www-form-urlencoded")
vals := url.Values{}
for k, v := range token {
vals.Set(k, fmt.Sprintf("%v", v))
}
_, _ = rw.Write([]byte(vals.Encode()))
return
}
// Default to json since the oauth2 package doesn't use Accept headers.
if mediaType == "application/json" || mediaType == "" {
rw.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(rw).Encode(token)
return
}
// If we get something we don't support, throw an error.
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "'Accept' header contains unsupported media type",
Detail: fmt.Sprintf("Found %q", mediaType),
})
}))
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
@ -886,6 +990,125 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
_ = json.NewEncoder(rw).Encode(set)
}))
mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call device verify")
inputParam := "user_input"
userInput := r.URL.Query().Get(inputParam)
if userInput == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid user input",
Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam),
})
return
}
deviceCode := r.URL.Query().Get("device_code")
if deviceCode == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Hit this url again with ?device_code=<device_code>",
})
return
}
flow, ok := f.deviceCode.Load(deviceCode)
if !ok {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Device code not found.",
})
return
}
if time.Now().After(flow.exp) {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "Device code expired.",
})
return
}
if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid device code",
Detail: "user code does not match",
})
return
}
f.deviceCode.Store(deviceCode, deviceFlow{
userInput: flow.userInput,
exp: flow.exp,
granted: true,
})
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
Message: "Device authenticated!",
})
}))
mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call device auth")
p := httpapi.NewQueryParamParser()
p.Required("client_id")
clientID := p.String(r.URL.Query(), "", "client_id")
_ = p.String(r.URL.Query(), "", "scopes")
if len(p.Errors) > 0 {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query params",
Validations: p.Errors,
})
return
}
if clientID != f.clientID {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid client id",
})
return
}
deviceCode := uuid.NewString()
lifetime := time.Second * 900
flow := deviceFlow{
//nolint:gosec
userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8),
}
f.deviceCode.Store(deviceCode, deviceFlow{
userInput: flow.userInput,
exp: time.Now().Add(lifetime),
})
verifyURL := f.issuerURL.ResolveReference(&url.URL{
Path: deviceVerify,
RawQuery: url.Values{
"device_code": {deviceCode},
"user_input": {flow.userInput},
}.Encode(),
}).String()
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{
"device_code": deviceCode,
"user_code": flow.userInput,
"verification_uri": verifyURL,
"expires_in": int(lifetime.Seconds()),
"interval": 3,
})
return
}
// By default, GitHub form encodes these.
_, _ = fmt.Fprint(rw, url.Values{
"device_code": {deviceCode},
"user_code": {flow.userInput},
"verification_uri": {verifyURL},
"expires_in": {strconv.Itoa(int(lifetime.Seconds()))},
"interval": {"3"},
}.Encode())
}))
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
@ -987,6 +1210,8 @@ type ExternalAuthConfigOptions struct {
// completely customize the response. It captures all routes under the /external-auth-validate/*
// so the caller can do whatever they want and even add routes.
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)
UseDeviceAuth bool
}
func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
@ -1033,9 +1258,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
}
}
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil))
cfg := &externalauth.Config{
DisplayName: id,
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
InstrumentedOAuth2Config: oauthCfg,
ID: id,
// No defaults for these fields by omitting the type
Type: "",
@ -1043,7 +1269,19 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
// Omit the /user for the validate so we can easily append to it when modifying
// the cfg for advanced tests.
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
DeviceAuth: &externalauth.DeviceAuth{
Config: oauthCfg,
ClientID: f.clientID,
TokenURL: f.provider.TokenURL,
Scopes: []string{},
CodeURL: f.provider.DeviceCodeURL,
},
}
if !custom.UseDeviceAuth {
cfg.DeviceAuth = nil
}
for _, opt := range opts {
opt(cfg)
}

View File

@ -6,9 +6,11 @@ import (
"encoding/json"
"fmt"
"io"
"mime"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
@ -321,13 +323,31 @@ func (c *DeviceAuth) AuthorizeDevice(ctx context.Context) (*codersdk.ExternalAut
}
err = json.NewDecoder(resp.Body).Decode(&r)
if err != nil {
// Some status codes do not return json payloads, and we should
// return a better error.
switch resp.StatusCode {
case http.StatusTooManyRequests:
return nil, xerrors.New("rate limit hit, unable to authorize device. please try again later")
mediaType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
mediaType = "unknown"
}
// If the json fails to decode, do a best effort to return a better error.
switch {
case resp.StatusCode == http.StatusTooManyRequests:
retryIn := "please try again later"
resetIn := resp.Header.Get("x-ratelimit-reset")
if resetIn != "" {
// Best effort to tell the user exactly how long they need
// to wait for.
unix, err := strconv.ParseInt(resetIn, 10, 64)
if err == nil {
waitFor := time.Unix(unix, 0).Sub(time.Now().Truncate(time.Second))
retryIn = fmt.Sprintf(" retry after %s", waitFor.Truncate(time.Second))
}
}
// 429 returns a plaintext payload with a message.
return nil, xerrors.New(fmt.Sprintf("rate limit hit, unable to authorize device. %s", retryIn))
case mediaType == "application/x-www-form-urlencoded":
return nil, xerrors.Errorf("status_code=%d, payload response is form-url encoded, expected a json payload", resp.StatusCode)
default:
return nil, xerrors.Errorf("status_code=%d: %w", resp.StatusCode, err)
return nil, fmt.Errorf("status_code=%d, mediaType=%s: %w", resp.StatusCode, mediaType, err)
}
}
if r.ErrorDescription != "" {

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
@ -264,6 +265,27 @@ func TestExternalAuthManagement(t *testing.T) {
func TestExternalAuthDevice(t *testing.T) {
t.Parallel()
// This is an example test on how to do device auth flow using our fake idp.
t.Run("WithFakeIDP", func(t *testing.T) {
t.Parallel()
fake := oidctest.NewFakeIDP(t, oidctest.WithServing())
externalID := "fake-idp"
cfg := fake.ExternalAuthConfig(t, externalID, &oidctest.ExternalAuthConfigOptions{
UseDeviceAuth: true,
})
client := coderdtest.New(t, &coderdtest.Options{
ExternalAuthConfigs: []*externalauth.Config{cfg},
})
coderdtest.CreateFirstUser(t, client)
// Login!
fake.DeviceLogin(t, client, externalID)
extAuth, err := client.ExternalAuthByID(context.Background(), externalID)
require.NoError(t, err)
require.True(t, extAuth.Authenticated)
})
t.Run("NotSupported", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
@ -363,6 +385,30 @@ func TestExternalAuthDevice(t *testing.T) {
_, err := client.ExternalAuthDeviceByID(context.Background(), "test")
require.ErrorContains(t, err, "rate limit hit")
})
// If we forget to add the accept header, we get a form encoded body instead.
t.Run("FormEncodedBody", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
_, _ = w.Write([]byte(url.Values{"access_token": {"hey"}}.Encode()))
}))
defer srv.Close()
client := coderdtest.New(t, &coderdtest.Options{
ExternalAuthConfigs: []*externalauth.Config{{
ID: "test",
DeviceAuth: &externalauth.DeviceAuth{
ClientID: "test",
CodeURL: srv.URL,
Scopes: []string{"repo"},
},
}},
})
coderdtest.CreateFirstUser(t, client)
_, err := client.ExternalAuthDeviceByID(context.Background(), "test")
require.Error(t, err)
require.ErrorContains(t, err, "is form-url encoded")
})
}
// nolint:bodyclose

View File

@ -23,6 +23,7 @@ var (
expiry = flag.Duration("expiry", time.Minute*5, "Token expiry")
clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random")
clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random")
deviceFlow = flag.Bool("device-flow", false, "Enable device flow")
// By default, no regex means it will never match anything. So at least default to matching something.
extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex")
)
@ -66,14 +67,18 @@ func RunIDP() func(t *testing.T) {
id, sec := idp.AppCredentials()
prov := idp.WellknownConfig()
const appID = "fake"
coderCfg := idp.ExternalAuthConfig(t, appID, nil)
coderCfg := idp.ExternalAuthConfig(t, appID, &oidctest.ExternalAuthConfigOptions{
UseDeviceAuth: *deviceFlow,
})
log.Println("IDP Issuer URL", idp.IssuerURL())
log.Println("Coderd Flags")
deviceCodeURL := ""
if coderCfg.DeviceAuth != nil {
deviceCodeURL = coderCfg.DeviceAuth.CodeURL
}
cfg := withClientSecret{
ClientSecret: sec,
ExternalAuthConfig: codersdk.ExternalAuthConfig{
@ -89,13 +94,14 @@ func RunIDP() func(t *testing.T) {
NoRefresh: false,
Scopes: []string{"openid", "email", "profile"},
ExtraTokenKeys: coderCfg.ExtraTokenKeys,
DeviceFlow: coderCfg.DeviceAuth != nil,
DeviceFlow: *deviceFlow,
DeviceCodeURL: deviceCodeURL,
Regex: *extRegex,
DisplayName: coderCfg.DisplayName,
DisplayIcon: coderCfg.DisplayIcon,
},
}
data, err := json.Marshal([]withClientSecret{cfg})
require.NoError(t, err)
log.Printf(`--external-auth-providers='%s'`, string(data))