mirror of https://github.com/coder/coder.git
chore: implement device auth flow for fake idp (#11707)
* chore: implement device auth flow for fake idp
This commit is contained in:
parent
16c6cefde8
commit
8e0a153725
|
@ -10,11 +10,14 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"mime"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -34,9 +37,11 @@ import (
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
"github.com/coder/coder/v2/coderd"
|
"github.com/coder/coder/v2/coderd"
|
||||||
"github.com/coder/coder/v2/coderd/externalauth"
|
"github.com/coder/coder/v2/coderd/externalauth"
|
||||||
|
"github.com/coder/coder/v2/coderd/httpapi"
|
||||||
"github.com/coder/coder/v2/coderd/promoauth"
|
"github.com/coder/coder/v2/coderd/promoauth"
|
||||||
"github.com/coder/coder/v2/coderd/util/syncmap"
|
"github.com/coder/coder/v2/coderd/util/syncmap"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type token struct {
|
type token struct {
|
||||||
|
@ -45,6 +50,13 @@ type token struct {
|
||||||
exp time.Time
|
exp time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type deviceFlow struct {
|
||||||
|
// userInput is the expected input to authenticate the device flow.
|
||||||
|
userInput string
|
||||||
|
exp time.Time
|
||||||
|
granted bool
|
||||||
|
}
|
||||||
|
|
||||||
// FakeIDP is a functional OIDC provider.
|
// FakeIDP is a functional OIDC provider.
|
||||||
// It only supports 1 OIDC client.
|
// It only supports 1 OIDC client.
|
||||||
type FakeIDP struct {
|
type FakeIDP struct {
|
||||||
|
@ -77,6 +89,8 @@ type FakeIDP struct {
|
||||||
refreshTokens *syncmap.Map[string, string]
|
refreshTokens *syncmap.Map[string, string]
|
||||||
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||||
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||||
|
// Device flow
|
||||||
|
deviceCode *syncmap.Map[string, deviceFlow]
|
||||||
|
|
||||||
// hooks
|
// hooks
|
||||||
// hookValidRedirectURL can be used to reject a redirect url from the
|
// hookValidRedirectURL can be used to reject a redirect url from the
|
||||||
|
@ -226,6 +240,8 @@ const (
|
||||||
authorizePath = "/oauth2/authorize"
|
authorizePath = "/oauth2/authorize"
|
||||||
keysPath = "/oauth2/keys"
|
keysPath = "/oauth2/keys"
|
||||||
userInfoPath = "/oauth2/userinfo"
|
userInfoPath = "/oauth2/userinfo"
|
||||||
|
deviceAuth = "/login/device/code"
|
||||||
|
deviceVerify = "/login/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||||
|
@ -246,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||||
refreshTokensUsed: syncmap.New[string, bool](),
|
refreshTokensUsed: syncmap.New[string, bool](),
|
||||||
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||||
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||||
|
deviceCode: syncmap.New[string, deviceFlow](),
|
||||||
hookOnRefresh: func(_ string) error { return nil },
|
hookOnRefresh: func(_ string) error { return nil },
|
||||||
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
|
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
|
||||||
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||||
|
@ -288,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||||
// ProviderJSON is the JSON representation of the OpenID Connect provider
|
// ProviderJSON is the JSON representation of the OpenID Connect provider
|
||||||
// These are all the urls that the IDP will respond to.
|
// These are all the urls that the IDP will respond to.
|
||||||
f.provider = ProviderJSON{
|
f.provider = ProviderJSON{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||||
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||||
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
|
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
|
||||||
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
|
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
|
||||||
|
DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(),
|
||||||
Algorithms: []string{
|
Algorithms: []string{
|
||||||
"RS256",
|
"RS256",
|
||||||
},
|
},
|
||||||
|
@ -467,6 +485,31 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
|
||||||
_ = res.Body.Close()
|
_ = res.Body.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeviceLogin does the oauth2 device flow for external auth providers.
|
||||||
|
func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) {
|
||||||
|
// First we need to initiate the device flow. This will have Coder hit the
|
||||||
|
// fake IDP and get a device code.
|
||||||
|
device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now the user needs to go to the fake IDP page and click "allow" and enter
|
||||||
|
// the device code input. For our purposes, we just send an http request to
|
||||||
|
// the verification url. No additional user input is needed.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Now we need to exchange the device code for an access token. We do this
|
||||||
|
// in this method because it is the user that does the polling for the device
|
||||||
|
// auth flow, not the backend.
|
||||||
|
err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{
|
||||||
|
DeviceCode: device.DeviceCode,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
|
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
|
||||||
// unit tests, it's easier to skip this step sometimes. It does make an actual
|
// unit tests, it's easier to skip this step sometimes. It does make an actual
|
||||||
// request to the IDP, so it should be equivalent to doing this "manually" with
|
// request to the IDP, so it should be equivalent to doing this "manually" with
|
||||||
|
@ -536,12 +579,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
|
||||||
|
|
||||||
// ProviderJSON is the .well-known/configuration JSON
|
// ProviderJSON is the .well-known/configuration JSON
|
||||||
type ProviderJSON struct {
|
type ProviderJSON struct {
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
AuthURL string `json:"authorization_endpoint"`
|
AuthURL string `json:"authorization_endpoint"`
|
||||||
TokenURL string `json:"token_endpoint"`
|
TokenURL string `json:"token_endpoint"`
|
||||||
JWKSURL string `json:"jwks_uri"`
|
JWKSURL string `json:"jwks_uri"`
|
||||||
UserInfoURL string `json:"userinfo_endpoint"`
|
UserInfoURL string `json:"userinfo_endpoint"`
|
||||||
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
DeviceCodeURL string `json:"device_authorization_endpoint"`
|
||||||
|
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||||
// This is custom
|
// This is custom
|
||||||
ExternalAuthURL string `json:"external_auth_url"`
|
ExternalAuthURL string `json:"external_auth_url"`
|
||||||
}
|
}
|
||||||
|
@ -709,8 +753,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
values, err := f.authenticateOIDCClientRequest(t, r)
|
var values url.Values
|
||||||
|
var err error
|
||||||
|
if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" {
|
||||||
|
values = r.URL.Query()
|
||||||
|
} else {
|
||||||
|
values, err = f.authenticateOIDCClientRequest(t, r)
|
||||||
|
}
|
||||||
f.logger.Info(r.Context(), "http idp call token",
|
f.logger.Info(r.Context(), "http idp call token",
|
||||||
|
slog.F("url", r.URL.String()),
|
||||||
slog.F("valid", err == nil),
|
slog.F("valid", err == nil),
|
||||||
slog.F("grant_type", values.Get("grant_type")),
|
slog.F("grant_type", values.Get("grant_type")),
|
||||||
slog.F("values", values.Encode()),
|
slog.F("values", values.Encode()),
|
||||||
|
@ -784,6 +835,37 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
f.refreshTokensUsed.Store(refreshToken, true)
|
f.refreshTokensUsed.Store(refreshToken, true)
|
||||||
// Always invalidate the refresh token after it is used.
|
// Always invalidate the refresh token after it is used.
|
||||||
f.refreshTokens.Delete(refreshToken)
|
f.refreshTokens.Delete(refreshToken)
|
||||||
|
case "urn:ietf:params:oauth:grant-type:device_code":
|
||||||
|
// Device flow
|
||||||
|
var resp externalauth.ExchangeDeviceCodeResponse
|
||||||
|
deviceCode := values.Get("device_code")
|
||||||
|
if deviceCode == "" {
|
||||||
|
resp.Error = "invalid_request"
|
||||||
|
resp.ErrorDescription = "missing device_code"
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceFlow, ok := f.deviceCode.Load(deviceCode)
|
||||||
|
if !ok {
|
||||||
|
resp.Error = "invalid_request"
|
||||||
|
resp.ErrorDescription = "device_code provided not found"
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !deviceFlow.granted {
|
||||||
|
// Status code ok with the error as pending.
|
||||||
|
resp.Error = "authorization_pending"
|
||||||
|
resp.ErrorDescription = ""
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusOK, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Would be nice to get an actual email here.
|
||||||
|
claims = jwt.MapClaims{
|
||||||
|
"email": "unknown-dev-auth",
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
|
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
|
||||||
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
|
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
|
||||||
|
@ -807,8 +889,30 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
// Store the claims for the next refresh
|
// Store the claims for the next refresh
|
||||||
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||||
|
|
||||||
rw.Header().Set("Content-Type", "application/json")
|
mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept"))
|
||||||
_ = json.NewEncoder(rw).Encode(token)
|
if mediaType == "application/x-www-form-urlencoded" {
|
||||||
|
// This val encode might not work for some data structures.
|
||||||
|
// It's good enough for now...
|
||||||
|
rw.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
vals := url.Values{}
|
||||||
|
for k, v := range token {
|
||||||
|
vals.Set(k, fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
_, _ = rw.Write([]byte(vals.Encode()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Default to json since the oauth2 package doesn't use Accept headers.
|
||||||
|
if mediaType == "application/json" || mediaType == "" {
|
||||||
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(rw).Encode(token)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get something we don't support, throw an error.
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "'Accept' header contains unsupported media type",
|
||||||
|
Detail: fmt.Sprintf("Found %q", mediaType),
|
||||||
|
})
|
||||||
}))
|
}))
|
||||||
|
|
||||||
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
|
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
|
||||||
|
@ -886,6 +990,125 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
_ = json.NewEncoder(rw).Encode(set)
|
_ = json.NewEncoder(rw).Encode(set)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
f.logger.Info(r.Context(), "http call device verify")
|
||||||
|
|
||||||
|
inputParam := "user_input"
|
||||||
|
userInput := r.URL.Query().Get(inputParam)
|
||||||
|
if userInput == "" {
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Invalid user input",
|
||||||
|
Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := r.URL.Query().Get("device_code")
|
||||||
|
if deviceCode == "" {
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Invalid device code",
|
||||||
|
Detail: "Hit this url again with ?device_code=<device_code>",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, ok := f.deviceCode.Load(deviceCode)
|
||||||
|
if !ok {
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Invalid device code",
|
||||||
|
Detail: "Device code not found.",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(flow.exp) {
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Invalid device code",
|
||||||
|
Detail: "Device code expired.",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) {
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Invalid device code",
|
||||||
|
Detail: "user code does not match",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.deviceCode.Store(deviceCode, deviceFlow{
|
||||||
|
userInput: flow.userInput,
|
||||||
|
exp: flow.exp,
|
||||||
|
granted: true,
|
||||||
|
})
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
|
||||||
|
Message: "Device authenticated!",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
f.logger.Info(r.Context(), "http call device auth")
|
||||||
|
|
||||||
|
p := httpapi.NewQueryParamParser()
|
||||||
|
p.Required("client_id")
|
||||||
|
clientID := p.String(r.URL.Query(), "", "client_id")
|
||||||
|
_ = p.String(r.URL.Query(), "", "scopes")
|
||||||
|
if len(p.Errors) > 0 {
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Invalid query params",
|
||||||
|
Validations: p.Errors,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientID != f.clientID {
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Invalid client id",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := uuid.NewString()
|
||||||
|
lifetime := time.Second * 900
|
||||||
|
flow := deviceFlow{
|
||||||
|
//nolint:gosec
|
||||||
|
userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8),
|
||||||
|
}
|
||||||
|
f.deviceCode.Store(deviceCode, deviceFlow{
|
||||||
|
userInput: flow.userInput,
|
||||||
|
exp: time.Now().Add(lifetime),
|
||||||
|
})
|
||||||
|
|
||||||
|
verifyURL := f.issuerURL.ResolveReference(&url.URL{
|
||||||
|
Path: deviceVerify,
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"device_code": {deviceCode},
|
||||||
|
"user_input": {flow.userInput},
|
||||||
|
}.Encode(),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
|
||||||
|
httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{
|
||||||
|
"device_code": deviceCode,
|
||||||
|
"user_code": flow.userInput,
|
||||||
|
"verification_uri": verifyURL,
|
||||||
|
"expires_in": int(lifetime.Seconds()),
|
||||||
|
"interval": 3,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// By default, GitHub form encodes these.
|
||||||
|
_, _ = fmt.Fprint(rw, url.Values{
|
||||||
|
"device_code": {deviceCode},
|
||||||
|
"user_code": {flow.userInput},
|
||||||
|
"verification_uri": {verifyURL},
|
||||||
|
"expires_in": {strconv.Itoa(int(lifetime.Seconds()))},
|
||||||
|
"interval": {"3"},
|
||||||
|
}.Encode())
|
||||||
|
}))
|
||||||
|
|
||||||
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
|
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
|
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
|
||||||
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
|
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
|
||||||
|
@ -987,6 +1210,8 @@ type ExternalAuthConfigOptions struct {
|
||||||
// completely customize the response. It captures all routes under the /external-auth-validate/*
|
// completely customize the response. It captures all routes under the /external-auth-validate/*
|
||||||
// so the caller can do whatever they want and even add routes.
|
// so the caller can do whatever they want and even add routes.
|
||||||
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)
|
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)
|
||||||
|
|
||||||
|
UseDeviceAuth bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
|
func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
|
||||||
|
@ -1033,9 +1258,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
|
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
|
||||||
|
oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil))
|
||||||
cfg := &externalauth.Config{
|
cfg := &externalauth.Config{
|
||||||
DisplayName: id,
|
DisplayName: id,
|
||||||
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
|
InstrumentedOAuth2Config: oauthCfg,
|
||||||
ID: id,
|
ID: id,
|
||||||
// No defaults for these fields by omitting the type
|
// No defaults for these fields by omitting the type
|
||||||
Type: "",
|
Type: "",
|
||||||
|
@ -1043,7 +1269,19 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||||
// Omit the /user for the validate so we can easily append to it when modifying
|
// Omit the /user for the validate so we can easily append to it when modifying
|
||||||
// the cfg for advanced tests.
|
// the cfg for advanced tests.
|
||||||
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
|
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
|
||||||
|
DeviceAuth: &externalauth.DeviceAuth{
|
||||||
|
Config: oauthCfg,
|
||||||
|
ClientID: f.clientID,
|
||||||
|
TokenURL: f.provider.TokenURL,
|
||||||
|
Scopes: []string{},
|
||||||
|
CodeURL: f.provider.DeviceCodeURL,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !custom.UseDeviceAuth {
|
||||||
|
cfg.DeviceAuth = nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(cfg)
|
opt(cfg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,9 +6,11 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -321,13 +323,31 @@ func (c *DeviceAuth) AuthorizeDevice(ctx context.Context) (*codersdk.ExternalAut
|
||||||
}
|
}
|
||||||
err = json.NewDecoder(resp.Body).Decode(&r)
|
err = json.NewDecoder(resp.Body).Decode(&r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Some status codes do not return json payloads, and we should
|
mediaType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||||
// return a better error.
|
if err != nil {
|
||||||
switch resp.StatusCode {
|
mediaType = "unknown"
|
||||||
case http.StatusTooManyRequests:
|
}
|
||||||
return nil, xerrors.New("rate limit hit, unable to authorize device. please try again later")
|
|
||||||
|
// If the json fails to decode, do a best effort to return a better error.
|
||||||
|
switch {
|
||||||
|
case resp.StatusCode == http.StatusTooManyRequests:
|
||||||
|
retryIn := "please try again later"
|
||||||
|
resetIn := resp.Header.Get("x-ratelimit-reset")
|
||||||
|
if resetIn != "" {
|
||||||
|
// Best effort to tell the user exactly how long they need
|
||||||
|
// to wait for.
|
||||||
|
unix, err := strconv.ParseInt(resetIn, 10, 64)
|
||||||
|
if err == nil {
|
||||||
|
waitFor := time.Unix(unix, 0).Sub(time.Now().Truncate(time.Second))
|
||||||
|
retryIn = fmt.Sprintf(" retry after %s", waitFor.Truncate(time.Second))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 429 returns a plaintext payload with a message.
|
||||||
|
return nil, xerrors.New(fmt.Sprintf("rate limit hit, unable to authorize device. %s", retryIn))
|
||||||
|
case mediaType == "application/x-www-form-urlencoded":
|
||||||
|
return nil, xerrors.Errorf("status_code=%d, payload response is form-url encoded, expected a json payload", resp.StatusCode)
|
||||||
default:
|
default:
|
||||||
return nil, xerrors.Errorf("status_code=%d: %w", resp.StatusCode, err)
|
return nil, fmt.Errorf("status_code=%d, mediaType=%s: %w", resp.StatusCode, mediaType, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if r.ErrorDescription != "" {
|
if r.ErrorDescription != "" {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -264,6 +265,27 @@ func TestExternalAuthManagement(t *testing.T) {
|
||||||
|
|
||||||
func TestExternalAuthDevice(t *testing.T) {
|
func TestExternalAuthDevice(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
// This is an example test on how to do device auth flow using our fake idp.
|
||||||
|
t.Run("WithFakeIDP", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
fake := oidctest.NewFakeIDP(t, oidctest.WithServing())
|
||||||
|
externalID := "fake-idp"
|
||||||
|
cfg := fake.ExternalAuthConfig(t, externalID, &oidctest.ExternalAuthConfigOptions{
|
||||||
|
UseDeviceAuth: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
|
ExternalAuthConfigs: []*externalauth.Config{cfg},
|
||||||
|
})
|
||||||
|
coderdtest.CreateFirstUser(t, client)
|
||||||
|
// Login!
|
||||||
|
fake.DeviceLogin(t, client, externalID)
|
||||||
|
|
||||||
|
extAuth, err := client.ExternalAuthByID(context.Background(), externalID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, extAuth.Authenticated)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("NotSupported", func(t *testing.T) {
|
t.Run("NotSupported", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
|
@ -363,6 +385,30 @@ func TestExternalAuthDevice(t *testing.T) {
|
||||||
_, err := client.ExternalAuthDeviceByID(context.Background(), "test")
|
_, err := client.ExternalAuthDeviceByID(context.Background(), "test")
|
||||||
require.ErrorContains(t, err, "rate limit hit")
|
require.ErrorContains(t, err, "rate limit hit")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// If we forget to add the accept header, we get a form encoded body instead.
|
||||||
|
t.Run("FormEncodedBody", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
_, _ = w.Write([]byte(url.Values{"access_token": {"hey"}}.Encode()))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
|
ExternalAuthConfigs: []*externalauth.Config{{
|
||||||
|
ID: "test",
|
||||||
|
DeviceAuth: &externalauth.DeviceAuth{
|
||||||
|
ClientID: "test",
|
||||||
|
CodeURL: srv.URL,
|
||||||
|
Scopes: []string{"repo"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
coderdtest.CreateFirstUser(t, client)
|
||||||
|
_, err := client.ExternalAuthDeviceByID(context.Background(), "test")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "is form-url encoded")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:bodyclose
|
// nolint:bodyclose
|
||||||
|
|
|
@ -23,6 +23,7 @@ var (
|
||||||
expiry = flag.Duration("expiry", time.Minute*5, "Token expiry")
|
expiry = flag.Duration("expiry", time.Minute*5, "Token expiry")
|
||||||
clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random")
|
clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random")
|
||||||
clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random")
|
clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random")
|
||||||
|
deviceFlow = flag.Bool("device-flow", false, "Enable device flow")
|
||||||
// By default, no regex means it will never match anything. So at least default to matching something.
|
// By default, no regex means it will never match anything. So at least default to matching something.
|
||||||
extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex")
|
extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex")
|
||||||
)
|
)
|
||||||
|
@ -66,14 +67,18 @@ func RunIDP() func(t *testing.T) {
|
||||||
id, sec := idp.AppCredentials()
|
id, sec := idp.AppCredentials()
|
||||||
prov := idp.WellknownConfig()
|
prov := idp.WellknownConfig()
|
||||||
const appID = "fake"
|
const appID = "fake"
|
||||||
coderCfg := idp.ExternalAuthConfig(t, appID, nil)
|
coderCfg := idp.ExternalAuthConfig(t, appID, &oidctest.ExternalAuthConfigOptions{
|
||||||
|
UseDeviceAuth: *deviceFlow,
|
||||||
|
})
|
||||||
|
|
||||||
log.Println("IDP Issuer URL", idp.IssuerURL())
|
log.Println("IDP Issuer URL", idp.IssuerURL())
|
||||||
log.Println("Coderd Flags")
|
log.Println("Coderd Flags")
|
||||||
|
|
||||||
deviceCodeURL := ""
|
deviceCodeURL := ""
|
||||||
if coderCfg.DeviceAuth != nil {
|
if coderCfg.DeviceAuth != nil {
|
||||||
deviceCodeURL = coderCfg.DeviceAuth.CodeURL
|
deviceCodeURL = coderCfg.DeviceAuth.CodeURL
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := withClientSecret{
|
cfg := withClientSecret{
|
||||||
ClientSecret: sec,
|
ClientSecret: sec,
|
||||||
ExternalAuthConfig: codersdk.ExternalAuthConfig{
|
ExternalAuthConfig: codersdk.ExternalAuthConfig{
|
||||||
|
@ -89,13 +94,14 @@ func RunIDP() func(t *testing.T) {
|
||||||
NoRefresh: false,
|
NoRefresh: false,
|
||||||
Scopes: []string{"openid", "email", "profile"},
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
ExtraTokenKeys: coderCfg.ExtraTokenKeys,
|
ExtraTokenKeys: coderCfg.ExtraTokenKeys,
|
||||||
DeviceFlow: coderCfg.DeviceAuth != nil,
|
DeviceFlow: *deviceFlow,
|
||||||
DeviceCodeURL: deviceCodeURL,
|
DeviceCodeURL: deviceCodeURL,
|
||||||
Regex: *extRegex,
|
Regex: *extRegex,
|
||||||
DisplayName: coderCfg.DisplayName,
|
DisplayName: coderCfg.DisplayName,
|
||||||
DisplayIcon: coderCfg.DisplayIcon,
|
DisplayIcon: coderCfg.DisplayIcon,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal([]withClientSecret{cfg})
|
data, err := json.Marshal([]withClientSecret{cfg})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
log.Printf(`--external-auth-providers='%s'`, string(data))
|
log.Printf(`--external-auth-providers='%s'`, string(data))
|
||||||
|
|
Loading…
Reference in New Issue