chore: improve fake IDP script (#11602)

* chore: testIDP using static defaults for easier reuse
This commit is contained in:
Steven Masley 2024-01-15 10:01:41 -06:00 committed by GitHub
parent f915bdf26c
commit 5087f7b5f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 204 additions and 80 deletions

View File

@ -1,58 +0,0 @@
package main
import (
"flag"
"log"
"os"
"os/signal"
"testing"
"github.com/golang-jwt/jwt/v4"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
)
func main() {
testing.Init()
_ = flag.Set("test.timeout", "0")
flag.Parse()
// This is just a way to run tests outside go test
testing.Main(func(pat, str string) (bool, error) {
return true, nil
}, []testing.InternalTest{
{
Name: "Run Fake IDP",
F: RunIDP(),
},
}, nil, nil)
}
// RunIDP needs the testing.T because our oidctest package requires the
// testing.T.
func RunIDP() func(t *testing.T) {
return func(t *testing.T) {
idp := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
oidctest.WithStaticUserInfo(jwt.MapClaims{}),
oidctest.WithDefaultIDClaims(jwt.MapClaims{}),
)
id, sec := idp.AppCredentials()
prov := idp.WellknownConfig()
log.Println("IDP Issuer URL", idp.IssuerURL())
log.Println("Coderd Flags")
log.Printf(`--external-auth-providers='[{"type":"fake","client_id":"%s","client_secret":"%s","auth_url":"%s","token_url":"%s","validate_url":"%s","scopes":["openid","email","profile"]}]'`,
id, sec, prov.AuthURL, prov.TokenURL, prov.UserInfoURL,
)
log.Println("Press Ctrl+C to exit")
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
// Block until ctl+c
<-c
log.Println("Closing")
}
}

View File

@ -39,6 +39,12 @@ import (
"github.com/coder/coder/v2/codersdk"
)
type token struct {
issued time.Time
email string
exp time.Time
}
// FakeIDP is a functional OIDC provider.
// It only supports 1 OIDC client.
type FakeIDP struct {
@ -65,7 +71,7 @@ type FakeIDP struct {
// That is the various access tokens, refresh tokens, states, etc.
codeToStateMap *syncmap.Map[string, string]
// Token -> Email
accessTokens *syncmap.Map[string, string]
accessTokens *syncmap.Map[string, token]
// Refresh Token -> Email
refreshTokensUsed *syncmap.Map[string, bool]
refreshTokens *syncmap.Map[string, string]
@ -89,7 +95,8 @@ type FakeIDP struct {
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
serve bool
// optional middlewares
middlewares chi.Middlewares
middlewares chi.Middlewares
defaultExpire time.Duration
}
func StatusError(code int, err error) error {
@ -134,6 +141,23 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) {
}
}
func WithDefaultExpire(d time.Duration) func(*FakeIDP) {
return func(f *FakeIDP) {
f.defaultExpire = d
}
}
func WithStaticCredentials(id, secret string) func(*FakeIDP) {
return func(f *FakeIDP) {
if id != "" {
f.clientID = id
}
if secret != "" {
f.clientSecret = secret
}
}
}
// WithExtra returns extra fields that be accessed on the returned Oauth Token.
// These extra fields can override the default fields (id_token, access_token, etc).
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
@ -155,6 +179,12 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
}
}
func WithLogger(logger slog.Logger) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = logger
}
}
// WithStaticUserInfo is optional, but will return the same user info for
// every user on the /userinfo endpoint.
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
@ -211,7 +241,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
clientSecret: uuid.NewString(),
logger: slog.Make(),
codeToStateMap: syncmap.New[string, string](),
accessTokens: syncmap.New[string, string](),
accessTokens: syncmap.New[string, token](),
refreshTokens: syncmap.New[string, string](),
refreshTokensUsed: syncmap.New[string, bool](),
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
@ -219,6 +249,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
hookValidRedirectURL: func(redirectURL string) error { return nil },
defaultExpire: time.Minute * 5,
}
for _, opt := range opts {
@ -265,6 +296,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
Algorithms: []string{
"RS256",
},
ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
}
}
@ -272,8 +304,23 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
t.Helper()
srvURL := "localhost:0"
issURL, err := url.Parse(f.issuer)
if err == nil {
if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" {
srvURL = issURL.Host
}
}
l, err := net.Listen("tcp", srvURL)
require.NoError(t, err, "failed to create listener")
ctx, cancel := context.WithCancel(context.Background())
srv := httptest.NewUnstartedServer(f.handler)
srv := &httptest.Server{
Listener: l,
Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5},
}
srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
@ -495,6 +542,8 @@ type ProviderJSON struct {
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
// This is custom
ExternalAuthURL string `json:"external_auth_url"`
}
// newCode enforces the code exchanged is actually a valid code
@ -507,9 +556,13 @@ func (f *FakeIDP) newCode(state string) string {
// newToken enforces the access token exchanged is actually a valid access token
// created by the IDP.
func (f *FakeIDP) newToken(email string) string {
func (f *FakeIDP) newToken(email string, expires time.Time) string {
accessToken := uuid.NewString()
f.accessTokens.Store(accessToken, email)
f.accessTokens.Store(accessToken, token{
issued: time.Now(),
email: email,
exp: expires,
})
return accessToken
}
@ -525,10 +578,15 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request
auth := req.Header.Get("Authorization")
token := strings.TrimPrefix(auth, "Bearer ")
_, ok := f.accessTokens.Load(token)
authToken, ok := f.accessTokens.Load(token)
if !ok {
return "", xerrors.New("invalid access token")
}
if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
return "", xerrors.New("access token expired")
}
return token, nil
}
@ -653,7 +711,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
values, err := f.authenticateOIDCClientRequest(t, r)
f.logger.Info(r.Context(), "http idp call token",
slog.Error(err),
slog.F("valid", err == nil),
slog.F("grant_type", values.Get("grant_type")),
slog.F("values", values.Encode()),
)
if err != nil {
@ -731,15 +790,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
return
}
exp := time.Now().Add(time.Minute * 5)
exp := time.Now().Add(f.defaultExpire)
claims["exp"] = exp.UnixMilli()
email := getEmail(claims)
refreshToken := f.newRefreshTokens(email)
token := map[string]interface{}{
"access_token": f.newToken(email),
"access_token": f.newToken(email, exp),
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": int64((time.Minute * 5).Seconds()),
"expires_in": int64((f.defaultExpire).Seconds()),
"id_token": f.encodeClaims(t, claims),
}
if f.hookMutateToken != nil {
@ -754,25 +813,31 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
token, err := f.authenticateBearerTokenRequest(t, r)
f.logger.Info(r.Context(), "http call idp user info",
slog.Error(err),
slog.F("url", r.URL.String()),
)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized)
return "", false
}
email, ok = f.accessTokens.Load(token)
authToken, ok := f.accessTokens.Load(token)
if !ok {
t.Errorf("access token user for user_info has no email to indicate which user")
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized)
return "", false
}
return email, true
if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
http.Error(rw, "auth token expired", http.StatusUnauthorized)
return "", false
}
return authToken.email, true
}
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email, ok := validateMW(rw, r)
f.logger.Info(r.Context(), "http userinfo endpoint",
slog.F("valid", ok),
slog.F("email", email),
)
if !ok {
return
}
@ -790,6 +855,10 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
// should be strict, and this one needs to handle sub routes.
mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
email, ok := validateMW(rw, r)
f.logger.Info(r.Context(), "http external auth validate",
slog.F("valid", ok),
slog.F("email", email),
)
if !ok {
return
}
@ -941,7 +1010,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
}
f.externalProviderID = id
f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) {
newPath := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/external-auth-validate/%s", id))
newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate")
switch newPath {
// /user is ALWAYS supported under the `/` path too.
case "/user", "/", "":
@ -965,6 +1034,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
}
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
cfg := &externalauth.Config{
DisplayName: id,
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
ID: id,
// No defaults for these fields by omitting the type
@ -972,11 +1042,12 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
DisplayIcon: f.WellknownConfig().UserInfoURL,
// Omit the /user for the validate so we can easily append to it when modifying
// the cfg for advanced tests.
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: fmt.Sprintf("/external-auth-validate/%s", id)}).String(),
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
}
for _, opt := range opts {
opt(cfg)
}
f.updateIssuerURL(t, f.issuer)
return cfg
}

View File

@ -126,7 +126,7 @@ func TestExternalAuthByID(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
ExternalAuthConfigs: []*externalauth.Config{
fake.ExternalAuthConfig(t, providerID, routes, func(cfg *externalauth.Config) {
cfg.AppInstallationsURL = cfg.ValidateURL + "/installs"
cfg.AppInstallationsURL = strings.TrimSuffix(cfg.ValidateURL, "/") + "/installs"
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
}),
},

111
scripts/testidp/main.go Normal file
View File

@ -0,0 +1,111 @@
package main
import (
"encoding/json"
"flag"
"log"
"os"
"os/signal"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/codersdk"
)
// Flags
var (
expiry = flag.Duration("expiry", time.Minute*5, "Token expiry")
clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random")
clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random")
// By default, no regex means it will never match anything. So at least default to matching something.
extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex")
)
func main() {
testing.Init()
_ = flag.Set("test.timeout", "0")
flag.Parse()
// This is just a way to run tests outside go test
testing.Main(func(pat, str string) (bool, error) {
return true, nil
}, []testing.InternalTest{
{
Name: "Run Fake IDP",
F: RunIDP(),
},
}, nil, nil)
}
type withClientSecret struct {
// We never unmarshal this in prod, but we need this field for testing.
ClientSecret string `json:"client_secret"`
codersdk.ExternalAuthConfig
}
// RunIDP needs the testing.T because our oidctest package requires the
// testing.T.
func RunIDP() func(t *testing.T) {
return func(t *testing.T) {
idp := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
oidctest.WithStaticUserInfo(jwt.MapClaims{}),
oidctest.WithDefaultIDClaims(jwt.MapClaims{}),
oidctest.WithDefaultExpire(*expiry),
oidctest.WithStaticCredentials(*clientID, *clientSecret),
oidctest.WithIssuer("http://localhost:4500"),
oidctest.WithLogger(slog.Make(sloghuman.Sink(os.Stderr))),
)
id, sec := idp.AppCredentials()
prov := idp.WellknownConfig()
const appID = "fake"
coderCfg := idp.ExternalAuthConfig(t, appID, nil)
log.Println("IDP Issuer URL", idp.IssuerURL())
log.Println("Coderd Flags")
deviceCodeURL := ""
if coderCfg.DeviceAuth != nil {
deviceCodeURL = coderCfg.DeviceAuth.CodeURL
}
cfg := withClientSecret{
ClientSecret: sec,
ExternalAuthConfig: codersdk.ExternalAuthConfig{
Type: appID,
ClientID: id,
ClientSecret: sec,
ID: appID,
AuthURL: prov.AuthURL,
TokenURL: prov.TokenURL,
ValidateURL: prov.ExternalAuthURL,
AppInstallURL: coderCfg.AppInstallURL,
AppInstallationsURL: coderCfg.AppInstallationsURL,
NoRefresh: false,
Scopes: []string{"openid", "email", "profile"},
ExtraTokenKeys: coderCfg.ExtraTokenKeys,
DeviceFlow: coderCfg.DeviceAuth != nil,
DeviceCodeURL: deviceCodeURL,
Regex: *extRegex,
DisplayName: coderCfg.DisplayName,
DisplayIcon: coderCfg.DisplayIcon,
},
}
data, err := json.Marshal([]withClientSecret{cfg})
require.NoError(t, err)
log.Printf(`--external-auth-providers='%s'`, string(data))
log.Println("Press Ctrl+C to exit")
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
// Block until ctl+c
<-c
log.Println("Closing")
}
}