diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 18062a549a..08979a67a8 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -31,15 +31,13 @@ import ( "time" "cloud.google.com/go/compute/metadata" - "github.com/coreos/go-oidc/v3/oidc" "github.com/fullsailor/pkcs7" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/oauth2" "golang.org/x/xerrors" "google.golang.org/api/idtoken" "google.golang.org/api/option" @@ -1020,152 +1018,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif } } -type OIDCConfig struct { - key *rsa.PrivateKey - issuer string - // These are optional - refreshToken string - oidcTokenExpires func() time.Time - tokenSource func() (*oauth2.Token, error) -} - -func WithRefreshToken(token string) func(cfg *OIDCConfig) { - return func(cfg *OIDCConfig) { - cfg.refreshToken = token - } -} - -func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) { - return func(cfg *OIDCConfig) { - cfg.oidcTokenExpires = expFunc - } -} - -func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) { - return func(cfg *OIDCConfig) { - cfg.tokenSource = src - } -} - -func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig { - t.Helper() - - block, _ := pem.Decode([]byte(testRSAPrivateKey)) - pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - require.NoError(t, err) - - if issuer == "" { - issuer = "https://coder.com" - } - - cfg := &OIDCConfig{ - key: pkey, - issuer: issuer, - } - for _, opt := range opts { - opt(cfg) - } - return cfg -} - -func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string { - return "/?state=" + url.QueryEscape(state) -} - -type tokenSource struct { - src func() (*oauth2.Token, error) -} - -func (s tokenSource) Token() (*oauth2.Token, error) { - return s.src() -} - -func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource { - if cfg.tokenSource == nil { - return nil - } - return tokenSource{ - src: cfg.tokenSource, - } -} - -func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - token, err := base64.StdEncoding.DecodeString(code) - if err != nil { - return nil, xerrors.Errorf("decode code: %w", err) - } - - var exp time.Time - if cfg.oidcTokenExpires != nil { - exp = cfg.oidcTokenExpires() - } - - return (&oauth2.Token{ - AccessToken: "token", - RefreshToken: cfg.refreshToken, - Expiry: exp, - }).WithExtra(map[string]interface{}{ - "id_token": string(token), - }), nil -} - -func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string { - t.Helper() - - if _, ok := claims["exp"]; !ok { - claims["exp"] = time.Now().Add(time.Hour).UnixMilli() - } - - if _, ok := claims["iss"]; !ok { - claims["iss"] = cfg.issuer - } - - if _, ok := claims["sub"]; !ok { - claims["sub"] = "testme" - } - - signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key) - require.NoError(t, err) - - return base64.StdEncoding.EncodeToString([]byte(signed)) -} - -func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { - // By default, the provider can be empty. - // This means it won't support any endpoints! - provider := &oidc.Provider{} - if userInfoClaims != nil { - resp, err := json.Marshal(userInfoClaims) - require.NoError(t, err) - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write(resp) - })) - t.Cleanup(srv.Close) - cfg := &oidc.ProviderConfig{ - UserInfoURL: srv.URL, - } - provider = cfg.NewProvider(context.Background()) - } - newCFG := &coderd.OIDCConfig{ - OAuth2Config: cfg, - Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{ - PublicKeys: []crypto.PublicKey{cfg.key.Public()}, - }, &oidc.Config{ - SkipClientIDCheck: true, - }), - Provider: provider, - UsernameField: "preferred_username", - EmailField: "email", - AuthURLParams: map[string]string{"access_type": "offline"}, - GroupField: "groups", - } - for _, opt := range opts { - opt(newCFG) - } - return newCFG -} - // NewAzureInstanceIdentity returns a metadata client and ID token validator for faking // instance authentication for Azure. func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) { @@ -1254,22 +1106,6 @@ func SDKError(t *testing.T, err error) *codersdk.Error { return cerr } -const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY----- -MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS -v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92 -5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB -AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0 -wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe -rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB -w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9 -pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8 -YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR -Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a -d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf -sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u -QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8 ------END RSA PRIVATE KEY-----` - func DeploymentValues(t testing.TB) *codersdk.DeploymentValues { var cfg codersdk.DeploymentValues opts := cfg.Options() diff --git a/coderd/coderdtest/oidctest/helper.go b/coderd/coderdtest/oidctest/helper.go new file mode 100644 index 0000000000..11d9114be2 --- /dev/null +++ b/coderd/coderdtest/oidctest/helper.go @@ -0,0 +1,103 @@ +package oidctest + +import ( + "net/http" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// LoginHelper helps with logging in a user and refreshing their oauth tokens. +// It is mainly because refreshing oauth tokens is a bit tricky and requires +// some database manipulation. +type LoginHelper struct { + fake *FakeIDP + client *codersdk.Client +} + +func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper { + if client == nil { + panic("client must not be nil") + } + if fake == nil { + panic("fake must not be nil") + } + return &LoginHelper{ + fake: fake, + client: client, + } +} + +// Login just helps by making an unauthenticated client and logging in with +// the given claims. All Logins should be unauthenticated, so this is a +// convenience method. +func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) { + t.Helper() + unauthenticatedClient := codersdk.New(h.client.URL) + + return h.fake.Login(t, unauthenticatedClient, idTokenClaims) +} + +// ExpireOauthToken expires the oauth token for the given user. +func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink { + t.Helper() + + //nolint:gocritic // Testing + ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium)) + + id, _, err := httpmw.SplitAPIToken(user.SessionToken()) + require.NoError(t, err) + + // We need to get the OIDC link and update it in the database to force + // it to be expired. + key, err := db.GetAPIKeyByID(ctx, id) + require.NoError(t, err, "get api key") + + link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ + UserID: key.UserID, + LoginType: database.LoginTypeOIDC, + }) + require.NoError(t, err, "get user link") + + // Expire the oauth link for the given user. + updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{ + OAuthAccessToken: link.OAuthAccessToken, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthExpiry: time.Now().Add(time.Hour * -1), + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.NoError(t, err, "expire user link") + + return updated +} + +// ForceRefresh forces the client to refresh its oauth token. It does this by +// expiring the oauth token, then doing an authenticated call. This will force +// the API Key middleware to refresh the oauth token. +// +// A unit test assertion makes sure the refresh token is used. +func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) { + t.Helper() + + link := h.ExpireOauthToken(t, db, user) + // Updates the claims that the IDP will return. By default, it always + // uses the original claims for the original oauth token. + h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken) + + t.Cleanup(func() { + require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?") + }) + + // Do any authenticated call to force the refresh + _, err := user.User(testutil.Context(t, testutil.WaitShort), "me") + require.NoError(t, err, "user must be able to be fetched") +} diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go new file mode 100644 index 0000000000..912d9acd7c --- /dev/null +++ b/coderd/coderdtest/oidctest/idp.go @@ -0,0 +1,793 @@ +package oidctest + +import ( + "context" + "crypto" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/coder/coder/v2/coderd/util/syncmap" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-chi/chi/v5" + "github.com/go-jose/go-jose/v3" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/codersdk" +) + +// FakeIDP is a functional OIDC provider. +// It only supports 1 OIDC client. +type FakeIDP struct { + issuer string + key *rsa.PrivateKey + provider providerJSON + handler http.Handler + cfg *oauth2.Config + + // clientID to be used by coderd + clientID string + clientSecret string + logger slog.Logger + + // These maps are used to control the state of the IDP. + // That is the various access tokens, refresh tokens, states, etc. + codeToStateMap *syncmap.Map[string, string] + // Token -> Email + accessTokens *syncmap.Map[string, string] + // Refresh Token -> Email + refreshTokensUsed *syncmap.Map[string, bool] + refreshTokens *syncmap.Map[string, string] + stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims] + refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims] + + // hooks + // hookValidRedirectURL can be used to reject a redirect url from the + // IDP -> Application. Almost all IDPs have the concept of + // "Authorized Redirect URLs". This can be used to emulate that. + hookValidRedirectURL func(redirectURL string) error + hookUserInfo func(email string) jwt.MapClaims + fakeCoderd func(req *http.Request) (*http.Response, error) + hookOnRefresh func(email string) error + // Custom authentication for the client. This is useful if you want + // to test something like PKI auth vs a client_secret. + hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error) + serve bool +} + +type FakeIDPOpt func(idp *FakeIDP) + +func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookValidRedirectURL = hook + } +} + +// WithRefreshHook is called when a refresh token is used. The email is +// the email of the user that is being refreshed assuming the claims are correct. +func WithRefreshHook(hook func(email string) error) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookOnRefresh = hook + } +} + +func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookAuthenticateClient = hook + } +} + +// WithLogging is optional, but will log some HTTP calls made to the IDP. +func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) { + return func(f *FakeIDP) { + f.logger = slogtest.Make(t, options) + } +} + +// WithStaticUserInfo is optional, but will return the same user info for +// every user on the /userinfo endpoint. +func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookUserInfo = func(_ string) jwt.MapClaims { + return info + } + } +} + +func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) { + return func(f *FakeIDP) { + f.hookUserInfo = userInfoFunc + } +} + +// WithServing makes the IDP run an actual http server. +func WithServing() func(*FakeIDP) { + return func(f *FakeIDP) { + f.serve = true + } +} + +func WithIssuer(issuer string) func(*FakeIDP) { + return func(f *FakeIDP) { + f.issuer = issuer + } +} + +const ( + // nolint:gosec // It thinks this is a secret lol + tokenPath = "/oauth2/token" + authorizePath = "/oauth2/authorize" + keysPath = "/oauth2/keys" + userInfoPath = "/oauth2/userinfo" +) + +func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP { + t.Helper() + + block, _ := pem.Decode([]byte(testRSAPrivateKey)) + pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + idp := &FakeIDP{ + key: pkey, + clientID: uuid.NewString(), + clientSecret: uuid.NewString(), + logger: slog.Make(), + codeToStateMap: syncmap.New[string, string](), + accessTokens: syncmap.New[string, string](), + refreshTokens: syncmap.New[string, string](), + refreshTokensUsed: syncmap.New[string, bool](), + stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](), + refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](), + hookOnRefresh: func(_ string) error { return nil }, + hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} }, + hookValidRedirectURL: func(redirectURL string) error { return nil }, + } + + for _, opt := range opts { + opt(idp) + } + + if idp.issuer == "" { + idp.issuer = "https://coder.com" + } + + idp.handler = idp.httpHandler(t) + idp.updateIssuerURL(t, idp.issuer) + if idp.serve { + idp.realServer(t) + } + + return idp +} + +func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) { + t.Helper() + + u, err := url.Parse(issuer) + require.NoError(t, err, "invalid issuer URL") + + f.issuer = issuer + // providerJSON is the JSON representation of the OpenID Connect provider + // These are all the urls that the IDP will respond to. + f.provider = providerJSON{ + Issuer: issuer, + AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(), + TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(), + JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(), + UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(), + Algorithms: []string{ + "RS256", + }, + } +} + +// realServer turns the FakeIDP into a real http server. +func (f *FakeIDP) realServer(t testing.TB) *httptest.Server { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + srv := httptest.NewUnstartedServer(f.handler) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + srv.Start() + t.Cleanup(srv.CloseClientConnections) + t.Cleanup(srv.Close) + t.Cleanup(cancel) + + f.updateIssuerURL(t, srv.URL) + return srv +} + +// Login does the full OIDC flow starting at the "LoginButton". +// The client argument is just to get the URL of the Coder instance. +// +// The client passed in is just to get the url of the Coder instance. +// The actual client that is used is 100% unauthenticated and fresh. +func (f *FakeIDP) Login(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { + t.Helper() + + client, resp := f.AttemptLogin(t, client, idTokenClaims, opts...) + require.Equal(t, http.StatusOK, resp.StatusCode, "client failed to login") + return client, resp +} + +func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { + t.Helper() + var err error + + cli := f.HTTPClient(client.HTTPClient) + shallowCpyCli := *cli + + if shallowCpyCli.Jar == nil { + shallowCpyCli.Jar, err = cookiejar.New(nil) + require.NoError(t, err, "failed to create cookie jar") + } + + unauthenticated := codersdk.New(client.URL) + unauthenticated.HTTPClient = &shallowCpyCli + + return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...) +} + +// LoginWithClient reuses the context of the passed in client. This means the same +// cookies will be used. This should be an unauthenticated client in most cases. +// +// This is a niche case, but it is needed for testing ConvertLoginType. +func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) { + t.Helper() + + coderOauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback") + require.NoError(t, err) + f.SetRedirect(t, coderOauthURL.String()) + + cli := f.HTTPClient(client.HTTPClient) + cli.CheckRedirect = func(req *http.Request, via []*http.Request) error { + // Store the idTokenClaims to the specific state request. This ties + // the claims 1:1 with a given authentication flow. + state := req.URL.Query().Get("state") + f.stateToIDTokenClaims.Store(state, idTokenClaims) + return nil + } + + req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil) + require.NoError(t, err) + if cli.Jar == nil { + cli.Jar, err = cookiejar.New(nil) + require.NoError(t, err, "failed to create cookie jar") + } + + for _, opt := range opts { + opt(req) + } + + res, err := cli.Do(req) + require.NoError(t, err) + + // If the coder session token exists, return the new authed client! + var user *codersdk.Client + cookies := cli.Jar.Cookies(client.URL) + for _, cookie := range cookies { + if cookie.Name == codersdk.SessionTokenCookie { + user = codersdk.New(client.URL) + user.SetSessionToken(cookie.Value) + } + } + + t.Cleanup(func() { + if res.Body != nil { + _ = res.Body.Close() + } + }) + + return user, res +} + +// OIDCCallback will emulate the IDP redirecting back to the Coder callback. +// This is helpful if no Coderd exists because the IDP needs to redirect to +// something. +// Essentially this is used to fake the Coderd side of the exchange. +// The flow starts at the user hitting the OIDC login page. +func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) { + t.Helper() + if f.serve { + panic("cannot use OIDCCallback with WithServing. This is only for the in memory usage") + } + + f.stateToIDTokenClaims.Store(state, idTokenClaims) + + cli := f.HTTPClient(nil) + u := f.cfg.AuthCodeURL(state) + req, err := http.NewRequest("GET", u, nil) + require.NoError(t, err) + + resp, err := cli.Do(req.WithContext(context.Background())) + require.NoError(t, err) + + t.Cleanup(func() { + if resp.Body != nil { + _ = resp.Body.Close() + } + }) + return resp, nil +} + +type providerJSON struct { + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + JWKSURL string `json:"jwks_uri"` + UserInfoURL string `json:"userinfo_endpoint"` + Algorithms []string `json:"id_token_signing_alg_values_supported"` +} + +// newCode enforces the code exchanged is actually a valid code +// created by the IDP. +func (f *FakeIDP) newCode(state string) string { + code := uuid.NewString() + f.codeToStateMap.Store(code, state) + return code +} + +// newToken enforces the access token exchanged is actually a valid access token +// created by the IDP. +func (f *FakeIDP) newToken(email string) string { + accessToken := uuid.NewString() + f.accessTokens.Store(accessToken, email) + return accessToken +} + +func (f *FakeIDP) newRefreshTokens(email string) string { + refreshToken := uuid.NewString() + f.refreshTokens.Store(refreshToken, email) + return refreshToken +} + +// authenticateBearerTokenRequest enforces the access token is valid. +func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) { + t.Helper() + + auth := req.Header.Get("Authorization") + token := strings.TrimPrefix(auth, "Bearer ") + _, ok := f.accessTokens.Load(token) + if !ok { + return "", xerrors.New("invalid access token") + } + return token, nil +} + +// authenticateOIDCClientRequest enforces the client_id and client_secret are valid. +func (f *FakeIDP) authenticateOIDCClientRequest(t testing.TB, req *http.Request) (url.Values, error) { + t.Helper() + + if f.hookAuthenticateClient != nil { + return f.hookAuthenticateClient(t, req) + } + + data, err := io.ReadAll(req.Body) + if !assert.NoError(t, err, "read token request body") { + return nil, xerrors.Errorf("authenticate request, read body: %w", err) + } + values, err := url.ParseQuery(string(data)) + if !assert.NoError(t, err, "parse token request values") { + return nil, xerrors.New("invalid token request") + } + + if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") { + return nil, xerrors.New("client_id mismatch") + } + + if !assert.Equal(t, f.clientSecret, values.Get("client_secret"), "client_secret mismatch") { + return nil, xerrors.New("client_secret mismatch") + } + + return values, nil +} + +// encodeClaims is a helper func to convert claims to a valid JWT. +func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string { + t.Helper() + + if _, ok := claims["exp"]; !ok { + claims["exp"] = time.Now().Add(time.Hour).UnixMilli() + } + + if _, ok := claims["aud"]; !ok { + claims["aud"] = f.clientID + } + + if _, ok := claims["iss"]; !ok { + claims["iss"] = f.issuer + } + + signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key) + require.NoError(t, err) + + return signed +} + +// httpHandler is the IDP http server. +func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { + t.Helper() + + mux := chi.NewMux() + // This endpoint is required to initialize the OIDC provider. + // It is used to get the OIDC configuration. + mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) { + f.logger.Info(r.Context(), "http OIDC config", slog.F("url", r.URL.String())) + + _ = json.NewEncoder(rw).Encode(f.provider) + }) + + // Authorize is called when the user is redirected to the IDP to login. + // This is the browser hitting the IDP and the user logging into Google or + // w/e and clicking "Allow". They will be redirected back to the redirect + // when this is done. + mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + f.logger.Info(r.Context(), "http call authorize", slog.F("url", r.URL.String())) + + clientID := r.URL.Query().Get("client_id") + if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") { + http.Error(rw, "invalid client_id", http.StatusBadRequest) + return + } + + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + + scope := r.URL.Query().Get("scope") + assert.NotEmpty(t, scope, "scope is empty") + + responseType := r.URL.Query().Get("response_type") + switch responseType { + case "code": + case "token": + t.Errorf("response_type %q not supported", responseType) + http.Error(rw, "invalid response_type", http.StatusBadRequest) + return + default: + t.Errorf("unexpected response_type %q", responseType) + http.Error(rw, "invalid response_type", http.StatusBadRequest) + return + } + + err := f.hookValidRedirectURL(redirectURI) + if err != nil { + t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error()) + http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest) + return + } + + ru, err := url.Parse(redirectURI) + if err != nil { + t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error()) + http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest) + return + } + + q := ru.Query() + q.Set("state", state) + q.Set("code", f.newCode(state)) + ru.RawQuery = q.Encode() + + http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect) + })) + + mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + values, err := f.authenticateOIDCClientRequest(t, r) + f.logger.Info(r.Context(), "http idp call token", + slog.Error(err), + slog.F("values", values.Encode()), + ) + if err != nil { + http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest) + return + } + getEmail := func(claims jwt.MapClaims) string { + email, ok := claims["email"] + if !ok { + return "unknown" + } + emailStr, ok := email.(string) + if !ok { + return "wrong-type" + } + return emailStr + } + + var claims jwt.MapClaims + switch values.Get("grant_type") { + case "authorization_code": + code := values.Get("code") + if !assert.NotEmpty(t, code, "code is empty") { + http.Error(rw, "invalid code", http.StatusBadRequest) + return + } + stateStr, ok := f.codeToStateMap.Load(code) + if !assert.True(t, ok, "invalid code") { + http.Error(rw, "invalid code", http.StatusBadRequest) + return + } + // Always invalidate the code after it is used. + f.codeToStateMap.Delete(code) + + idTokenClaims, ok := f.stateToIDTokenClaims.Load(stateStr) + if !ok { + t.Errorf("missing id token claims") + http.Error(rw, "missing id token claims", http.StatusBadRequest) + return + } + claims = idTokenClaims + case "refresh_token": + refreshToken := values.Get("refresh_token") + if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") { + http.Error(rw, "invalid refresh_token", http.StatusBadRequest) + return + } + + _, ok := f.refreshTokens.Load(refreshToken) + if !assert.True(t, ok, "invalid refresh_token") { + http.Error(rw, "invalid refresh_token", http.StatusBadRequest) + return + } + + idTokenClaims, ok := f.refreshIDTokenClaims.Load(refreshToken) + if !ok { + t.Errorf("missing id token claims in refresh") + http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest) + return + } + + claims = idTokenClaims + err := f.hookOnRefresh(getEmail(claims)) + if err != nil { + http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest) + return + } + + f.refreshTokensUsed.Store(refreshToken, true) + // Always invalidate the refresh token after it is used. + f.refreshTokens.Delete(refreshToken) + default: + t.Errorf("unexpected grant_type %q", values.Get("grant_type")) + http.Error(rw, "invalid grant_type", http.StatusBadRequest) + return + } + + exp := time.Now().Add(time.Minute * 5) + claims["exp"] = exp.UnixMilli() + email := getEmail(claims) + refreshToken := f.newRefreshTokens(email) + token := map[string]interface{}{ + "access_token": f.newToken(email), + "refresh_token": refreshToken, + "token_type": "Bearer", + "expires_in": int64((time.Minute * 5).Seconds()), + "id_token": f.encodeClaims(t, claims), + } + // Store the claims for the next refresh + f.refreshIDTokenClaims.Store(refreshToken, claims) + + rw.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(rw).Encode(token) + })) + + mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + token, err := f.authenticateBearerTokenRequest(t, r) + f.logger.Info(r.Context(), "http call idp user info", + slog.Error(err), + slog.F("url", r.URL.String()), + ) + if err != nil { + http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest) + return + } + + email, ok := f.accessTokens.Load(token) + if !ok { + t.Errorf("access token user for user_info has no email to indicate which user") + http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest) + return + } + _ = json.NewEncoder(rw).Encode(f.hookUserInfo(email)) + })) + + mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + f.logger.Info(r.Context(), "http call idp /keys") + set := jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: f.key.Public(), + KeyID: "test-key", + Algorithm: "RSA", + }, + }, + } + _ = json.NewEncoder(rw).Encode(set) + })) + + mux.NotFound(func(rw http.ResponseWriter, r *http.Request) { + f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path)) + t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path) + }) + + return mux +} + +// HTTPClient does nothing if IsServing is used. +// +// If IsServing is not used, then it will return a client that will make requests +// to the IDP all in memory. If a request is not to the IDP, then the passed in +// client will be used. If no client is passed in, then any regular network +// requests will fail. +func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client { + if f.serve { + if rest == nil || rest.Transport == nil { + return &http.Client{} + } + return rest + } + + var jar http.CookieJar + if rest != nil { + jar = rest.Jar + } + return &http.Client{ + Jar: jar, + Transport: fakeRoundTripper{ + roundTrip: func(req *http.Request) (*http.Response, error) { + u, _ := url.Parse(f.issuer) + if req.URL.Host != u.Host { + if f.fakeCoderd != nil { + return f.fakeCoderd(req) + } + if rest == nil || rest.Transport == nil { + return nil, fmt.Errorf("unexpected network request to %q", req.URL.Host) + } + return rest.Transport.RoundTrip(req) + } + resp := httptest.NewRecorder() + f.handler.ServeHTTP(resp, req) + return resp.Result(), nil + }, + }, + } +} + +// RefreshUsed returns if the refresh token has been used. All refresh tokens +// can only be used once, then they are deleted. +func (f *FakeIDP) RefreshUsed(refreshToken string) bool { + used, _ := f.refreshTokensUsed.Load(refreshToken) + return used +} + +// UpdateRefreshClaims allows the caller to change what claims are returned +// for a given refresh token. By default, all refreshes use the same claims as +// the original IDToken issuance. +func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) { + f.refreshIDTokenClaims.Store(refreshToken, claims) +} + +// SetRedirect is required for the IDP to know where to redirect and call +// Coderd. +func (f *FakeIDP) SetRedirect(t testing.TB, u string) { + t.Helper() + + f.cfg.RedirectURL = u +} + +// SetCoderdCallback is optional and only works if not using the IsServing. +// It will setup a fake "Coderd" for the IDP to call when the IDP redirects +// back after authenticating. +func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) { + if f.serve { + panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'") + } + f.fakeCoderd = callback +} + +func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) { + f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) { + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + return resp.Result(), nil + }) +} + +// OIDCConfig returns the OIDC config to use for Coderd. +func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig { + t.Helper() + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + + oauthCfg := &oauth2.Config{ + ClientID: f.clientID, + ClientSecret: f.clientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: f.provider.AuthURL, + TokenURL: f.provider.TokenURL, + AuthStyle: oauth2.AuthStyleInParams, + }, + // If the user is using a real network request, they will need to do + // 'fake.SetRedirect()' + RedirectURL: "https://redirect.com", + Scopes: scopes, + } + + ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil)) + p, err := oidc.NewProvider(ctx, f.provider.Issuer) + require.NoError(t, err, "failed to create OIDC provider") + cfg := &coderd.OIDCConfig{ + OAuth2Config: oauthCfg, + Provider: p, + Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{f.key.Public()}, + }, &oidc.Config{ + ClientID: oauthCfg.ClientID, + SupportedSigningAlgs: []string{ + "RS256", + }, + // Todo: add support for Now() + }), + UsernameField: "preferred_username", + EmailField: "email", + AuthURLParams: map[string]string{"access_type": "offline"}, + } + + for _, opt := range opts { + if opt == nil { + continue + } + opt(cfg) + } + + f.cfg = oauthCfg + + return cfg +} + +type fakeRoundTripper struct { + roundTrip func(req *http.Request) (*http.Response, error) +} + +func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return f.roundTrip(req) +} + +const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS +v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92 +5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB +AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0 +wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe +rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB +w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9 +pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8 +YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR +Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a +d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf +sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u +QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8 +-----END RSA PRIVATE KEY-----` diff --git a/coderd/coderdtest/oidctest/idp_test.go b/coderd/coderdtest/oidctest/idp_test.go new file mode 100644 index 0000000000..0dc1149d93 --- /dev/null +++ b/coderd/coderdtest/oidctest/idp_test.go @@ -0,0 +1,72 @@ +package oidctest_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/assert" + + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +// TestFakeIDPBasicFlow tests the basic flow of the fake IDP. +// It is done all in memory with no actual network requests. +// nolint:bodyclose +func TestFakeIDPBasicFlow(t *testing.T) { + t.Parallel() + + fake := oidctest.NewFakeIDP(t, + oidctest.WithLogging(t, nil), + ) + + var handler http.Handler + srv := httptest.NewServer(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler.ServeHTTP(w, r) + }))) + defer srv.Close() + + cfg := fake.OIDCConfig(t, nil) + cli := fake.HTTPClient(nil) + ctx := oidc.ClientContext(context.Background(), cli) + + const expectedState = "random-state" + var token *oauth2.Token + // This is the Coder callback using an actual network request. + fake.SetCoderdCallbackHandler(func(w http.ResponseWriter, r *http.Request) { + // Emulate OIDC flow + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + assert.Equal(t, expectedState, state, "state mismatch") + + oauthToken, err := cfg.Exchange(ctx, code) + if assert.NoError(t, err, "failed to exchange code") { + assert.NotEmpty(t, oauthToken.AccessToken, "access token is empty") + assert.NotEmpty(t, oauthToken.RefreshToken, "refresh token is empty") + } + token = oauthToken + }) + + resp, err := fake.OIDCCallback(t, expectedState, jwt.MapClaims{}) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Test the user info + _, err = cfg.Provider.UserInfo(ctx, oauth2.StaticTokenSource(token)) + require.NoError(t, err) + + // Now test it can refresh + refreshed, err := cfg.TokenSource(ctx, &oauth2.Token{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + Expiry: time.Now().Add(time.Minute * -1), + }).Token() + require.NoError(t, err, "failed to refresh token") + require.NotEmpty(t, refreshed.AccessToken, "access token is empty on refresh") +} diff --git a/coderd/oauthpki/oidcpki.go b/coderd/oauthpki/oidcpki.go index d5bc625336..c44d130e5b 100644 --- a/coderd/oauthpki/oidcpki.go +++ b/coderd/oauthpki/oidcpki.go @@ -215,7 +215,10 @@ func (src *jwtTokenSource) Token() (*oauth2.Token, error) { } var tokenRes struct { - oauth2.Token + AccessToken string `json:"access_token"` + TokenType string `json:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + // Extra fields returned by the refresh that are needed IDToken string `json:"id_token"` ExpiresIn int64 `json:"expires_in"` // relative seconds from now diff --git a/coderd/oauthpki/okidcpki_test.go b/coderd/oauthpki/okidcpki_test.go index 27593607f2..ab6e3e3a08 100644 --- a/coderd/oauthpki/okidcpki_test.go +++ b/coderd/oauthpki/okidcpki_test.go @@ -12,12 +12,15 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/oauthpki" "github.com/coder/coder/v2/testutil" ) @@ -123,6 +126,58 @@ func TestAzureADPKIOIDC(t *testing.T) { require.Error(t, err, "error expected") } +// TestAzureAKPKIWithCoderd uses a fake IDP and a real Coderd to test PKI auth. +// nolint:bodyclose +func TestAzureAKPKIWithCoderd(t *testing.T) { + t.Parallel() + + scopes := []string{"openid", "email", "profile", "offline_access"} + fake := oidctest.NewFakeIDP(t, + oidctest.WithIssuer("https://login.microsoftonline.com/fake_app"), + oidctest.WithCustomClientAuth(func(t testing.TB, req *http.Request) (url.Values, error) { + values := assertJWTAuth(t, req) + if values == nil { + return nil, xerrors.New("authorizatin failed in request") + } + return values, nil + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) + + oauthCfg := cfg.OAuth2Config.(*oauth2.Config) + // Create the oauthpki config + pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{ + ClientID: oauthCfg.ClientID, + TokenURL: oauthCfg.Endpoint.TokenURL, + Scopes: scopes, + PemEncodedKey: []byte(testClientKey), + PemEncodedCert: []byte(testClientCert), + Config: oauthCfg, + }) + require.NoError(t, err) + cfg.OAuth2Config = pki + + owner, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + OIDCConfig: cfg, + }) + + // Create a user and login + const email = "alice@coder.com" + claims := jwt.MapClaims{ + "email": email, + } + helper := oidctest.NewLoginHelper(owner, fake) + user, _ := helper.Login(t, claims) + + // Try refreshing the token more than once. + for i := 0; i < 2; i++ { + helper.ForceRefresh(t, api.Database, user, claims) + } +} + // TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure // AD instance and saving them to replay, removing some details. // The reason this is done is that this is the only way to assert values @@ -269,7 +324,7 @@ func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // assertJWTAuth will assert the basic JWT auth assertions. It will return the // url.Values from the request body for any additional assertions to be made. -func assertJWTAuth(t *testing.T, r *http.Request) url.Values { +func assertJWTAuth(t testing.TB, r *http.Request) url.Values { body, err := io.ReadAll(r.Body) if !assert.NoError(t, err) { return nil diff --git a/coderd/userauth_test.go b/coderd/userauth_test.go index 10bf7ecf67..1f37a0721a 100644 --- a/coderd/userauth_test.go +++ b/coderd/userauth_test.go @@ -4,28 +4,25 @@ import ( "context" "crypto" "fmt" - "io" "net/http" "net/http/cookiejar" + "net/url" "strings" "testing" - "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v4" "github.com/google/go-github/v43/github" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/oauth2" "golang.org/x/xerrors" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/codersdk" @@ -35,85 +32,42 @@ import ( // This test specifically tests logging in with OIDC when an expired // OIDC session token exists. // The token refreshing should not happen since we are reauthenticating. +// nolint:bodyclose func TestOIDCOauthLoginWithExisting(t *testing.T) { t.Parallel() - conf := coderdtest.NewOIDCConfig(t, "", - // Provide a refresh token so we use the refresh token flow - coderdtest.WithRefreshToken("refresh_token"), - // We need to set the expire in the future for the first api calls. - coderdtest.WithTokenExpires(func() time.Time { - return time.Now().Add(time.Hour).UTC() - }), - // No refresh should actually happen in this test. - coderdtest.WithTokenSource(func() (*oauth2.Token, error) { - return nil, xerrors.New("token should not require refresh") + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") }), + oidctest.WithServing(), ) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - auditor := audit.NewMock() + + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.IgnoreUserInfo = true + }) + + client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ + OIDCConfig: cfg, + }) + const username = "alice" claims := jwt.MapClaims{ "email": "alice@coder.com", "email_verified": true, "preferred_username": username, } - config := conf.OIDCConfig(t, claims) - - config.AllowSignups = true - config.IgnoreUserInfo = true - client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ - Auditor: auditor, - OIDCConfig: config, - Logger: &logger, - }) + helper := oidctest.NewLoginHelper(client, fake) // Signup alice - resp := oidcCallback(t, client, conf.EncodeClaims(t, claims)) - // Set the client to use this OIDC context - authCookie := authCookieValue(resp.Cookies()) - client.SetSessionToken(authCookie) - _ = resp.Body.Close() + userClient, _ := helper.Login(t, claims) - ctx := testutil.Context(t, testutil.WaitLong) - // Verify the user and oauth link - user, err := client.User(ctx, "me") - require.NoError(t, err) - require.Equal(t, username, user.Username) + // Expire the link. This will force the client to refresh the token. + helper.ExpireOauthToken(t, api.Database, userClient) - // nolint:gocritic - link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{ - UserID: user.ID, - LoginType: database.LoginType(user.LoginType), - }) - require.NoError(t, err, "failed to get user link") - - // Expire the link - // nolint:gocritic - _, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{ - OAuthAccessToken: link.OAuthAccessToken, - OAuthRefreshToken: link.OAuthRefreshToken, - OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(), - UserID: link.UserID, - LoginType: link.LoginType, - }) - require.NoError(t, err, "failed to update user link") - - // Log in again with OIDC - loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) { - req.AddCookie(&http.Cookie{ - Name: codersdk.SessionTokenCookie, - Value: authCookie, - Path: "/", - }) - }) - require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode) - _ = loginAgain.Body.Close() - - // Try to use new login - client.SetSessionToken(authCookieValue(resp.Cookies())) - _, err = client.User(ctx, "me") - require.NoError(t, err, "use new session") + // Instead of refreshing, just log in again. + helper.Login(t, claims) } func TestUserLogin(t *testing.T) { @@ -660,7 +614,7 @@ func TestUserOIDC(t *testing.T) { "email": "kyle@kwc.io", }, AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, Username: "kyle", }, { Name: "EmailNotVerified", @@ -685,7 +639,7 @@ func TestUserOIDC(t *testing.T) { "email_verified": false, }, AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, Username: "kyle", IgnoreEmailVerified: true, }, { @@ -709,7 +663,7 @@ func TestUserOIDC(t *testing.T) { EmailDomain: []string{ "kwc.io", }, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "EmptyClaims", IDTokenClaims: jwt.MapClaims{}, @@ -730,7 +684,7 @@ func TestUserOIDC(t *testing.T) { }, Username: "kyle", AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "UsernameFromClaims", IDTokenClaims: jwt.MapClaims{ @@ -740,7 +694,7 @@ func TestUserOIDC(t *testing.T) { }, Username: "hotdog", AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { // Services like Okta return the email as the username: // https://developer.okta.com/docs/reference/api/oidc/#base-claims-always-present @@ -752,7 +706,7 @@ func TestUserOIDC(t *testing.T) { }, Username: "kyle", AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { // See: https://github.com/coder/coder/issues/4472 Name: "UsernameIsEmail", @@ -761,7 +715,7 @@ func TestUserOIDC(t *testing.T) { }, Username: "kyle", AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "WithPicture", IDTokenClaims: jwt.MapClaims{ @@ -773,7 +727,7 @@ func TestUserOIDC(t *testing.T) { Username: "kyle", AllowSignups: true, AvatarURL: "/example.png", - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "WithUserInfoClaims", IDTokenClaims: jwt.MapClaims{ @@ -787,7 +741,7 @@ func TestUserOIDC(t *testing.T) { Username: "potato", AllowSignups: true, AvatarURL: "/example.png", - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "GroupsDoesNothing", IDTokenClaims: jwt.MapClaims{ @@ -795,7 +749,7 @@ func TestUserOIDC(t *testing.T) { "groups": []string{"pingpong"}, }, AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "UserInfoOverridesIDTokenClaims", IDTokenClaims: jwt.MapClaims{ @@ -810,7 +764,7 @@ func TestUserOIDC(t *testing.T) { Username: "user", AllowSignups: true, IgnoreEmailVerified: false, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }, { Name: "InvalidUserInfo", IDTokenClaims: jwt.MapClaims{ @@ -837,36 +791,41 @@ func TestUserOIDC(t *testing.T) { Username: "user", IgnoreUserInfo: true, AllowSignups: true, - StatusCode: http.StatusTemporaryRedirect, + StatusCode: http.StatusOK, }} { tc := tc t.Run(tc.Name, func(t *testing.T) { t.Parallel() + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + oidctest.WithStaticUserInfo(tc.UserInfoClaims), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = tc.AllowSignups + cfg.EmailDomain = tc.EmailDomain + cfg.IgnoreEmailVerified = tc.IgnoreEmailVerified + cfg.IgnoreUserInfo = tc.IgnoreUserInfo + }) + auditor := audit.NewMock() - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, tc.UserInfoClaims) - config.AllowSignups = tc.AllowSignups - config.EmailDomain = tc.EmailDomain - config.IgnoreEmailVerified = tc.IgnoreEmailVerified - config.IgnoreUserInfo = tc.IgnoreUserInfo - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - client := coderdtest.New(t, &coderdtest.Options{ + owner := coderdtest.New(t, &coderdtest.Options{ Auditor: auditor, - OIDCConfig: config, + OIDCConfig: cfg, Logger: &logger, }) numLogs := len(auditor.AuditLogs()) - resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.IDTokenClaims)) + client, resp := fake.AttemptLogin(t, owner, tc.IDTokenClaims) numLogs++ // add an audit log for login - assert.Equal(t, tc.StatusCode, resp.StatusCode) + require.Equal(t, tc.StatusCode, resp.StatusCode) ctx := testutil.Context(t, testutil.WaitLong) if tc.Username != "" { - client.SetSessionToken(authCookieValue(resp.Cookies())) user, err := client.User(ctx, "me") require.NoError(t, err) require.Equal(t, tc.Username, user.Username) @@ -877,7 +836,6 @@ func TestUserOIDC(t *testing.T) { } if tc.AvatarURL != "" { - client.SetSessionToken(authCookieValue(resp.Cookies())) user, err := client.User(ctx, "me") require.NoError(t, err) require.Equal(t, tc.AvatarURL, user.AvatarURL) @@ -890,26 +848,29 @@ func TestUserOIDC(t *testing.T) { t.Run("OIDCConvert", func(t *testing.T) { t.Parallel() + auditor := audit.NewMock() - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, nil) - config.AllowSignups = true - - cfg := coderdtest.DeploymentValues(t) - client := coderdtest.New(t, &coderdtest.Options{ - Auditor: auditor, - OIDCConfig: config, - DeploymentValues: cfg, + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true }) - owner := coderdtest.CreateFirstUser(t, client) + client := coderdtest.New(t, &coderdtest.Options{ + Auditor: auditor, + OIDCConfig: cfg, + }) + + owner := coderdtest.CreateFirstUser(t, client) user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) - code := conf.EncodeClaims(t, jwt.MapClaims{ + claims := jwt.MapClaims{ "email": userData.Email, - }) - + } var err error user.HTTPClient.Jar, err = cookiejar.New(nil) require.NoError(t, err) @@ -921,52 +882,58 @@ func TestUserOIDC(t *testing.T) { }) require.NoError(t, err) - resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + fake.LoginWithClient(t, user, claims, func(r *http.Request) { + r.URL.RawQuery = url.Values{ + "oidc_merge_state": {convertResponse.StateString}, + }.Encode() + r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken()) + cookies := user.HTTPClient.Jar.Cookies(r.URL) + for _, cookie := range cookies { + r.AddCookie(cookie) + } + }) }) t.Run("AlternateUsername", func(t *testing.T) { t.Parallel() auditor := audit.NewMock() - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, nil) - config.AllowSignups = true + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) client := coderdtest.New(t, &coderdtest.Options{ Auditor: auditor, - OIDCConfig: config, + OIDCConfig: cfg, }) - numLogs := len(auditor.AuditLogs()) - code := conf.EncodeClaims(t, jwt.MapClaims{ + numLogs := len(auditor.AuditLogs()) + claims := jwt.MapClaims{ "email": "jon@coder.com", - }) - resp := oidcCallback(t, client, code) + } + + userClient, _ := fake.Login(t, client, claims) numLogs++ // add an audit log for login - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - ctx := testutil.Context(t, testutil.WaitLong) - - client.SetSessionToken(authCookieValue(resp.Cookies())) - user, err := client.User(ctx, "me") + user, err := userClient.User(ctx, "me") require.NoError(t, err) require.Equal(t, "jon", user.Username) // Pass a different subject field so that we prompt creating a - // new user. - code = conf.EncodeClaims(t, jwt.MapClaims{ + // new user + userClient, _ = fake.Login(t, client, jwt.MapClaims{ "email": "jon@example2.com", "sub": "diff", }) - resp = oidcCallback(t, client, code) numLogs++ // add an audit log for login - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - client.SetSessionToken(authCookieValue(resp.Cookies())) - user, err = client.User(ctx, "me") + user, err = userClient.User(ctx, "me") require.NoError(t, err) require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-") @@ -977,45 +944,62 @@ func TestUserOIDC(t *testing.T) { t.Run("Disabled", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) - resp := oidcCallback(t, client, "asdf") + oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback") + require.NoError(t, err) + + req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil) + require.NoError(t, err) + resp, err := client.HTTPClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) t.Run("NoIDToken", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{ - OIDCConfig: &coderd.OIDCConfig{ - OAuth2Config: &testutil.OAuth2Config{}, - }, + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true }) - resp := oidcCallback(t, client, "asdf") + client := coderdtest.New(t, &coderdtest.Options{ + OIDCConfig: cfg, + }) + + _, resp := fake.AttemptLogin(t, client, jwt.MapClaims{}) require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) t.Run("BadVerify", func(t *testing.T) { t.Parallel() - verifier := oidc.NewVerifier("", &oidc.StaticKeySet{ + badVerifier := oidc.NewVerifier("", &oidc.StaticKeySet{ PublicKeys: []crypto.PublicKey{}, }, &oidc.Config{}) - provider := &oidc.Provider{} + badProvider := &oidc.Provider{} - client := coderdtest.New(t, &coderdtest.Options{ - OIDCConfig: &coderd.OIDCConfig{ - OAuth2Config: &testutil.OAuth2Config{ - Token: (&oauth2.Token{ - AccessToken: "token", - }).WithExtra(map[string]interface{}{ - "id_token": "invalid", - }), - }, - Provider: provider, - Verifier: verifier, - }, + fake := oidctest.NewFakeIDP(t, + oidctest.WithRefreshHook(func(_ string) error { + return xerrors.New("refreshing token should never occur") + }), + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.Provider = badProvider + cfg.Verifier = badVerifier }) - resp := oidcCallback(t, client, "asdf") + client := coderdtest.New(t, &coderdtest.Options{ + OIDCConfig: cfg, + }) + _, resp := fake.AttemptLogin(t, client, jwt.MapClaims{}) require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) } @@ -1146,36 +1130,6 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response { return res } -func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response { - return oidcCallbackWithState(t, client, code, "somestate", nil) -} - -func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response { - t.Helper() - - client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=%s", code, state)) - require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil) - require.NoError(t, err) - req.AddCookie(&http.Cookie{ - Name: codersdk.OAuth2StateCookie, - Value: state, - }) - if modify != nil { - modify(req) - } - res, err := client.HTTPClient.Do(req) - require.NoError(t, err) - defer res.Body.Close() - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - t.Log(string(data)) - return res -} - func i64ptr(i int64) *int64 { return &i } diff --git a/coderd/users_test.go b/coderd/users_test.go index c36b4fad98..60e6ddb82a 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -8,7 +8,10 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + + "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -403,6 +406,7 @@ func TestPostLogout(t *testing.T) { }) } +// nolint:bodyclose func TestPostUsers(t *testing.T) { t.Parallel() t.Run("NoAuth", func(t *testing.T) { @@ -593,15 +597,15 @@ func TestPostUsers(t *testing.T) { t.Run("CreateOIDCLoginType", func(t *testing.T) { t.Parallel() email := "another@user.org" - conf := coderdtest.NewOIDCConfig(t, "") - config := conf.OIDCConfig(t, jwt.MapClaims{ - "email": email, + fake := oidctest.NewFakeIDP(t, + oidctest.WithServing(), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true }) - config.AllowSignups = false - config.IgnoreUserInfo = true client := coderdtest.New(t, &coderdtest.Options{ - OIDCConfig: config, + OIDCConfig: cfg, }) first := coderdtest.CreateFirstUser(t, client) @@ -618,15 +622,9 @@ func TestPostUsers(t *testing.T) { require.NoError(t, err) // Try to log in with OIDC. - userClient := codersdk.New(client.URL) - resp := oidcCallback(t, userClient, conf.EncodeClaims(t, jwt.MapClaims{ + userClient, _ := fake.Login(t, client, jwt.MapClaims{ "email": email, - })) - require.Equal(t, resp.StatusCode, http.StatusTemporaryRedirect) - // Set the client to use this OIDC context - authCookie := authCookieValue(resp.Cookies()) - userClient.SetSessionToken(authCookie) - _ = resp.Body.Close() + }) found, err := userClient.User(ctx, "me") require.NoError(t, err) diff --git a/coderd/util/syncmap/map.go b/coderd/util/syncmap/map.go new file mode 100644 index 0000000000..d245973efa --- /dev/null +++ b/coderd/util/syncmap/map.go @@ -0,0 +1,77 @@ +package syncmap + +import "sync" + +// Map is a type safe sync.Map +type Map[K, V any] struct { + m sync.Map +} + +func New[K, V any]() *Map[K, V] { + return &Map[K, V]{ + m: sync.Map{}, + } +} + +func (m *Map[K, V]) Store(k K, v V) { + m.m.Store(k, v) +} + +//nolint:forcetypeassert +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.m.Load(key) + if !ok { + var empty V + return empty, false + } + return v.(V), ok +} + +func (m *Map[K, V]) Delete(key K) { + m.m.Delete(key) +} + +//nolint:forcetypeassert +func (m *Map[K, V]) LoadAndDelete(key K) (actual V, loaded bool) { + act, loaded := m.m.LoadAndDelete(key) + if !loaded { + var empty V + return empty, loaded + } + return act.(V), loaded +} + +//nolint:forcetypeassert +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + act, loaded := m.m.LoadOrStore(key, value) + if !loaded { + var empty V + return empty, loaded + } + return act.(V), loaded +} + +func (m *Map[K, V]) CompareAndSwap(key K, old V, new V) bool { + return m.m.CompareAndSwap(key, old, new) +} + +func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) { + return m.m.CompareAndDelete(key, old) +} + +//nolint:forcetypeassert +func (m *Map[K, V]) Swap(key K, value V) (previous any, loaded bool) { + previous, loaded = m.m.Swap(key, value) + if !loaded { + var empty V + return empty, loaded + } + return previous.(V), loaded +} + +//nolint:forcetypeassert +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value interface{}) bool { + return f(key.(K), value.(V)) + }) +} diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index d6f6db3cbe..8e76a36b1d 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -1,25 +1,22 @@ package coderd_test import ( - "context" - "fmt" - "io" "net/http" "regexp" "testing" - "github.com/golang-jwt/jwt" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" + "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" + coderden "github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/testutil" @@ -31,128 +28,123 @@ func TestUserOIDC(t *testing.T) { t.Run("RoleSync", func(t *testing.T) { t.Parallel() + // NoRoles is the "control group". It has claims with 0 roles + // assigned, and asserts that the user has no roles. t.Run("NoRoles", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - conf := coderdtest.NewOIDCConfig(t, "") - - oidcRoleName := "TemplateAuthor" - - config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) { - cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}} - }) - config.AllowSignups = true - config.UserRoleField = "roles" - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureUserRoleManagement: 1}, + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ + claims := jwt.MapClaims{ "email": "alice@coder.com", - })) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - user, err := client.User(ctx, "alice") - require.NoError(t, err) - - require.Len(t, user.Roles, 0) - roleNames := []string{} - require.ElementsMatch(t, roleNames, []string{}) + } + // Login a new client that signs up + client, resp := runner.Login(t, claims) + require.Equal(t, http.StatusOK, resp.StatusCode) + // User should be in 0 groups. + runner.AssertRoles(t, "alice", []string{}) + // Force a refresh, and assert nothing has changes + runner.ForceRefresh(t, client, claims) + runner.AssertRoles(t, "alice", []string{}) }) - t.Run("NewUserAndRemoveRoles", func(t *testing.T) { + // A user has some roles, then on an oauth refresh will lose said + // roles from an updated claim. + t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) { + // TODO: Implement new feature to update roles/groups on OIDC + // refresh tokens. https://github.com/coder/coder/issues/9312 + t.Skip("Refreshing tokens does not update roles :(") t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - conf := coderdtest.NewOIDCConfig(t, "") - - oidcRoleName := "TemplateAuthor" - - config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) { - cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}} - }) - config.AllowSignups = true - config.UserRoleField = "roles" - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureUserRoleManagement: 1}, + const oidcRoleName = "TemplateAuthor" + runner := setupOIDCTest(t, oidcTestConfig{ + Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}, + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" + cfg.UserRoleMapping = map[string][]string{ + oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}, + } }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ + // User starts with the owner role + client, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{"random", oidcRoleName, rbac.RoleOwner()}, - })) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - _ = resp.Body.Close() - user, err := client.User(ctx, "alice") - require.NoError(t, err) + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()}) - require.Len(t, user.Roles, 3) - roleNames := []string{user.Roles[0].Name, user.Roles[1].Name, user.Roles[2].Name} - require.ElementsMatch(t, roleNames, []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()}) - - // Now remove the roles with a new oidc login - resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ + // Now refresh the oauth, and check the roles are removed. + // Force a refresh, and assert nothing has changes + runner.ForceRefresh(t, client, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{"random"}, - })) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - _ = resp.Body.Close() - user, err = client.User(ctx, "alice") - require.NoError(t, err) - - require.Len(t, user.Roles, 0) + }) + runner.AssertRoles(t, "alice", []string{}) }) + + // A user has some roles, then on another oauth login will lose said + // roles from an updated claim. + t.Run("NewUserAndRemoveRolesOnReAuth", func(t *testing.T) { + t.Parallel() + + const oidcRoleName = "TemplateAuthor" + runner := setupOIDCTest(t, oidcTestConfig{ + Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}, + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" + cfg.UserRoleMapping = map[string][]string{ + oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}, + } + }, + }) + + // User starts with the owner role + _, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + "roles": []string{"random", oidcRoleName, rbac.RoleOwner()}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()}) + + // Now login with oauth again, and check the roles are removed. + _, resp = runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + "roles": []string{"random"}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + + runner.AssertRoles(t, "alice", []string{}) + }) + + // All manual role updates should fail when role sync is enabled. t.Run("BlockAssignRoles", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, jwt.MapClaims{}) - config.AllowSignups = true - config.UserRoleField = "roles" - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureUserRoleManagement: 1}, + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ + _, resp := runner.Login(t, jwt.MapClaims{ "email": "alice@coder.com", "roles": []string{}, - })) - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + }) + require.Equal(t, http.StatusOK, resp.StatusCode) // Try to manually update user roles, even though controlled by oidc // role sync. - _, err = client.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{ + ctx := testutil.Context(t, testutil.WaitShort) + _, err := runner.AdminClient.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{ Roles: []string{ rbac.RoleTemplateAdmin(), }, @@ -164,199 +156,211 @@ func TestUserOIDC(t *testing.T) { t.Run("Groups", func(t *testing.T) { t.Parallel() + + // Assigns does a simple test of assigning a user to a group based + // on the oidc claims. t.Run("Assigns", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - conf := coderdtest.NewOIDCConfig(t, "") - const groupClaim = "custom-groups" - config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) { - cfg.GroupField = groupClaim - }) - config.AllowSignups = true - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + const groupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - groupName := "bingbong" - group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{ + ctx := testutil.Context(t, testutil.WaitShort) + group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ Name: groupName, }) require.NoError(t, err) require.Len(t, group.Members, 0) - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", + _, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", groupClaim: []string{groupName}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 1) + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) }) + + // Tests the group mapping feature. t.Run("AssignsMapped", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitMedium) - conf := coderdtest.NewOIDCConfig(t, "") + const groupClaim = "custom-groups" - oidcGroupName := "pingpong" - coderGroupName := "bingbong" - - config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) { - cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName} - }) - config.AllowSignups = true - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + const oidcGroupName = "pingpong" + const coderGroupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim + cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName} }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{ + ctx := testutil.Context(t, testutil.WaitShort) + group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ Name: coderGroupName, }) require.NoError(t, err) require.Len(t, group.Members, 0) - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", - "groups": []string{oidcGroupName}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 1) + _, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{oidcGroupName}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{coderGroupName}) }) - t.Run("AddThenRemove", func(t *testing.T) { + // User is in a group, then on an oauth refresh will lose said + // group. + t.Run("AddThenRemoveOnRefresh", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - conf := coderdtest.NewOIDCConfig(t, "") + // TODO: Implement new feature to update roles/groups on OIDC + // refresh tokens. https://github.com/coder/coder/issues/9312 + t.Skip("Refreshing tokens does not update groups :(") - config := conf.OIDCConfig(t, jwt.MapClaims{}) - config.AllowSignups = true - - client, firstUser := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + const groupClaim = "custom-groups" + const groupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim }, }) - // Add some extra users/groups that should be asserted after. - // Adding this user as there was a bug that removing 1 user removed - // all users from the group. - _, extra := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) - groupName := "bingbong" - group, err := client.CreateGroup(ctx, firstUser.OrganizationID, codersdk.CreateGroupRequest{ + ctx := testutil.Context(t, testutil.WaitShort) + group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ Name: groupName, }) - require.NoError(t, err, "create group") + require.NoError(t, err) + require.Len(t, group.Members, 0) - group, err = client.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{ - AddUsers: []string{ - firstUser.UserID.String(), - extra.ID.String(), - }, + client, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{groupName}, }) - require.NoError(t, err, "patch group") - require.Len(t, group.Members, 2, "expect both members") + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) - // Now add OIDC user into the group - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", - "groups": []string{groupName}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 3) - - // Login to remove the OIDC user from the group - resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", - "groups": []string{}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 2) - var expected []uuid.UUID - for _, mem := range group.Members { - expected = append(expected, mem.ID) - } - require.ElementsMatchf(t, expected, []uuid.UUID{firstUser.UserID, extra.ID}, "expected members") + // Refresh without the group claim + runner.ForceRefresh(t, client, jwt.MapClaims{ + "email": "alice@coder.com", + }) + runner.AssertGroups(t, "alice", []string{}) }) + t.Run("AddThenRemoveOnReAuth", func(t *testing.T) { + t.Parallel() + + const groupClaim = "custom-groups" + const groupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim + }, + }) + + ctx := testutil.Context(t, testutil.WaitShort) + group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{ + Name: groupName, + }) + require.NoError(t, err) + require.Len(t, group.Members, 0) + + _, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{groupName}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) + + // Refresh without the group claim + _, resp = runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{}) + }) + + // Updating groups where the claimed group does not exist. t.Run("NoneMatch", func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, jwt.MapClaims{}) - config.AllowSignups = true - - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + const groupClaim = "custom-groups" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - - groupName := "bingbong" - group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{ - Name: groupName, + _, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{"not-exists"}, }) - require.NoError(t, err) - require.Len(t, group.Members, 0) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{}) + }) - resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{ - "email": "colin@coder.com", - "groups": []string{"coolin"}, - })) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + // Updating groups where the claimed group does not exist creates + // the group. + t.Run("AutoCreate", func(t *testing.T) { + t.Parallel() - group, err = client.Group(ctx, group.ID) - require.NoError(t, err) - require.Len(t, group.Members, 0) + const groupClaim = "custom-groups" + const groupName = "make-me" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim + cfg.CreateMissingGroups = true + }, + }) + + _, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + groupClaim: []string{groupName}, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) + }) + }) + + t.Run("Refresh", func(t *testing.T) { + t.Run("RefreshTokensMultiple", func(t *testing.T) { + t.Parallel() + + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" + }, + }) + + claims := jwt.MapClaims{ + "email": "alice@coder.com", + } + // Login a new client that signs up + client, resp := runner.Login(t, claims) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Refresh multiple times. + for i := 0; i < 3; i++ { + runner.ForceRefresh(t, client, claims) + } }) }) } +// nolint:bodyclose func TestGroupSync(t *testing.T) { t.Parallel() @@ -470,28 +474,20 @@ func TestGroupSync(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitLong) - conf := coderdtest.NewOIDCConfig(t, "") - - config := conf.OIDCConfig(t, jwt.MapClaims{}, tc.modCfg) - - client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - OIDCConfig: config, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{codersdk.FeatureTemplateRBAC: 1}, + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.GroupField = "groups" + tc.modCfg(cfg) }, }) - admin, err := client.User(ctx, "me") - require.NoError(t, err) - require.Len(t, admin.OrganizationIDs, 1) - // Setup + ctx := testutil.Context(t, testutil.WaitLong) + org := runner.AdminUser.OrganizationIDs[0] + initialGroups := make(map[string]codersdk.Group) for _, group := range tc.initialOrgGroups { - newGroup, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{ + newGroup, err := runner.AdminClient.CreateGroup(ctx, org, codersdk.CreateGroupRequest{ Name: group, }) require.NoError(t, err) @@ -500,16 +496,16 @@ func TestGroupSync(t *testing.T) { } // Create the user and add them to their initial groups - _, user := coderdtest.CreateAnotherUser(t, client, admin.OrganizationIDs[0]) + _, user := coderdtest.CreateAnotherUser(t, runner.AdminClient, org) for _, group := range tc.initialUserGroups { - _, err := client.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{ + _, err := runner.AdminClient.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{ AddUsers: []string{user.ID.String()}, }) require.NoError(t, err) } // nolint:gocritic - _, err = api.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{ + _, err := runner.API.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{ NewLoginType: database.LoginTypeOIDC, UserID: user.ID, }) @@ -517,11 +513,11 @@ func TestGroupSync(t *testing.T) { // Log in the new user tc.claims["email"] = user.Email - resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.claims)) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - _ = resp.Body.Close() + _, resp := runner.Login(t, tc.claims) + require.Equal(t, http.StatusOK, resp.StatusCode) - orgGroups, err := client.GroupsByOrganization(ctx, admin.OrganizationIDs[0]) + // Check group sources + orgGroups, err := runner.AdminClient.GroupsByOrganization(ctx, org) require.NoError(t, err) for _, group := range orgGroups { @@ -567,24 +563,107 @@ func TestGroupSync(t *testing.T) { } } -func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response { - t.Helper() - client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code)) - require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil) - require.NoError(t, err) - req.AddCookie(&http.Cookie{ - Name: codersdk.OAuth2StateCookie, - Value: "somestate", - }) - res, err := client.HTTPClient.Do(req) - require.NoError(t, err) - defer res.Body.Close() - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - t.Log(string(data)) - return res +// oidcTestRunner is just a helper to setup and run oidc tests. +// An actual Coderd instance is used to run the tests. +type oidcTestRunner struct { + AdminClient *codersdk.Client + AdminUser codersdk.User + API *coderden.API + + // Login will call the OIDC flow with an unauthenticated client. + // The IDP will return the idToken claims. + Login func(t *testing.T, idToken jwt.MapClaims) (*codersdk.Client, *http.Response) + // ForceRefresh will use an authenticated codersdk.Client, and force their + // OIDC token to be expired and require a refresh. The refresh will use the claims provided. + // It just calls the /users/me endpoint to trigger the refresh. + ForceRefresh func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) +} + +type oidcTestConfig struct { + Userinfo jwt.MapClaims + + // Config allows modifying the Coderd OIDC configuration. + Config func(cfg *coderd.OIDCConfig) +} + +func (r *oidcTestRunner) AssertRoles(t *testing.T, userIdent string, roles []string) { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitMedium) + user, err := r.AdminClient.User(ctx, userIdent) + require.NoError(t, err) + + roleNames := []string{} + for _, role := range user.Roles { + roleNames = append(roleNames, role.Name) + } + require.ElementsMatch(t, roles, roleNames, "expected roles") +} + +func (r *oidcTestRunner) AssertGroups(t *testing.T, userIdent string, groups []string) { + t.Helper() + + if !slice.Contains(groups, database.EveryoneGroup) { + var cpy []string + cpy = append(cpy, groups...) + // always include everyone group + cpy = append(cpy, database.EveryoneGroup) + groups = cpy + } + ctx := testutil.Context(t, testutil.WaitMedium) + user, err := r.AdminClient.User(ctx, userIdent) + require.NoError(t, err) + + allGroups, err := r.AdminClient.GroupsByOrganization(ctx, user.OrganizationIDs[0]) + require.NoError(t, err) + + userInGroups := []string{} + for _, g := range allGroups { + for _, mem := range g.Members { + if mem.ID == user.ID { + userInGroups = append(userInGroups, g.Name) + } + } + } + + require.ElementsMatch(t, groups, userInGroups, "expected groups") +} + +func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner { + t.Helper() + + fake := oidctest.NewFakeIDP(t, + oidctest.WithStaticUserInfo(settings.Userinfo), + oidctest.WithLogging(t, nil), + // Run fake IDP on a real webserver + oidctest.WithServing(), + ) + + ctx := testutil.Context(t, testutil.WaitMedium) + cfg := fake.OIDCConfig(t, nil, settings.Config) + owner, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + OIDCConfig: cfg, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureUserRoleManagement: 1, + codersdk.FeatureTemplateRBAC: 1, + }, + }, + }) + admin, err := owner.User(ctx, "me") + require.NoError(t, err) + + helper := oidctest.NewLoginHelper(owner, fake) + + return &oidcTestRunner{ + AdminClient: owner, + AdminUser: admin, + API: api, + Login: helper.Login, + ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) { + helper.ForceRefresh(t, api.Database, client, idToken) + }, + } }