mirror of https://github.com/coder/coder.git
chore: improve fake IDP script (#11602)
* chore: testIDP using static defaults for easier reuse
This commit is contained in:
parent
f915bdf26c
commit
5087f7b5f6
|
@ -1,58 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
|
||||||
|
|
||||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
testing.Init()
|
|
||||||
_ = flag.Set("test.timeout", "0")
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
// This is just a way to run tests outside go test
|
|
||||||
testing.Main(func(pat, str string) (bool, error) {
|
|
||||||
return true, nil
|
|
||||||
}, []testing.InternalTest{
|
|
||||||
{
|
|
||||||
Name: "Run Fake IDP",
|
|
||||||
F: RunIDP(),
|
|
||||||
},
|
|
||||||
}, nil, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunIDP needs the testing.T because our oidctest package requires the
|
|
||||||
// testing.T.
|
|
||||||
func RunIDP() func(t *testing.T) {
|
|
||||||
return func(t *testing.T) {
|
|
||||||
idp := oidctest.NewFakeIDP(t,
|
|
||||||
oidctest.WithServing(),
|
|
||||||
oidctest.WithStaticUserInfo(jwt.MapClaims{}),
|
|
||||||
oidctest.WithDefaultIDClaims(jwt.MapClaims{}),
|
|
||||||
)
|
|
||||||
id, sec := idp.AppCredentials()
|
|
||||||
prov := idp.WellknownConfig()
|
|
||||||
|
|
||||||
log.Println("IDP Issuer URL", idp.IssuerURL())
|
|
||||||
log.Println("Coderd Flags")
|
|
||||||
log.Printf(`--external-auth-providers='[{"type":"fake","client_id":"%s","client_secret":"%s","auth_url":"%s","token_url":"%s","validate_url":"%s","scopes":["openid","email","profile"]}]'`,
|
|
||||||
id, sec, prov.AuthURL, prov.TokenURL, prov.UserInfoURL,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.Println("Press Ctrl+C to exit")
|
|
||||||
c := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(c, os.Interrupt)
|
|
||||||
|
|
||||||
// Block until ctl+c
|
|
||||||
<-c
|
|
||||||
log.Println("Closing")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -39,6 +39,12 @@ import (
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type token struct {
|
||||||
|
issued time.Time
|
||||||
|
email string
|
||||||
|
exp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// FakeIDP is a functional OIDC provider.
|
// FakeIDP is a functional OIDC provider.
|
||||||
// It only supports 1 OIDC client.
|
// It only supports 1 OIDC client.
|
||||||
type FakeIDP struct {
|
type FakeIDP struct {
|
||||||
|
@ -65,7 +71,7 @@ type FakeIDP struct {
|
||||||
// That is the various access tokens, refresh tokens, states, etc.
|
// That is the various access tokens, refresh tokens, states, etc.
|
||||||
codeToStateMap *syncmap.Map[string, string]
|
codeToStateMap *syncmap.Map[string, string]
|
||||||
// Token -> Email
|
// Token -> Email
|
||||||
accessTokens *syncmap.Map[string, string]
|
accessTokens *syncmap.Map[string, token]
|
||||||
// Refresh Token -> Email
|
// Refresh Token -> Email
|
||||||
refreshTokensUsed *syncmap.Map[string, bool]
|
refreshTokensUsed *syncmap.Map[string, bool]
|
||||||
refreshTokens *syncmap.Map[string, string]
|
refreshTokens *syncmap.Map[string, string]
|
||||||
|
@ -89,7 +95,8 @@ type FakeIDP struct {
|
||||||
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
|
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
|
||||||
serve bool
|
serve bool
|
||||||
// optional middlewares
|
// optional middlewares
|
||||||
middlewares chi.Middlewares
|
middlewares chi.Middlewares
|
||||||
|
defaultExpire time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func StatusError(code int, err error) error {
|
func StatusError(code int, err error) error {
|
||||||
|
@ -134,6 +141,23 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithDefaultExpire(d time.Duration) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.defaultExpire = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStaticCredentials(id, secret string) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
if id != "" {
|
||||||
|
f.clientID = id
|
||||||
|
}
|
||||||
|
if secret != "" {
|
||||||
|
f.clientSecret = secret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithExtra returns extra fields that be accessed on the returned Oauth Token.
|
// WithExtra returns extra fields that be accessed on the returned Oauth Token.
|
||||||
// These extra fields can override the default fields (id_token, access_token, etc).
|
// These extra fields can override the default fields (id_token, access_token, etc).
|
||||||
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
|
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
|
||||||
|
@ -155,6 +179,12 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithLogger(logger slog.Logger) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithStaticUserInfo is optional, but will return the same user info for
|
// WithStaticUserInfo is optional, but will return the same user info for
|
||||||
// every user on the /userinfo endpoint.
|
// every user on the /userinfo endpoint.
|
||||||
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
|
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
|
||||||
|
@ -211,7 +241,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||||
clientSecret: uuid.NewString(),
|
clientSecret: uuid.NewString(),
|
||||||
logger: slog.Make(),
|
logger: slog.Make(),
|
||||||
codeToStateMap: syncmap.New[string, string](),
|
codeToStateMap: syncmap.New[string, string](),
|
||||||
accessTokens: syncmap.New[string, string](),
|
accessTokens: syncmap.New[string, token](),
|
||||||
refreshTokens: syncmap.New[string, string](),
|
refreshTokens: syncmap.New[string, string](),
|
||||||
refreshTokensUsed: syncmap.New[string, bool](),
|
refreshTokensUsed: syncmap.New[string, bool](),
|
||||||
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||||
|
@ -219,6 +249,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||||
hookOnRefresh: func(_ string) error { return nil },
|
hookOnRefresh: func(_ string) error { return nil },
|
||||||
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
|
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
|
||||||
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||||
|
defaultExpire: time.Minute * 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
|
@ -265,6 +296,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||||
Algorithms: []string{
|
Algorithms: []string{
|
||||||
"RS256",
|
"RS256",
|
||||||
},
|
},
|
||||||
|
ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -272,8 +304,23 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||||
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
srvURL := "localhost:0"
|
||||||
|
issURL, err := url.Parse(f.issuer)
|
||||||
|
if err == nil {
|
||||||
|
if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" {
|
||||||
|
srvURL = issURL.Host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := net.Listen("tcp", srvURL)
|
||||||
|
require.NoError(t, err, "failed to create listener")
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
srv := httptest.NewUnstartedServer(f.handler)
|
srv := &httptest.Server{
|
||||||
|
Listener: l,
|
||||||
|
Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5},
|
||||||
|
}
|
||||||
|
|
||||||
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
@ -495,6 +542,8 @@ type ProviderJSON struct {
|
||||||
JWKSURL string `json:"jwks_uri"`
|
JWKSURL string `json:"jwks_uri"`
|
||||||
UserInfoURL string `json:"userinfo_endpoint"`
|
UserInfoURL string `json:"userinfo_endpoint"`
|
||||||
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||||
|
// This is custom
|
||||||
|
ExternalAuthURL string `json:"external_auth_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// newCode enforces the code exchanged is actually a valid code
|
// newCode enforces the code exchanged is actually a valid code
|
||||||
|
@ -507,9 +556,13 @@ func (f *FakeIDP) newCode(state string) string {
|
||||||
|
|
||||||
// newToken enforces the access token exchanged is actually a valid access token
|
// newToken enforces the access token exchanged is actually a valid access token
|
||||||
// created by the IDP.
|
// created by the IDP.
|
||||||
func (f *FakeIDP) newToken(email string) string {
|
func (f *FakeIDP) newToken(email string, expires time.Time) string {
|
||||||
accessToken := uuid.NewString()
|
accessToken := uuid.NewString()
|
||||||
f.accessTokens.Store(accessToken, email)
|
f.accessTokens.Store(accessToken, token{
|
||||||
|
issued: time.Now(),
|
||||||
|
email: email,
|
||||||
|
exp: expires,
|
||||||
|
})
|
||||||
return accessToken
|
return accessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -525,10 +578,15 @@ func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request
|
||||||
|
|
||||||
auth := req.Header.Get("Authorization")
|
auth := req.Header.Get("Authorization")
|
||||||
token := strings.TrimPrefix(auth, "Bearer ")
|
token := strings.TrimPrefix(auth, "Bearer ")
|
||||||
_, ok := f.accessTokens.Load(token)
|
authToken, ok := f.accessTokens.Load(token)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", xerrors.New("invalid access token")
|
return "", xerrors.New("invalid access token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
|
||||||
|
return "", xerrors.New("access token expired")
|
||||||
|
}
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -653,7 +711,8 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
values, err := f.authenticateOIDCClientRequest(t, r)
|
values, err := f.authenticateOIDCClientRequest(t, r)
|
||||||
f.logger.Info(r.Context(), "http idp call token",
|
f.logger.Info(r.Context(), "http idp call token",
|
||||||
slog.Error(err),
|
slog.F("valid", err == nil),
|
||||||
|
slog.F("grant_type", values.Get("grant_type")),
|
||||||
slog.F("values", values.Encode()),
|
slog.F("values", values.Encode()),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -731,15 +790,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
exp := time.Now().Add(time.Minute * 5)
|
exp := time.Now().Add(f.defaultExpire)
|
||||||
claims["exp"] = exp.UnixMilli()
|
claims["exp"] = exp.UnixMilli()
|
||||||
email := getEmail(claims)
|
email := getEmail(claims)
|
||||||
refreshToken := f.newRefreshTokens(email)
|
refreshToken := f.newRefreshTokens(email)
|
||||||
token := map[string]interface{}{
|
token := map[string]interface{}{
|
||||||
"access_token": f.newToken(email),
|
"access_token": f.newToken(email, exp),
|
||||||
"refresh_token": refreshToken,
|
"refresh_token": refreshToken,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
"expires_in": int64((time.Minute * 5).Seconds()),
|
"expires_in": int64((f.defaultExpire).Seconds()),
|
||||||
"id_token": f.encodeClaims(t, claims),
|
"id_token": f.encodeClaims(t, claims),
|
||||||
}
|
}
|
||||||
if f.hookMutateToken != nil {
|
if f.hookMutateToken != nil {
|
||||||
|
@ -754,25 +813,31 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
|
|
||||||
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
|
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
|
||||||
token, err := f.authenticateBearerTokenRequest(t, r)
|
token, err := f.authenticateBearerTokenRequest(t, r)
|
||||||
f.logger.Info(r.Context(), "http call idp user info",
|
|
||||||
slog.Error(err),
|
|
||||||
slog.F("url", r.URL.String()),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
|
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusUnauthorized)
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
email, ok = f.accessTokens.Load(token)
|
authToken, ok := f.accessTokens.Load(token)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("access token user for user_info has no email to indicate which user")
|
t.Errorf("access token user for user_info has no email to indicate which user")
|
||||||
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
|
http.Error(rw, "invalid access token, missing user info", http.StatusUnauthorized)
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
return email, true
|
|
||||||
|
if !authToken.exp.IsZero() && authToken.exp.Before(time.Now()) {
|
||||||
|
http.Error(rw, "auth token expired", http.StatusUnauthorized)
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return authToken.email, true
|
||||||
}
|
}
|
||||||
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
email, ok := validateMW(rw, r)
|
email, ok := validateMW(rw, r)
|
||||||
|
f.logger.Info(r.Context(), "http userinfo endpoint",
|
||||||
|
slog.F("valid", ok),
|
||||||
|
slog.F("email", email),
|
||||||
|
)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -790,6 +855,10 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
// should be strict, and this one needs to handle sub routes.
|
// should be strict, and this one needs to handle sub routes.
|
||||||
mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
mux.Mount("/external-auth-validate/", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
email, ok := validateMW(rw, r)
|
email, ok := validateMW(rw, r)
|
||||||
|
f.logger.Info(r.Context(), "http external auth validate",
|
||||||
|
slog.F("valid", ok),
|
||||||
|
slog.F("email", email),
|
||||||
|
)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -941,7 +1010,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||||
}
|
}
|
||||||
f.externalProviderID = id
|
f.externalProviderID = id
|
||||||
f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) {
|
f.externalAuthValidate = func(email string, rw http.ResponseWriter, r *http.Request) {
|
||||||
newPath := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/external-auth-validate/%s", id))
|
newPath := strings.TrimPrefix(r.URL.Path, "/external-auth-validate")
|
||||||
switch newPath {
|
switch newPath {
|
||||||
// /user is ALWAYS supported under the `/` path too.
|
// /user is ALWAYS supported under the `/` path too.
|
||||||
case "/user", "/", "":
|
case "/user", "/", "":
|
||||||
|
@ -965,6 +1034,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||||
}
|
}
|
||||||
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
|
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
|
||||||
cfg := &externalauth.Config{
|
cfg := &externalauth.Config{
|
||||||
|
DisplayName: id,
|
||||||
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
|
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
|
||||||
ID: id,
|
ID: id,
|
||||||
// No defaults for these fields by omitting the type
|
// No defaults for these fields by omitting the type
|
||||||
|
@ -972,11 +1042,12 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||||
DisplayIcon: f.WellknownConfig().UserInfoURL,
|
DisplayIcon: f.WellknownConfig().UserInfoURL,
|
||||||
// Omit the /user for the validate so we can easily append to it when modifying
|
// Omit the /user for the validate so we can easily append to it when modifying
|
||||||
// the cfg for advanced tests.
|
// the cfg for advanced tests.
|
||||||
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: fmt.Sprintf("/external-auth-validate/%s", id)}).String(),
|
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(cfg)
|
opt(cfg)
|
||||||
}
|
}
|
||||||
|
f.updateIssuerURL(t, f.issuer)
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,7 @@ func TestExternalAuthByID(t *testing.T) {
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
ExternalAuthConfigs: []*externalauth.Config{
|
ExternalAuthConfigs: []*externalauth.Config{
|
||||||
fake.ExternalAuthConfig(t, providerID, routes, func(cfg *externalauth.Config) {
|
fake.ExternalAuthConfig(t, providerID, routes, func(cfg *externalauth.Config) {
|
||||||
cfg.AppInstallationsURL = cfg.ValidateURL + "/installs"
|
cfg.AppInstallationsURL = strings.TrimSuffix(cfg.ValidateURL, "/") + "/installs"
|
||||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"cdr.dev/slog/sloggers/sloghuman"
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Flags
|
||||||
|
var (
|
||||||
|
expiry = flag.Duration("expiry", time.Minute*5, "Token expiry")
|
||||||
|
clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random")
|
||||||
|
clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random")
|
||||||
|
// By default, no regex means it will never match anything. So at least default to matching something.
|
||||||
|
extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
testing.Init()
|
||||||
|
_ = flag.Set("test.timeout", "0")
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// This is just a way to run tests outside go test
|
||||||
|
testing.Main(func(pat, str string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}, []testing.InternalTest{
|
||||||
|
{
|
||||||
|
Name: "Run Fake IDP",
|
||||||
|
F: RunIDP(),
|
||||||
|
},
|
||||||
|
}, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
type withClientSecret struct {
|
||||||
|
// We never unmarshal this in prod, but we need this field for testing.
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
codersdk.ExternalAuthConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunIDP needs the testing.T because our oidctest package requires the
|
||||||
|
// testing.T.
|
||||||
|
func RunIDP() func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
idp := oidctest.NewFakeIDP(t,
|
||||||
|
oidctest.WithServing(),
|
||||||
|
oidctest.WithStaticUserInfo(jwt.MapClaims{}),
|
||||||
|
oidctest.WithDefaultIDClaims(jwt.MapClaims{}),
|
||||||
|
oidctest.WithDefaultExpire(*expiry),
|
||||||
|
oidctest.WithStaticCredentials(*clientID, *clientSecret),
|
||||||
|
oidctest.WithIssuer("http://localhost:4500"),
|
||||||
|
oidctest.WithLogger(slog.Make(sloghuman.Sink(os.Stderr))),
|
||||||
|
)
|
||||||
|
id, sec := idp.AppCredentials()
|
||||||
|
prov := idp.WellknownConfig()
|
||||||
|
const appID = "fake"
|
||||||
|
coderCfg := idp.ExternalAuthConfig(t, appID, nil)
|
||||||
|
|
||||||
|
log.Println("IDP Issuer URL", idp.IssuerURL())
|
||||||
|
log.Println("Coderd Flags")
|
||||||
|
deviceCodeURL := ""
|
||||||
|
if coderCfg.DeviceAuth != nil {
|
||||||
|
deviceCodeURL = coderCfg.DeviceAuth.CodeURL
|
||||||
|
}
|
||||||
|
cfg := withClientSecret{
|
||||||
|
ClientSecret: sec,
|
||||||
|
ExternalAuthConfig: codersdk.ExternalAuthConfig{
|
||||||
|
Type: appID,
|
||||||
|
ClientID: id,
|
||||||
|
ClientSecret: sec,
|
||||||
|
ID: appID,
|
||||||
|
AuthURL: prov.AuthURL,
|
||||||
|
TokenURL: prov.TokenURL,
|
||||||
|
ValidateURL: prov.ExternalAuthURL,
|
||||||
|
AppInstallURL: coderCfg.AppInstallURL,
|
||||||
|
AppInstallationsURL: coderCfg.AppInstallationsURL,
|
||||||
|
NoRefresh: false,
|
||||||
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
|
ExtraTokenKeys: coderCfg.ExtraTokenKeys,
|
||||||
|
DeviceFlow: coderCfg.DeviceAuth != nil,
|
||||||
|
DeviceCodeURL: deviceCodeURL,
|
||||||
|
Regex: *extRegex,
|
||||||
|
DisplayName: coderCfg.DisplayName,
|
||||||
|
DisplayIcon: coderCfg.DisplayIcon,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal([]withClientSecret{cfg})
|
||||||
|
require.NoError(t, err)
|
||||||
|
log.Printf(`--external-auth-providers='%s'`, string(data))
|
||||||
|
|
||||||
|
log.Println("Press Ctrl+C to exit")
|
||||||
|
c := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(c, os.Interrupt)
|
||||||
|
|
||||||
|
// Block until ctl+c
|
||||||
|
<-c
|
||||||
|
log.Println("Closing")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue