From 5087f7b5f691b9e8508a5f58ce9c960bb157751f Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 15 Jan 2024 10:01:41 -0600 Subject: [PATCH] chore: improve fake IDP script (#11602) * chore: testIDP using static defaults for easier reuse --- cmd/testidp/main.go | 58 --------------- coderd/coderdtest/oidctest/idp.go | 113 +++++++++++++++++++++++------ coderd/externalauth_test.go | 2 +- {cmd => scripts}/testidp/README.md | 0 scripts/testidp/main.go | 111 ++++++++++++++++++++++++++++ 5 files changed, 204 insertions(+), 80 deletions(-) delete mode 100644 cmd/testidp/main.go rename {cmd => scripts}/testidp/README.md (100%) create mode 100644 scripts/testidp/main.go diff --git a/cmd/testidp/main.go b/cmd/testidp/main.go deleted file mode 100644 index fd96d0b84a..0000000000 --- a/cmd/testidp/main.go +++ /dev/null @@ -1,58 +0,0 @@ -package main - -import ( - "flag" - "log" - "os" - "os/signal" - "testing" - - "github.com/golang-jwt/jwt/v4" - - "github.com/coder/coder/v2/coderd/coderdtest/oidctest" -) - -func main() { - testing.Init() - _ = flag.Set("test.timeout", "0") - - flag.Parse() - - // This is just a way to run tests outside go test - testing.Main(func(pat, str string) (bool, error) { - return true, nil - }, []testing.InternalTest{ - { - Name: "Run Fake IDP", - F: RunIDP(), - }, - }, nil, nil) -} - -// RunIDP needs the testing.T because our oidctest package requires the -// testing.T. -func RunIDP() func(t *testing.T) { - return func(t *testing.T) { - idp := oidctest.NewFakeIDP(t, - oidctest.WithServing(), - oidctest.WithStaticUserInfo(jwt.MapClaims{}), - oidctest.WithDefaultIDClaims(jwt.MapClaims{}), - ) - id, sec := idp.AppCredentials() - prov := idp.WellknownConfig() - - log.Println("IDP Issuer URL", idp.IssuerURL()) - log.Println("Coderd Flags") - log.Printf(`--external-auth-providers='[{"type":"fake","client_id":"%s","client_secret":"%s","auth_url":"%s","token_url":"%s","validate_url":"%s","scopes":["openid","email","profile"]}]'`, - id, sec, prov.AuthURL, prov.TokenURL, prov.UserInfoURL, - ) - - log.Println("Press Ctrl+C to exit") - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - - // Block until ctl+c - <-c - log.Println("Closing") - } -} diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index 6b6936e346..e830bb0511 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -39,6 +39,12 @@ import ( "github.com/coder/coder/v2/codersdk" ) +type token struct { + issued time.Time + email string + exp time.Time +} + // FakeIDP is a functional OIDC provider. // It only supports 1 OIDC client. type FakeIDP struct { @@ -65,7 +71,7 @@ type FakeIDP struct { // That is the various access tokens, refresh tokens, states, etc. codeToStateMap *syncmap.Map[string, string] // Token -> Email - accessTokens *syncmap.Map[string, string] + accessTokens *syncmap.Map[string, token] // Refresh Token -> Email refreshTokensUsed *syncmap.Map[string, bool] refreshTokens *syncmap.Map[string, string] @@ -89,7 +95,8 @@ type FakeIDP struct { hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) serve bool // optional middlewares - middlewares chi.Middlewares + middlewares chi.Middlewares + defaultExpire time.Duration } func StatusError(code int, err error) error { @@ -134,6 +141,23 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) { } } +func WithDefaultExpire(d time.Duration) func(*FakeIDP) { + return func(f *FakeIDP) { + f.defaultExpire = d + } +} + +func WithStaticCredentials(id, secret string) func(*FakeIDP) { + return func(f *FakeIDP) { + if id != "" { + f.clientID = id + } + if secret != "" { + f.clientSecret = secret + } + } +} + // WithExtra returns extra fields that be accessed on the returned Oauth Token. // These extra fields can override the default fields (id_token, access_token, etc). func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) { @@ -155,6 +179,12 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { } } +func WithLogger(logger slog.Logger) func(*FakeIDP) { + return func(f *FakeIDP) { + f.logger = logger + } +} + // WithStaticUserInfo is optional, but will return the same user info for // every user on the /userinfo endpoint. func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { @@ -211,7 +241,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { clientSecret: uuid.NewString(), logger: slog.Make(), codeToStateMap: syncmap.New[string, string](), - accessTokens: syncmap.New[string, string](), + accessTokens: syncmap.New[string, token](), refreshTokens: syncmap.New[string, string](), refreshTokensUsed: syncmap.New[string, bool](), stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), @@ -219,6 +249,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { hookOnRefresh: func(_ string) error { return nil }, hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil }, hookValidRedirectURL: func(redirectURL string) error { return nil }, + defaultExpire: time.Minute * 5, } for _, opt := range opts { @@ -265,6 +296,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { Algorithms: []string{ "RS256", }, + ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(), } } @@ -272,8 +304,23 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { t.Helper() + srvURL := "localhost:0" + issURL, err := url.Parse(f.issuer) + if err == nil { + if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" { + srvURL = issURL.Host + } + } + + l, err := net.Listen("tcp", srvURL) + require.NoError(t, err, "failed to create listener") + ctx, cancel := context.WithCancel(context.Background()) - srv := httptest.NewUnstartedServer(f.handler) + srv := &httptest.Server{ + Listener: l, + Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5}, + } + srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx } @@ -495,6 +542,8 @@ type ProviderJSON struct { JWKSURL string `json:"jwks_uri"` UserInfoURL string `json:"userinfo_endpoint"` Algorithms []string `json:"id_token_signing_alg_values_supported"` + // This is custom + ExternalAuthURL string `json:"external_auth_url"` } // newCode enforces the code exchanged is actually a valid code @@ -507,9 +556,13 @@ func (f *FakeIDP) newCode(state string) string { // newToken enforces the access token exchanged is actually a valid access token // created by the IDP. -func (f *FakeIDP) newToken(email string) string { +func (f *FakeIDP) newToken(email string, expires time.Time) string { accessToken := uuid.NewString() - f.accessTokens.Store(accessToken, email) + f.accessTokens.Store(accessToken, token{ + issued: time.Now(), + email: email, + exp: expires, + }) return accessToken } @@ -525,10 +578,15 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request auth := req.Header.Get("Authorization") token := strings.TrimPrefix(auth, "Bearer ") - _, ok := f.accessTokens.Load(token) + authToken, ok := f.accessTokens.Load(token) if !ok { return "", xerrors.New("invalid access token") } + + if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) { + return "", xerrors.New("access token expired") + } + return token, nil } @@ -653,7 +711,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { values, err := f.authenticateOIDCClientRequest(t, r) f.logger.Info(r.Context(), "http idp call token", - slog.Error(err), + slog.F("valid", err == nil), + slog.F("grant_type", values.Get("grant_type")), slog.F("values", values.Encode()), ) if err != nil { @@ -731,15 +790,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { return } - exp := time.Now().Add(time.Minute * 5) + exp := time.Now().Add(f.defaultExpire) claims["exp"] = exp.UnixMilli() email := getEmail(claims) refreshToken := f.newRefreshTokens(email) token := map[string]interface{}{ - "access_token": f.newToken(email), + "access_token": f.newToken(email, exp), "refresh_token": refreshToken, "token_type": "Bearer", - "expires_in": int64((time.Minute * 5).Seconds()), + "expires_in": int64((f.defaultExpire).Seconds()), "id_token": f.encodeClaims(t, claims), } if f.hookMutateToken != nil { @@ -754,25 +813,31 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) { token, err := f.authenticateBearerTokenRequest(t, r) - f.logger.Info(r.Context(), "http call idp user info", - slog.Error(err), - slog.F("url", r.URL.String()), - ) if err != nil { - http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest) + http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized) return "", false } - email, ok = f.accessTokens.Load(token) + authToken, ok := f.accessTokens.Load(token) if !ok { t.Errorf("access token user for user_info has no email to indicate which user") - http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest) + http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized) return "", false } - return email, true + + if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) { + http.Error(rw, "auth token expired", http.StatusUnauthorized) + return "", false + } + + return authToken.email, true } mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { email, ok := validateMW(rw, r) + f.logger.Info(r.Context(), "http userinfo endpoint", + slog.F("valid", ok), + slog.F("email", email), + ) if !ok { return } @@ -790,6 +855,10 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { // should be strict, and this one needs to handle sub routes. mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { email, ok := validateMW(rw, r) + f.logger.Info(r.Context(), "http external auth validate", + slog.F("valid", ok), + slog.F("email", email), + ) if !ok { return } @@ -941,7 +1010,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu } f.externalProviderID = id f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) { - newPath := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/external-auth-validate/%s", id)) + newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate") switch newPath { // /user is ALWAYS supported under the `/` path too. case "/user", "/", "": @@ -965,6 +1034,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu } instrumentF := promoauth.NewFactory(prometheus.NewRegistry()) cfg := &externalauth.Config{ + DisplayName: id, InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)), ID: id, // No defaults for these fields by omitting the type @@ -972,11 +1042,12 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu DisplayIcon: f.WellknownConfig().UserInfoURL, // Omit the /user for the validate so we can easily append to it when modifying // the cfg for advanced tests. - ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: fmt.Sprintf("/external-auth-validate/%s", id)}).String(), + ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(), } for _, opt := range opts { opt(cfg) } + f.updateIssuerURL(t, f.issuer) return cfg } diff --git a/coderd/externalauth_test.go b/coderd/externalauth_test.go index e109405c4e..4892ad6598 100644 --- a/coderd/externalauth_test.go +++ b/coderd/externalauth_test.go @@ -126,7 +126,7 @@ func TestExternalAuthByID(t *testing.T) { client := coderdtest.New(t, &coderdtest.Options{ ExternalAuthConfigs: []*externalauth.Config{ fake.ExternalAuthConfig(t, providerID, routes, func(cfg *externalauth.Config) { - cfg.AppInstallationsURL = cfg.ValidateURL + "/installs" + cfg.AppInstallationsURL = strings.TrimSuffix(cfg.ValidateURL, "/") + "/installs" cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() }), }, diff --git a/cmd/testidp/README.md b/scripts/testidp/README.md similarity index 100% rename from cmd/testidp/README.md rename to scripts/testidp/README.md diff --git a/scripts/testidp/main.go b/scripts/testidp/main.go new file mode 100644 index 0000000000..49902eca17 --- /dev/null +++ b/scripts/testidp/main.go @@ -0,0 +1,111 @@ +package main + +import ( + "encoding/json" + "flag" + "log" + "os" + "os/signal" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coder/coder/v2/codersdk" +) + +// Flags +var ( + expiry = flag.Duration("expiry", time.Minute*5, "Token expiry") + clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random") + clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random") + // By default, no regex means it will never match anything. So at least default to matching something. + extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex") +) + +func main() { + testing.Init() + _ = flag.Set("test.timeout", "0") + + flag.Parse() + + // This is just a way to run tests outside go test + testing.Main(func(pat, str string) (bool, error) { + return true, nil + }, []testing.InternalTest{ + { + Name: "Run Fake IDP", + F: RunIDP(), + }, + }, nil, nil) +} + +type withClientSecret struct { + // We never unmarshal this in prod, but we need this field for testing. + ClientSecret string `json:"client_secret"` + codersdk.ExternalAuthConfig +} + +// RunIDP needs the testing.T because our oidctest package requires the +// testing.T. +func RunIDP() func(t *testing.T) { + return func(t *testing.T) { + idp := oidctest.NewFakeIDP(t, + oidctest.WithServing(), + oidctest.WithStaticUserInfo(jwt.MapClaims{}), + oidctest.WithDefaultIDClaims(jwt.MapClaims{}), + oidctest.WithDefaultExpire(*expiry), + oidctest.WithStaticCredentials(*clientID, *clientSecret), + oidctest.WithIssuer("http://localhost:4500"), + oidctest.WithLogger(slog.Make(sloghuman.Sink(os.Stderr))), + ) + id, sec := idp.AppCredentials() + prov := idp.WellknownConfig() + const appID = "fake" + coderCfg := idp.ExternalAuthConfig(t, appID, nil) + + log.Println("IDP Issuer URL", idp.IssuerURL()) + log.Println("Coderd Flags") + deviceCodeURL := "" + if coderCfg.DeviceAuth != nil { + deviceCodeURL = coderCfg.DeviceAuth.CodeURL + } + cfg := withClientSecret{ + ClientSecret: sec, + ExternalAuthConfig: codersdk.ExternalAuthConfig{ + Type: appID, + ClientID: id, + ClientSecret: sec, + ID: appID, + AuthURL: prov.AuthURL, + TokenURL: prov.TokenURL, + ValidateURL: prov.ExternalAuthURL, + AppInstallURL: coderCfg.AppInstallURL, + AppInstallationsURL: coderCfg.AppInstallationsURL, + NoRefresh: false, + Scopes: []string{"openid", "email", "profile"}, + ExtraTokenKeys: coderCfg.ExtraTokenKeys, + DeviceFlow: coderCfg.DeviceAuth != nil, + DeviceCodeURL: deviceCodeURL, + Regex: *extRegex, + DisplayName: coderCfg.DisplayName, + DisplayIcon: coderCfg.DisplayIcon, + }, + } + data, err := json.Marshal([]withClientSecret{cfg}) + require.NoError(t, err) + log.Printf(`--external-auth-providers='%s'`, string(data)) + + log.Println("Press Ctrl+C to exit") + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + // Block until ctl+c + <-c + log.Println("Closing") + } +}