mirror of https://github.com/coder/coder.git
chore: improve fake IDP script (#11602)
* chore: testIDP using static defaults for easier reuse
This commit is contained in:
parent
f915bdf26c
commit
5087f7b5f6
|
@ -1,58 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
)
|
||||
|
||||
func main() {
|
||||
testing.Init()
|
||||
_ = flag.Set("test.timeout", "0")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// This is just a way to run tests outside go test
|
||||
testing.Main(func(pat, str string) (bool, error) {
|
||||
return true, nil
|
||||
}, []testing.InternalTest{
|
||||
{
|
||||
Name: "Run Fake IDP",
|
||||
F: RunIDP(),
|
||||
},
|
||||
}, nil, nil)
|
||||
}
|
||||
|
||||
// RunIDP needs the testing.T because our oidctest package requires the
|
||||
// testing.T.
|
||||
func RunIDP() func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
idp := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithServing(),
|
||||
oidctest.WithStaticUserInfo(jwt.MapClaims{}),
|
||||
oidctest.WithDefaultIDClaims(jwt.MapClaims{}),
|
||||
)
|
||||
id, sec := idp.AppCredentials()
|
||||
prov := idp.WellknownConfig()
|
||||
|
||||
log.Println("IDP Issuer URL", idp.IssuerURL())
|
||||
log.Println("Coderd Flags")
|
||||
log.Printf(`--external-auth-providers='[{"type":"fake","client_id":"%s","client_secret":"%s","auth_url":"%s","token_url":"%s","validate_url":"%s","scopes":["openid","email","profile"]}]'`,
|
||||
id, sec, prov.AuthURL, prov.TokenURL, prov.UserInfoURL,
|
||||
)
|
||||
|
||||
log.Println("Press Ctrl+C to exit")
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
|
||||
// Block until ctl+c
|
||||
<-c
|
||||
log.Println("Closing")
|
||||
}
|
||||
}
|
|
@ -39,6 +39,12 @@ import (
|
|||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type token struct {
|
||||
issued time.Time
|
||||
email string
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
// FakeIDP is a functional OIDC provider.
|
||||
// It only supports 1 OIDC client.
|
||||
type FakeIDP struct {
|
||||
|
@ -65,7 +71,7 @@ type FakeIDP struct {
|
|||
// That is the various access tokens, refresh tokens, states, etc.
|
||||
codeToStateMap *syncmap.Map[string, string]
|
||||
// Token -> Email
|
||||
accessTokens *syncmap.Map[string, string]
|
||||
accessTokens *syncmap.Map[string, token]
|
||||
// Refresh Token -> Email
|
||||
refreshTokensUsed *syncmap.Map[string, bool]
|
||||
refreshTokens *syncmap.Map[string, string]
|
||||
|
@ -89,7 +95,8 @@ type FakeIDP struct {
|
|||
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
|
||||
serve bool
|
||||
// optional middlewares
|
||||
middlewares chi.Middlewares
|
||||
middlewares chi.Middlewares
|
||||
defaultExpire time.Duration
|
||||
}
|
||||
|
||||
func StatusError(code int, err error) error {
|
||||
|
@ -134,6 +141,23 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) {
|
|||
}
|
||||
}
|
||||
|
||||
func WithDefaultExpire(d time.Duration) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.defaultExpire = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithStaticCredentials(id, secret string) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
if id != "" {
|
||||
f.clientID = id
|
||||
}
|
||||
if secret != "" {
|
||||
f.clientSecret = secret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithExtra returns extra fields that be accessed on the returned Oauth Token.
|
||||
// These extra fields can override the default fields (id_token, access_token, etc).
|
||||
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
|
||||
|
@ -155,6 +179,12 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
|
|||
}
|
||||
}
|
||||
|
||||
func WithLogger(logger slog.Logger) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithStaticUserInfo is optional, but will return the same user info for
|
||||
// every user on the /userinfo endpoint.
|
||||
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
|
||||
|
@ -211,7 +241,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
|||
clientSecret: uuid.NewString(),
|
||||
logger: slog.Make(),
|
||||
codeToStateMap: syncmap.New[string, string](),
|
||||
accessTokens: syncmap.New[string, string](),
|
||||
accessTokens: syncmap.New[string, token](),
|
||||
refreshTokens: syncmap.New[string, string](),
|
||||
refreshTokensUsed: syncmap.New[string, bool](),
|
||||
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
|
@ -219,6 +249,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
|||
hookOnRefresh: func(_ string) error { return nil },
|
||||
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
|
||||
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||
defaultExpire: time.Minute * 5,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
|
@ -265,6 +296,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
|||
Algorithms: []string{
|
||||
"RS256",
|
||||
},
|
||||
ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -272,8 +304,23 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
|||
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
srvURL := "localhost:0"
|
||||
issURL, err := url.Parse(f.issuer)
|
||||
if err == nil {
|
||||
if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" {
|
||||
srvURL = issURL.Host
|
||||
}
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", srvURL)
|
||||
require.NoError(t, err, "failed to create listener")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
srv := httptest.NewUnstartedServer(f.handler)
|
||||
srv := &httptest.Server{
|
||||
Listener: l,
|
||||
Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5},
|
||||
}
|
||||
|
||||
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
@ -495,6 +542,8 @@ type ProviderJSON struct {
|
|||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||
// This is custom
|
||||
ExternalAuthURL string `json:"external_auth_url"`
|
||||
}
|
||||
|
||||
// newCode enforces the code exchanged is actually a valid code
|
||||
|
@ -507,9 +556,13 @@ func (f *FakeIDP) newCode(state string) string {
|
|||
|
||||
// newToken enforces the access token exchanged is actually a valid access token
|
||||
// created by the IDP.
|
||||
func (f *FakeIDP) newToken(email string) string {
|
||||
func (f *FakeIDP) newToken(email string, expires time.Time) string {
|
||||
accessToken := uuid.NewString()
|
||||
f.accessTokens.Store(accessToken, email)
|
||||
f.accessTokens.Store(accessToken, token{
|
||||
issued: time.Now(),
|
||||
email: email,
|
||||
exp: expires,
|
||||
})
|
||||
return accessToken
|
||||
}
|
||||
|
||||
|
@ -525,10 +578,15 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request
|
|||
|
||||
auth := req.Header.Get("Authorization")
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
_, ok := f.accessTokens.Load(token)
|
||||
authToken, ok := f.accessTokens.Load(token)
|
||||
if !ok {
|
||||
return "", xerrors.New("invalid access token")
|
||||
}
|
||||
|
||||
if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
|
||||
return "", xerrors.New("access token expired")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
|
@ -653,7 +711,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
|||
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
values, err := f.authenticateOIDCClientRequest(t, r)
|
||||
f.logger.Info(r.Context(), "http idp call token",
|
||||
slog.Error(err),
|
||||
slog.F("valid", err == nil),
|
||||
slog.F("grant_type", values.Get("grant_type")),
|
||||
slog.F("values", values.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -731,15 +790,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
|||
return
|
||||
}
|
||||
|
||||
exp := time.Now().Add(time.Minute * 5)
|
||||
exp := time.Now().Add(f.defaultExpire)
|
||||
claims["exp"] = exp.UnixMilli()
|
||||
email := getEmail(claims)
|
||||
refreshToken := f.newRefreshTokens(email)
|
||||
token := map[string]interface{}{
|
||||
"access_token": f.newToken(email),
|
||||
"access_token": f.newToken(email, exp),
|
||||
"refresh_token": refreshToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int64((time.Minute * 5).Seconds()),
|
||||
"expires_in": int64((f.defaultExpire).Seconds()),
|
||||
"id_token": f.encodeClaims(t, claims),
|
||||
}
|
||||
if f.hookMutateToken != nil {
|
||||
|
@ -754,25 +813,31 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
|||
|
||||
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
|
||||
token, err := f.authenticateBearerTokenRequest(t, r)
|
||||
f.logger.Info(r.Context(), "http call idp user info",
|
||||
slog.Error(err),
|
||||
slog.F("url", r.URL.String()),
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
|
||||
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized)
|
||||
return "", false
|
||||
}
|
||||
|
||||
email, ok = f.accessTokens.Load(token)
|
||||
authToken, ok := f.accessTokens.Load(token)
|
||||
if !ok {
|
||||
t.Errorf("access token user for user_info has no email to indicate which user")
|
||||
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
|
||||
http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized)
|
||||
return "", false
|
||||
}
|
||||
return email, true
|
||||
|
||||
if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
|
||||
http.Error(rw, "auth token expired", http.StatusUnauthorized)
|
||||
return "", false
|
||||
}
|
||||
|
||||
return authToken.email, true
|
||||
}
|
||||
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
email, ok := validateMW(rw, r)
|
||||
f.logger.Info(r.Context(), "http userinfo endpoint",
|
||||
slog.F("valid", ok),
|
||||
slog.F("email", email),
|
||||
)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
@ -790,6 +855,10 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
|||
// should be strict, and this one needs to handle sub routes.
|
||||
mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
email, ok := validateMW(rw, r)
|
||||
f.logger.Info(r.Context(), "http external auth validate",
|
||||
slog.F("valid", ok),
|
||||
slog.F("email", email),
|
||||
)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
@ -941,7 +1010,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
|||
}
|
||||
f.externalProviderID = id
|
||||
f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) {
|
||||
newPath := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/external-auth-validate/%s", id))
|
||||
newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate")
|
||||
switch newPath {
|
||||
// /user is ALWAYS supported under the `/` path too.
|
||||
case "/user", "/", "":
|
||||
|
@ -965,6 +1034,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
|||
}
|
||||
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
|
||||
cfg := &externalauth.Config{
|
||||
DisplayName: id,
|
||||
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
|
||||
ID: id,
|
||||
// No defaults for these fields by omitting the type
|
||||
|
@ -972,11 +1042,12 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
|||
DisplayIcon: f.WellknownConfig().UserInfoURL,
|
||||
// Omit the /user for the validate so we can easily append to it when modifying
|
||||
// the cfg for advanced tests.
|
||||
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: fmt.Sprintf("/external-auth-validate/%s", id)}).String(),
|
||||
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
f.updateIssuerURL(t, f.issuer)
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ func TestExternalAuthByID(t *testing.T) {
|
|||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
ExternalAuthConfigs: []*externalauth.Config{
|
||||
fake.ExternalAuthConfig(t, providerID, routes, func(cfg *externalauth.Config) {
|
||||
cfg.AppInstallationsURL = cfg.ValidateURL + "/installs"
|
||||
cfg.AppInstallationsURL = strings.TrimSuffix(cfg.ValidateURL, "/") + "/installs"
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
}),
|
||||
},
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Flags
|
||||
var (
|
||||
expiry = flag.Duration("expiry", time.Minute*5, "Token expiry")
|
||||
clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random")
|
||||
clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random")
|
||||
// By default, no regex means it will never match anything. So at least default to matching something.
|
||||
extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex")
|
||||
)
|
||||
|
||||
func main() {
|
||||
testing.Init()
|
||||
_ = flag.Set("test.timeout", "0")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// This is just a way to run tests outside go test
|
||||
testing.Main(func(pat, str string) (bool, error) {
|
||||
return true, nil
|
||||
}, []testing.InternalTest{
|
||||
{
|
||||
Name: "Run Fake IDP",
|
||||
F: RunIDP(),
|
||||
},
|
||||
}, nil, nil)
|
||||
}
|
||||
|
||||
type withClientSecret struct {
|
||||
// We never unmarshal this in prod, but we need this field for testing.
|
||||
ClientSecret string `json:"client_secret"`
|
||||
codersdk.ExternalAuthConfig
|
||||
}
|
||||
|
||||
// RunIDP needs the testing.T because our oidctest package requires the
|
||||
// testing.T.
|
||||
func RunIDP() func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
idp := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithServing(),
|
||||
oidctest.WithStaticUserInfo(jwt.MapClaims{}),
|
||||
oidctest.WithDefaultIDClaims(jwt.MapClaims{}),
|
||||
oidctest.WithDefaultExpire(*expiry),
|
||||
oidctest.WithStaticCredentials(*clientID, *clientSecret),
|
||||
oidctest.WithIssuer("http://localhost:4500"),
|
||||
oidctest.WithLogger(slog.Make(sloghuman.Sink(os.Stderr))),
|
||||
)
|
||||
id, sec := idp.AppCredentials()
|
||||
prov := idp.WellknownConfig()
|
||||
const appID = "fake"
|
||||
coderCfg := idp.ExternalAuthConfig(t, appID, nil)
|
||||
|
||||
log.Println("IDP Issuer URL", idp.IssuerURL())
|
||||
log.Println("Coderd Flags")
|
||||
deviceCodeURL := ""
|
||||
if coderCfg.DeviceAuth != nil {
|
||||
deviceCodeURL = coderCfg.DeviceAuth.CodeURL
|
||||
}
|
||||
cfg := withClientSecret{
|
||||
ClientSecret: sec,
|
||||
ExternalAuthConfig: codersdk.ExternalAuthConfig{
|
||||
Type: appID,
|
||||
ClientID: id,
|
||||
ClientSecret: sec,
|
||||
ID: appID,
|
||||
AuthURL: prov.AuthURL,
|
||||
TokenURL: prov.TokenURL,
|
||||
ValidateURL: prov.ExternalAuthURL,
|
||||
AppInstallURL: coderCfg.AppInstallURL,
|
||||
AppInstallationsURL: coderCfg.AppInstallationsURL,
|
||||
NoRefresh: false,
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
ExtraTokenKeys: coderCfg.ExtraTokenKeys,
|
||||
DeviceFlow: coderCfg.DeviceAuth != nil,
|
||||
DeviceCodeURL: deviceCodeURL,
|
||||
Regex: *extRegex,
|
||||
DisplayName: coderCfg.DisplayName,
|
||||
DisplayIcon: coderCfg.DisplayIcon,
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal([]withClientSecret{cfg})
|
||||
require.NoError(t, err)
|
||||
log.Printf(`--external-auth-providers='%s'`, string(data))
|
||||
|
||||
log.Println("Press Ctrl+C to exit")
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
|
||||
// Block until ctl+c
|
||||
<-c
|
||||
log.Println("Closing")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue