test: add full OIDC fake IDP (#9317)

* test: implement fake OIDC provider with full functionality
* Refactor existing tests
This commit is contained in:
Steven Masley 2023-08-25 14:34:07 -05:00 committed by GitHub
parent 0a213a6ac3
commit d9d4d74f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1617 additions and 647 deletions

View File

@ -31,15 +31,13 @@ import (
"time"
"cloud.google.com/go/compute/metadata"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/fullsailor/pkcs7"
"github.com/golang-jwt/jwt"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
@ -1020,152 +1018,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
}
}
type OIDCConfig struct {
key *rsa.PrivateKey
issuer string
// These are optional
refreshToken string
oidcTokenExpires func() time.Time
tokenSource func() (*oauth2.Token, error)
}
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.refreshToken = token
}
}
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.oidcTokenExpires = expFunc
}
}
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.tokenSource = src
}
}
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
t.Helper()
block, _ := pem.Decode([]byte(testRSAPrivateKey))
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)
if issuer == "" {
issuer = "https://coder.com"
}
cfg := &OIDCConfig{
key: pkey,
issuer: issuer,
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "/?state=" + url.QueryEscape(state)
}
type tokenSource struct {
src func() (*oauth2.Token, error)
}
func (s tokenSource) Token() (*oauth2.Token, error) {
return s.src()
}
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
if cfg.tokenSource == nil {
return nil
}
return tokenSource{
src: cfg.tokenSource,
}
}
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
token, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, xerrors.Errorf("decode code: %w", err)
}
var exp time.Time
if cfg.oidcTokenExpires != nil {
exp = cfg.oidcTokenExpires()
}
return (&oauth2.Token{
AccessToken: "token",
RefreshToken: cfg.refreshToken,
Expiry: exp,
}).WithExtra(map[string]interface{}{
"id_token": string(token),
}), nil
}
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
t.Helper()
if _, ok := claims["exp"]; !ok {
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
}
if _, ok := claims["iss"]; !ok {
claims["iss"] = cfg.issuer
}
if _, ok := claims["sub"]; !ok {
claims["sub"] = "testme"
}
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString([]byte(signed))
}
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
// By default, the provider can be empty.
// This means it won't support any endpoints!
provider := &oidc.Provider{}
if userInfoClaims != nil {
resp, err := json.Marshal(userInfoClaims)
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(resp)
}))
t.Cleanup(srv.Close)
cfg := &oidc.ProviderConfig{
UserInfoURL: srv.URL,
}
provider = cfg.NewProvider(context.Background())
}
newCFG := &coderd.OIDCConfig{
OAuth2Config: cfg,
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
}, &oidc.Config{
SkipClientIDCheck: true,
}),
Provider: provider,
UsernameField: "preferred_username",
EmailField: "email",
AuthURLParams: map[string]string{"access_type": "offline"},
GroupField: "groups",
}
for _, opt := range opts {
opt(newCFG)
}
return newCFG
}
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
// instance authentication for Azure.
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
@ -1254,22 +1106,6 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
return cerr
}
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
-----END RSA PRIVATE KEY-----`
func DeploymentValues(t testing.TB) *codersdk.DeploymentValues {
var cfg codersdk.DeploymentValues
opts := cfg.Options()

View File

@ -0,0 +1,103 @@
package oidctest
import (
"net/http"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// LoginHelper helps with logging in a user and refreshing their oauth tokens.
// It is mainly because refreshing oauth tokens is a bit tricky and requires
// some database manipulation.
type LoginHelper struct {
fake *FakeIDP
client *codersdk.Client
}
func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper {
if client == nil {
panic("client must not be nil")
}
if fake == nil {
panic("fake must not be nil")
}
return &LoginHelper{
fake: fake,
client: client,
}
}
// Login just helps by making an unauthenticated client and logging in with
// the given claims. All Logins should be unauthenticated, so this is a
// convenience method.
func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) {
t.Helper()
unauthenticatedClient := codersdk.New(h.client.URL)
return h.fake.Login(t, unauthenticatedClient, idTokenClaims)
}
// ExpireOauthToken expires the oauth token for the given user.
func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink {
t.Helper()
//nolint:gocritic // Testing
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
id, _, err := httpmw.SplitAPIToken(user.SessionToken())
require.NoError(t, err)
// We need to get the OIDC link and update it in the database to force
// it to be expired.
key, err := db.GetAPIKeyByID(ctx, id)
require.NoError(t, err, "get api key")
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: database.LoginTypeOIDC,
})
require.NoError(t, err, "get user link")
// Expire the oauth link for the given user.
updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: link.OAuthAccessToken,
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthExpiry: time.Now().Add(time.Hour * -1),
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err, "expire user link")
return updated
}
// ForceRefresh forces the client to refresh its oauth token. It does this by
// expiring the oauth token, then doing an authenticated call. This will force
// the API Key middleware to refresh the oauth token.
//
// A unit test assertion makes sure the refresh token is used.
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) {
t.Helper()
link := h.ExpireOauthToken(t, db, user)
// Updates the claims that the IDP will return. By default, it always
// uses the original claims for the original oauth token.
h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken)
t.Cleanup(func() {
require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
})
// Do any authenticated call to force the refresh
_, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
require.NoError(t, err, "user must be able to be fetched")
}

View File

@ -0,0 +1,793 @@
package oidctest
import (
"context"
"crypto"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/coder/coder/v2/coderd/util/syncmap"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-chi/chi/v5"
"github.com/go-jose/go-jose/v3"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/codersdk"
)
// FakeIDP is a functional OIDC provider.
// It only supports 1 OIDC client.
type FakeIDP struct {
issuer string
key *rsa.PrivateKey
provider providerJSON
handler http.Handler
cfg *oauth2.Config
// clientID to be used by coderd
clientID string
clientSecret string
logger slog.Logger
// These maps are used to control the state of the IDP.
// That is the various access tokens, refresh tokens, states, etc.
codeToStateMap *syncmap.Map[string, string]
// Token -> Email
accessTokens *syncmap.Map[string, string]
// Refresh Token -> Email
refreshTokensUsed *syncmap.Map[string, bool]
refreshTokens *syncmap.Map[string, string]
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
// hooks
// hookValidRedirectURL can be used to reject a redirect url from the
// IDP -> Application. Almost all IDPs have the concept of
// "Authorized Redirect URLs". This can be used to emulate that.
hookValidRedirectURL func(redirectURL string) error
hookUserInfo func(email string) jwt.MapClaims
fakeCoderd func(req *http.Request) (*http.Response, error)
hookOnRefresh func(email string) error
// Custom authentication for the client. This is useful if you want
// to test something like PKI auth vs a client_secret.
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
serve bool
}
type FakeIDPOpt func(idp *FakeIDP)
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookValidRedirectURL = hook
}
}
// WithRefreshHook is called when a refresh token is used. The email is
// the email of the user that is being refreshed assuming the claims are correct.
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookOnRefresh = hook
}
}
func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookAuthenticateClient = hook
}
}
// WithLogging is optional, but will log some HTTP calls made to the IDP.
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = slogtest.Make(t, options)
}
}
// WithStaticUserInfo is optional, but will return the same user info for
// every user on the /userinfo endpoint.
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookUserInfo = func(_ string) jwt.MapClaims {
return info
}
}
}
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookUserInfo = userInfoFunc
}
}
// WithServing makes the IDP run an actual http server.
func WithServing() func(*FakeIDP) {
return func(f *FakeIDP) {
f.serve = true
}
}
func WithIssuer(issuer string) func(*FakeIDP) {
return func(f *FakeIDP) {
f.issuer = issuer
}
}
const (
// nolint:gosec // It thinks this is a secret lol
tokenPath = "/oauth2/token"
authorizePath = "/oauth2/authorize"
keysPath = "/oauth2/keys"
userInfoPath = "/oauth2/userinfo"
)
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
t.Helper()
block, _ := pem.Decode([]byte(testRSAPrivateKey))
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)
idp := &FakeIDP{
key: pkey,
clientID: uuid.NewString(),
clientSecret: uuid.NewString(),
logger: slog.Make(),
codeToStateMap: syncmap.New[string, string](),
accessTokens: syncmap.New[string, string](),
refreshTokens: syncmap.New[string, string](),
refreshTokensUsed: syncmap.New[string, bool](),
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
hookValidRedirectURL: func(redirectURL string) error { return nil },
}
for _, opt := range opts {
opt(idp)
}
if idp.issuer == "" {
idp.issuer = "https://coder.com"
}
idp.handler = idp.httpHandler(t)
idp.updateIssuerURL(t, idp.issuer)
if idp.serve {
idp.realServer(t)
}
return idp
}
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
t.Helper()
u, err := url.Parse(issuer)
require.NoError(t, err, "invalid issuer URL")
f.issuer = issuer
// providerJSON is the JSON representation of the OpenID Connect provider
// These are all the urls that the IDP will respond to.
f.provider = providerJSON{
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
Algorithms: []string{
"RS256",
},
}
}
// realServer turns the FakeIDP into a real http server.
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
srv := httptest.NewUnstartedServer(f.handler)
srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
srv.Start()
t.Cleanup(srv.CloseClientConnections)
t.Cleanup(srv.Close)
t.Cleanup(cancel)
f.updateIssuerURL(t, srv.URL)
return srv
}
// Login does the full OIDC flow starting at the "LoginButton".
// The client argument is just to get the URL of the Coder instance.
//
// The client passed in is just to get the url of the Coder instance.
// The actual client that is used is 100% unauthenticated and fresh.
func (f *FakeIDP) Login(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
t.Helper()
client, resp := f.AttemptLogin(t, client, idTokenClaims, opts...)
require.Equal(t, http.StatusOK, resp.StatusCode, "client failed to login")
return client, resp
}
func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
t.Helper()
var err error
cli := f.HTTPClient(client.HTTPClient)
shallowCpyCli := *cli
if shallowCpyCli.Jar == nil {
shallowCpyCli.Jar, err = cookiejar.New(nil)
require.NoError(t, err, "failed to create cookie jar")
}
unauthenticated := codersdk.New(client.URL)
unauthenticated.HTTPClient = &shallowCpyCli
return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...)
}
// LoginWithClient reuses the context of the passed in client. This means the same
// cookies will be used. This should be an unauthenticated client in most cases.
//
// This is a niche case, but it is needed for testing ConvertLoginType.
func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
t.Helper()
coderOauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback")
require.NoError(t, err)
f.SetRedirect(t, coderOauthURL.String())
cli := f.HTTPClient(client.HTTPClient)
cli.CheckRedirect = func(req *http.Request, via []*http.Request) error {
// Store the idTokenClaims to the specific state request. This ties
// the claims 1:1 with a given authentication flow.
state := req.URL.Query().Get("state")
f.stateToIDTokenClaims.Store(state, idTokenClaims)
return nil
}
req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil)
require.NoError(t, err)
if cli.Jar == nil {
cli.Jar, err = cookiejar.New(nil)
require.NoError(t, err, "failed to create cookie jar")
}
for _, opt := range opts {
opt(req)
}
res, err := cli.Do(req)
require.NoError(t, err)
// If the coder session token exists, return the new authed client!
var user *codersdk.Client
cookies := cli.Jar.Cookies(client.URL)
for _, cookie := range cookies {
if cookie.Name == codersdk.SessionTokenCookie {
user = codersdk.New(client.URL)
user.SetSessionToken(cookie.Value)
}
}
t.Cleanup(func() {
if res.Body != nil {
_ = res.Body.Close()
}
})
return user, res
}
// OIDCCallback will emulate the IDP redirecting back to the Coder callback.
// This is helpful if no Coderd exists because the IDP needs to redirect to
// something.
// Essentially this is used to fake the Coderd side of the exchange.
// The flow starts at the user hitting the OIDC login page.
func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) {
t.Helper()
if f.serve {
panic("cannot use OIDCCallback with WithServing. This is only for the in memory usage")
}
f.stateToIDTokenClaims.Store(state, idTokenClaims)
cli := f.HTTPClient(nil)
u := f.cfg.AuthCodeURL(state)
req, err := http.NewRequest("GET", u, nil)
require.NoError(t, err)
resp, err := cli.Do(req.WithContext(context.Background()))
require.NoError(t, err)
t.Cleanup(func() {
if resp.Body != nil {
_ = resp.Body.Close()
}
})
return resp, nil
}
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
}
// newCode enforces the code exchanged is actually a valid code
// created by the IDP.
func (f *FakeIDP) newCode(state string) string {
code := uuid.NewString()
f.codeToStateMap.Store(code, state)
return code
}
// newToken enforces the access token exchanged is actually a valid access token
// created by the IDP.
func (f *FakeIDP) newToken(email string) string {
accessToken := uuid.NewString()
f.accessTokens.Store(accessToken, email)
return accessToken
}
func (f *FakeIDP) newRefreshTokens(email string) string {
refreshToken := uuid.NewString()
f.refreshTokens.Store(refreshToken, email)
return refreshToken
}
// authenticateBearerTokenRequest enforces the access token is valid.
func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) {
t.Helper()
auth := req.Header.Get("Authorization")
token := strings.TrimPrefix(auth, "Bearer ")
_, ok := f.accessTokens.Load(token)
if !ok {
return "", xerrors.New("invalid access token")
}
return token, nil
}
// authenticateOIDCClientRequest enforces the client_id and client_secret are valid.
func (f *FakeIDP) authenticateOIDCClientRequest(t testing.TB, req *http.Request) (url.Values, error) {
t.Helper()
if f.hookAuthenticateClient != nil {
return f.hookAuthenticateClient(t, req)
}
data, err := io.ReadAll(req.Body)
if !assert.NoError(t, err, "read token request body") {
return nil, xerrors.Errorf("authenticate request, read body: %w", err)
}
values, err := url.ParseQuery(string(data))
if !assert.NoError(t, err, "parse token request values") {
return nil, xerrors.New("invalid token request")
}
if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") {
return nil, xerrors.New("client_id mismatch")
}
if !assert.Equal(t, f.clientSecret, values.Get("client_secret"), "client_secret mismatch") {
return nil, xerrors.New("client_secret mismatch")
}
return values, nil
}
// encodeClaims is a helper func to convert claims to a valid JWT.
func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string {
t.Helper()
if _, ok := claims["exp"]; !ok {
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
}
if _, ok := claims["aud"]; !ok {
claims["aud"] = f.clientID
}
if _, ok := claims["iss"]; !ok {
claims["iss"] = f.issuer
}
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key)
require.NoError(t, err)
return signed
}
// httpHandler is the IDP http server.
func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
t.Helper()
mux := chi.NewMux()
// This endpoint is required to initialize the OIDC provider.
// It is used to get the OIDC configuration.
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http OIDC config", slog.F("url", r.URL.String()))
_ = json.NewEncoder(rw).Encode(f.provider)
})
// Authorize is called when the user is redirected to the IDP to login.
// This is the browser hitting the IDP and the user logging into Google or
// w/e and clicking "Allow". They will be redirected back to the redirect
// when this is done.
mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call authorize", slog.F("url", r.URL.String()))
clientID := r.URL.Query().Get("client_id")
if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") {
http.Error(rw, "invalid client_id", http.StatusBadRequest)
return
}
redirectURI := r.URL.Query().Get("redirect_uri")
state := r.URL.Query().Get("state")
scope := r.URL.Query().Get("scope")
assert.NotEmpty(t, scope, "scope is empty")
responseType := r.URL.Query().Get("response_type")
switch responseType {
case "code":
case "token":
t.Errorf("response_type %q not supported", responseType)
http.Error(rw, "invalid response_type", http.StatusBadRequest)
return
default:
t.Errorf("unexpected response_type %q", responseType)
http.Error(rw, "invalid response_type", http.StatusBadRequest)
return
}
err := f.hookValidRedirectURL(redirectURI)
if err != nil {
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
return
}
ru, err := url.Parse(redirectURI)
if err != nil {
t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error())
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
return
}
q := ru.Query()
q.Set("state", state)
q.Set("code", f.newCode(state))
ru.RawQuery = q.Encode()
http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect)
}))
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
values, err := f.authenticateOIDCClientRequest(t, r)
f.logger.Info(r.Context(), "http idp call token",
slog.Error(err),
slog.F("values", values.Encode()),
)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
return
}
getEmail := func(claims jwt.MapClaims) string {
email, ok := claims["email"]
if !ok {
return "unknown"
}
emailStr, ok := email.(string)
if !ok {
return "wrong-type"
}
return emailStr
}
var claims jwt.MapClaims
switch values.Get("grant_type") {
case "authorization_code":
code := values.Get("code")
if !assert.NotEmpty(t, code, "code is empty") {
http.Error(rw, "invalid code", http.StatusBadRequest)
return
}
stateStr, ok := f.codeToStateMap.Load(code)
if !assert.True(t, ok, "invalid code") {
http.Error(rw, "invalid code", http.StatusBadRequest)
return
}
// Always invalidate the code after it is used.
f.codeToStateMap.Delete(code)
idTokenClaims, ok := f.stateToIDTokenClaims.Load(stateStr)
if !ok {
t.Errorf("missing id token claims")
http.Error(rw, "missing id token claims", http.StatusBadRequest)
return
}
claims = idTokenClaims
case "refresh_token":
refreshToken := values.Get("refresh_token")
if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") {
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
return
}
_, ok := f.refreshTokens.Load(refreshToken)
if !assert.True(t, ok, "invalid refresh_token") {
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
return
}
idTokenClaims, ok := f.refreshIDTokenClaims.Load(refreshToken)
if !ok {
t.Errorf("missing id token claims in refresh")
http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest)
return
}
claims = idTokenClaims
err := f.hookOnRefresh(getEmail(claims))
if err != nil {
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
return
}
f.refreshTokensUsed.Store(refreshToken, true)
// Always invalidate the refresh token after it is used.
f.refreshTokens.Delete(refreshToken)
default:
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
return
}
exp := time.Now().Add(time.Minute * 5)
claims["exp"] = exp.UnixMilli()
email := getEmail(claims)
refreshToken := f.newRefreshTokens(email)
token := map[string]interface{}{
"access_token": f.newToken(email),
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": int64((time.Minute * 5).Seconds()),
"id_token": f.encodeClaims(t, claims),
}
// Store the claims for the next refresh
f.refreshIDTokenClaims.Store(refreshToken, claims)
rw.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(rw).Encode(token)
}))
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
token, err := f.authenticateBearerTokenRequest(t, r)
f.logger.Info(r.Context(), "http call idp user info",
slog.Error(err),
slog.F("url", r.URL.String()),
)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
return
}
email, ok := f.accessTokens.Load(token)
if !ok {
t.Errorf("access token user for user_info has no email to indicate which user")
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
return
}
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
}))
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http call idp /keys")
set := jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: f.key.Public(),
KeyID: "test-key",
Algorithm: "RSA",
},
},
}
_ = json.NewEncoder(rw).Encode(set)
}))
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
})
return mux
}
// HTTPClient does nothing if IsServing is used.
//
// If IsServing is not used, then it will return a client that will make requests
// to the IDP all in memory. If a request is not to the IDP, then the passed in
// client will be used. If no client is passed in, then any regular network
// requests will fail.
func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
if f.serve {
if rest == nil || rest.Transport == nil {
return &http.Client{}
}
return rest
}
var jar http.CookieJar
if rest != nil {
jar = rest.Jar
}
return &http.Client{
Jar: jar,
Transport: fakeRoundTripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
u, _ := url.Parse(f.issuer)
if req.URL.Host != u.Host {
if f.fakeCoderd != nil {
return f.fakeCoderd(req)
}
if rest == nil || rest.Transport == nil {
return nil, fmt.Errorf("unexpected network request to %q", req.URL.Host)
}
return rest.Transport.RoundTrip(req)
}
resp := httptest.NewRecorder()
f.handler.ServeHTTP(resp, req)
return resp.Result(), nil
},
},
}
}
// RefreshUsed returns if the refresh token has been used. All refresh tokens
// can only be used once, then they are deleted.
func (f *FakeIDP) RefreshUsed(refreshToken string) bool {
used, _ := f.refreshTokensUsed.Load(refreshToken)
return used
}
// UpdateRefreshClaims allows the caller to change what claims are returned
// for a given refresh token. By default, all refreshes use the same claims as
// the original IDToken issuance.
func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) {
f.refreshIDTokenClaims.Store(refreshToken, claims)
}
// SetRedirect is required for the IDP to know where to redirect and call
// Coderd.
func (f *FakeIDP) SetRedirect(t testing.TB, u string) {
t.Helper()
f.cfg.RedirectURL = u
}
// SetCoderdCallback is optional and only works if not using the IsServing.
// It will setup a fake "Coderd" for the IDP to call when the IDP redirects
// back after authenticating.
func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) {
if f.serve {
panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'")
}
f.fakeCoderd = callback
}
func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) {
f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) {
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
return resp.Result(), nil
})
}
// OIDCConfig returns the OIDC config to use for Coderd.
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
t.Helper()
if len(scopes) == 0 {
scopes = []string{"openid", "email", "profile"}
}
oauthCfg := &oauth2.Config{
ClientID: f.clientID,
ClientSecret: f.clientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: f.provider.AuthURL,
TokenURL: f.provider.TokenURL,
AuthStyle: oauth2.AuthStyleInParams,
},
// If the user is using a real network request, they will need to do
// 'fake.SetRedirect()'
RedirectURL: "https://redirect.com",
Scopes: scopes,
}
ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil))
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
require.NoError(t, err, "failed to create OIDC provider")
cfg := &coderd.OIDCConfig{
OAuth2Config: oauthCfg,
Provider: p,
Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{f.key.Public()},
}, &oidc.Config{
ClientID: oauthCfg.ClientID,
SupportedSigningAlgs: []string{
"RS256",
},
// Todo: add support for Now()
}),
UsernameField: "preferred_username",
EmailField: "email",
AuthURLParams: map[string]string{"access_type": "offline"},
}
for _, opt := range opts {
if opt == nil {
continue
}
opt(cfg)
}
f.cfg = oauthCfg
return cfg
}
type fakeRoundTripper struct {
roundTrip func(req *http.Request) (*http.Response, error)
}
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return f.roundTrip(req)
}
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
-----END RSA PRIVATE KEY-----`

View File

@ -0,0 +1,72 @@
package oidctest_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
// TestFakeIDPBasicFlow tests the basic flow of the fake IDP.
// It is done all in memory with no actual network requests.
// nolint:bodyclose
func TestFakeIDPBasicFlow(t *testing.T) {
t.Parallel()
fake := oidctest.NewFakeIDP(t,
oidctest.WithLogging(t, nil),
)
var handler http.Handler
srv := httptest.NewServer(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
})))
defer srv.Close()
cfg := fake.OIDCConfig(t, nil)
cli := fake.HTTPClient(nil)
ctx := oidc.ClientContext(context.Background(), cli)
const expectedState = "random-state"
var token *oauth2.Token
// This is the Coder callback using an actual network request.
fake.SetCoderdCallbackHandler(func(w http.ResponseWriter, r *http.Request) {
// Emulate OIDC flow
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
assert.Equal(t, expectedState, state, "state mismatch")
oauthToken, err := cfg.Exchange(ctx, code)
if assert.NoError(t, err, "failed to exchange code") {
assert.NotEmpty(t, oauthToken.AccessToken, "access token is empty")
assert.NotEmpty(t, oauthToken.RefreshToken, "refresh token is empty")
}
token = oauthToken
})
resp, err := fake.OIDCCallback(t, expectedState, jwt.MapClaims{})
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
// Test the user info
_, err = cfg.Provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
require.NoError(t, err)
// Now test it can refresh
refreshed, err := cfg.TokenSource(ctx, &oauth2.Token{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expiry: time.Now().Add(time.Minute * -1),
}).Token()
require.NoError(t, err, "failed to refresh token")
require.NotEmpty(t, refreshed.AccessToken, "access token is empty on refresh")
}

View File

@ -215,7 +215,10 @@ func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
}
var tokenRes struct {
oauth2.Token
AccessToken string `json:"access_token"`
TokenType string `json:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
// Extra fields returned by the refresh that are needed
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now

View File

@ -12,12 +12,15 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/oauthpki"
"github.com/coder/coder/v2/testutil"
)
@ -123,6 +126,58 @@ func TestAzureADPKIOIDC(t *testing.T) {
require.Error(t, err, "error expected")
}
// TestAzureAKPKIWithCoderd uses a fake IDP and a real Coderd to test PKI auth.
// nolint:bodyclose
func TestAzureAKPKIWithCoderd(t *testing.T) {
t.Parallel()
scopes := []string{"openid", "email", "profile", "offline_access"}
fake := oidctest.NewFakeIDP(t,
oidctest.WithIssuer("https://login.microsoftonline.com/fake_app"),
oidctest.WithCustomClientAuth(func(t testing.TB, req *http.Request) (url.Values, error) {
values := assertJWTAuth(t, req)
if values == nil {
return nil, xerrors.New("authorizatin failed in request")
}
return values, nil
}),
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
oauthCfg := cfg.OAuth2Config.(*oauth2.Config)
// Create the oauthpki config
pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
ClientID: oauthCfg.ClientID,
TokenURL: oauthCfg.Endpoint.TokenURL,
Scopes: scopes,
PemEncodedKey: []byte(testClientKey),
PemEncodedCert: []byte(testClientCert),
Config: oauthCfg,
})
require.NoError(t, err)
cfg.OAuth2Config = pki
owner, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
OIDCConfig: cfg,
})
// Create a user and login
const email = "alice@coder.com"
claims := jwt.MapClaims{
"email": email,
}
helper := oidctest.NewLoginHelper(owner, fake)
user, _ := helper.Login(t, claims)
// Try refreshing the token more than once.
for i := 0; i < 2; i++ {
helper.ForceRefresh(t, api.Database, user, claims)
}
}
// TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
// AD instance and saving them to replay, removing some details.
// The reason this is done is that this is the only way to assert values
@ -269,7 +324,7 @@ func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// assertJWTAuth will assert the basic JWT auth assertions. It will return the
// url.Values from the request body for any additional assertions to be made.
func assertJWTAuth(t *testing.T, r *http.Request) url.Values {
func assertJWTAuth(t testing.TB, r *http.Request) url.Values {
body, err := io.ReadAll(r.Body)
if !assert.NoError(t, err) {
return nil

View File

@ -4,28 +4,25 @@ import (
"context"
"crypto"
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
"github.com/golang-jwt/jwt/v4"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
@ -35,85 +32,42 @@ import (
// This test specifically tests logging in with OIDC when an expired
// OIDC session token exists.
// The token refreshing should not happen since we are reauthenticating.
// nolint:bodyclose
func TestOIDCOauthLoginWithExisting(t *testing.T) {
t.Parallel()
conf := coderdtest.NewOIDCConfig(t, "",
// Provide a refresh token so we use the refresh token flow
coderdtest.WithRefreshToken("refresh_token"),
// We need to set the expire in the future for the first api calls.
coderdtest.WithTokenExpires(func() time.Time {
return time.Now().Add(time.Hour).UTC()
}),
// No refresh should actually happen in this test.
coderdtest.WithTokenSource(func() (*oauth2.Token, error) {
return nil, xerrors.New("token should not require refresh")
fake := oidctest.NewFakeIDP(t,
oidctest.WithRefreshHook(func(_ string) error {
return xerrors.New("refreshing token should never occur")
}),
oidctest.WithServing(),
)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
auditor := audit.NewMock()
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.IgnoreUserInfo = true
})
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
OIDCConfig: cfg,
})
const username = "alice"
claims := jwt.MapClaims{
"email": "alice@coder.com",
"email_verified": true,
"preferred_username": username,
}
config := conf.OIDCConfig(t, claims)
config.AllowSignups = true
config.IgnoreUserInfo = true
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: config,
Logger: &logger,
})
helper := oidctest.NewLoginHelper(client, fake)
// Signup alice
resp := oidcCallback(t, client, conf.EncodeClaims(t, claims))
// Set the client to use this OIDC context
authCookie := authCookieValue(resp.Cookies())
client.SetSessionToken(authCookie)
_ = resp.Body.Close()
userClient, _ := helper.Login(t, claims)
ctx := testutil.Context(t, testutil.WaitLong)
// Verify the user and oauth link
user, err := client.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, username, user.Username)
// Expire the link. This will force the client to refresh the token.
helper.ExpireOauthToken(t, api.Database, userClient)
// nolint:gocritic
link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: user.ID,
LoginType: database.LoginType(user.LoginType),
})
require.NoError(t, err, "failed to get user link")
// Expire the link
// nolint:gocritic
_, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
OAuthAccessToken: link.OAuthAccessToken,
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(),
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err, "failed to update user link")
// Log in again with OIDC
loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) {
req.AddCookie(&http.Cookie{
Name: codersdk.SessionTokenCookie,
Value: authCookie,
Path: "/",
})
})
require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode)
_ = loginAgain.Body.Close()
// Try to use new login
client.SetSessionToken(authCookieValue(resp.Cookies()))
_, err = client.User(ctx, "me")
require.NoError(t, err, "use new session")
// Instead of refreshing, just log in again.
helper.Login(t, claims)
}
func TestUserLogin(t *testing.T) {
@ -660,7 +614,7 @@ func TestUserOIDC(t *testing.T) {
"email": "kyle@kwc.io",
},
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
Username: "kyle",
}, {
Name: "EmailNotVerified",
@ -685,7 +639,7 @@ func TestUserOIDC(t *testing.T) {
"email_verified": false,
},
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
Username: "kyle",
IgnoreEmailVerified: true,
}, {
@ -709,7 +663,7 @@ func TestUserOIDC(t *testing.T) {
EmailDomain: []string{
"kwc.io",
},
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
Name: "EmptyClaims",
IDTokenClaims: jwt.MapClaims{},
@ -730,7 +684,7 @@ func TestUserOIDC(t *testing.T) {
},
Username: "kyle",
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
Name: "UsernameFromClaims",
IDTokenClaims: jwt.MapClaims{
@ -740,7 +694,7 @@ func TestUserOIDC(t *testing.T) {
},
Username: "hotdog",
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
// Services like Okta return the email as the username:
// https://developer.okta.com/docs/reference/api/oidc/#base-claims-always-present
@ -752,7 +706,7 @@ func TestUserOIDC(t *testing.T) {
},
Username: "kyle",
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
// See: https://github.com/coder/coder/issues/4472
Name: "UsernameIsEmail",
@ -761,7 +715,7 @@ func TestUserOIDC(t *testing.T) {
},
Username: "kyle",
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
Name: "WithPicture",
IDTokenClaims: jwt.MapClaims{
@ -773,7 +727,7 @@ func TestUserOIDC(t *testing.T) {
Username: "kyle",
AllowSignups: true,
AvatarURL: "/example.png",
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
Name: "WithUserInfoClaims",
IDTokenClaims: jwt.MapClaims{
@ -787,7 +741,7 @@ func TestUserOIDC(t *testing.T) {
Username: "potato",
AllowSignups: true,
AvatarURL: "/example.png",
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
Name: "GroupsDoesNothing",
IDTokenClaims: jwt.MapClaims{
@ -795,7 +749,7 @@ func TestUserOIDC(t *testing.T) {
"groups": []string{"pingpong"},
},
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
Name: "UserInfoOverridesIDTokenClaims",
IDTokenClaims: jwt.MapClaims{
@ -810,7 +764,7 @@ func TestUserOIDC(t *testing.T) {
Username: "user",
AllowSignups: true,
IgnoreEmailVerified: false,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}, {
Name: "InvalidUserInfo",
IDTokenClaims: jwt.MapClaims{
@ -837,36 +791,41 @@ func TestUserOIDC(t *testing.T) {
Username: "user",
IgnoreUserInfo: true,
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
StatusCode: http.StatusOK,
}} {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
fake := oidctest.NewFakeIDP(t,
oidctest.WithRefreshHook(func(_ string) error {
return xerrors.New("refreshing token should never occur")
}),
oidctest.WithServing(),
oidctest.WithStaticUserInfo(tc.UserInfoClaims),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = tc.AllowSignups
cfg.EmailDomain = tc.EmailDomain
cfg.IgnoreEmailVerified = tc.IgnoreEmailVerified
cfg.IgnoreUserInfo = tc.IgnoreUserInfo
})
auditor := audit.NewMock()
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig(t, tc.UserInfoClaims)
config.AllowSignups = tc.AllowSignups
config.EmailDomain = tc.EmailDomain
config.IgnoreEmailVerified = tc.IgnoreEmailVerified
config.IgnoreUserInfo = tc.IgnoreUserInfo
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
client := coderdtest.New(t, &coderdtest.Options{
owner := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: config,
OIDCConfig: cfg,
Logger: &logger,
})
numLogs := len(auditor.AuditLogs())
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.IDTokenClaims))
client, resp := fake.AttemptLogin(t, owner, tc.IDTokenClaims)
numLogs++ // add an audit log for login
assert.Equal(t, tc.StatusCode, resp.StatusCode)
require.Equal(t, tc.StatusCode, resp.StatusCode)
ctx := testutil.Context(t, testutil.WaitLong)
if tc.Username != "" {
client.SetSessionToken(authCookieValue(resp.Cookies()))
user, err := client.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, tc.Username, user.Username)
@ -877,7 +836,6 @@ func TestUserOIDC(t *testing.T) {
}
if tc.AvatarURL != "" {
client.SetSessionToken(authCookieValue(resp.Cookies()))
user, err := client.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, tc.AvatarURL, user.AvatarURL)
@ -890,26 +848,29 @@ func TestUserOIDC(t *testing.T) {
t.Run("OIDCConvert", func(t *testing.T) {
t.Parallel()
auditor := audit.NewMock()
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig(t, nil)
config.AllowSignups = true
cfg := coderdtest.DeploymentValues(t)
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: config,
DeploymentValues: cfg,
fake := oidctest.NewFakeIDP(t,
oidctest.WithRefreshHook(func(_ string) error {
return xerrors.New("refreshing token should never occur")
}),
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
owner := coderdtest.CreateFirstUser(t, client)
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: cfg,
})
owner := coderdtest.CreateFirstUser(t, client)
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
code := conf.EncodeClaims(t, jwt.MapClaims{
claims := jwt.MapClaims{
"email": userData.Email,
})
}
var err error
user.HTTPClient.Jar, err = cookiejar.New(nil)
require.NoError(t, err)
@ -921,52 +882,58 @@ func TestUserOIDC(t *testing.T) {
})
require.NoError(t, err)
resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
fake.LoginWithClient(t, user, claims, func(r *http.Request) {
r.URL.RawQuery = url.Values{
"oidc_merge_state": {convertResponse.StateString},
}.Encode()
r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken())
cookies := user.HTTPClient.Jar.Cookies(r.URL)
for _, cookie := range cookies {
r.AddCookie(cookie)
}
})
})
t.Run("AlternateUsername", func(t *testing.T) {
t.Parallel()
auditor := audit.NewMock()
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig(t, nil)
config.AllowSignups = true
fake := oidctest.NewFakeIDP(t,
oidctest.WithRefreshHook(func(_ string) error {
return xerrors.New("refreshing token should never occur")
}),
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: config,
OIDCConfig: cfg,
})
numLogs := len(auditor.AuditLogs())
code := conf.EncodeClaims(t, jwt.MapClaims{
numLogs := len(auditor.AuditLogs())
claims := jwt.MapClaims{
"email": "jon@coder.com",
})
resp := oidcCallback(t, client, code)
}
userClient, _ := fake.Login(t, client, claims)
numLogs++ // add an audit log for login
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
ctx := testutil.Context(t, testutil.WaitLong)
client.SetSessionToken(authCookieValue(resp.Cookies()))
user, err := client.User(ctx, "me")
user, err := userClient.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, "jon", user.Username)
// Pass a different subject field so that we prompt creating a
// new user.
code = conf.EncodeClaims(t, jwt.MapClaims{
// new user
userClient, _ = fake.Login(t, client, jwt.MapClaims{
"email": "jon@example2.com",
"sub": "diff",
})
resp = oidcCallback(t, client, code)
numLogs++ // add an audit log for login
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
client.SetSessionToken(authCookieValue(resp.Cookies()))
user, err = client.User(ctx, "me")
user, err = userClient.User(ctx, "me")
require.NoError(t, err)
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")
@ -977,45 +944,62 @@ func TestUserOIDC(t *testing.T) {
t.Run("Disabled", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
resp := oidcCallback(t, client, "asdf")
oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback")
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
require.NoError(t, err)
resp, err := client.HTTPClient.Do(req)
require.NoError(t, err)
resp.Body.Close()
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("NoIDToken", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: &coderd.OIDCConfig{
OAuth2Config: &testutil.OAuth2Config{},
},
fake := oidctest.NewFakeIDP(t,
oidctest.WithRefreshHook(func(_ string) error {
return xerrors.New("refreshing token should never occur")
}),
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
resp := oidcCallback(t, client, "asdf")
client := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
})
_, resp := fake.AttemptLogin(t, client, jwt.MapClaims{})
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("BadVerify", func(t *testing.T) {
t.Parallel()
verifier := oidc.NewVerifier("", &oidc.StaticKeySet{
badVerifier := oidc.NewVerifier("", &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{},
}, &oidc.Config{})
provider := &oidc.Provider{}
badProvider := &oidc.Provider{}
client := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: &coderd.OIDCConfig{
OAuth2Config: &testutil.OAuth2Config{
Token: (&oauth2.Token{
AccessToken: "token",
}).WithExtra(map[string]interface{}{
"id_token": "invalid",
}),
},
Provider: provider,
Verifier: verifier,
},
fake := oidctest.NewFakeIDP(t,
oidctest.WithRefreshHook(func(_ string) error {
return xerrors.New("refreshing token should never occur")
}),
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.Provider = badProvider
cfg.Verifier = badVerifier
})
resp := oidcCallback(t, client, "asdf")
client := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
})
_, resp := fake.AttemptLogin(t, client, jwt.MapClaims{})
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}
@ -1146,36 +1130,6 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
return res
}
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
return oidcCallbackWithState(t, client, code, "somestate", nil)
}
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response {
t.Helper()
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=%s", code, state))
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: codersdk.OAuth2StateCookie,
Value: state,
})
if modify != nil {
modify(req)
}
res, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
require.NoError(t, err)
t.Log(string(data))
return res
}
func i64ptr(i int64) *int64 {
return &i
}

View File

@ -8,7 +8,10 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -403,6 +406,7 @@ func TestPostLogout(t *testing.T) {
})
}
// nolint:bodyclose
func TestPostUsers(t *testing.T) {
t.Parallel()
t.Run("NoAuth", func(t *testing.T) {
@ -593,15 +597,15 @@ func TestPostUsers(t *testing.T) {
t.Run("CreateOIDCLoginType", func(t *testing.T) {
t.Parallel()
email := "another@user.org"
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig(t, jwt.MapClaims{
"email": email,
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
config.AllowSignups = false
config.IgnoreUserInfo = true
client := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: config,
OIDCConfig: cfg,
})
first := coderdtest.CreateFirstUser(t, client)
@ -618,15 +622,9 @@ func TestPostUsers(t *testing.T) {
require.NoError(t, err)
// Try to log in with OIDC.
userClient := codersdk.New(client.URL)
resp := oidcCallback(t, userClient, conf.EncodeClaims(t, jwt.MapClaims{
userClient, _ := fake.Login(t, client, jwt.MapClaims{
"email": email,
}))
require.Equal(t, resp.StatusCode, http.StatusTemporaryRedirect)
// Set the client to use this OIDC context
authCookie := authCookieValue(resp.Cookies())
userClient.SetSessionToken(authCookie)
_ = resp.Body.Close()
})
found, err := userClient.User(ctx, "me")
require.NoError(t, err)

View File

@ -0,0 +1,77 @@
package syncmap
import "sync"
// Map is a type safe sync.Map
type Map[K, V any] struct {
m sync.Map
}
func New[K, V any]() *Map[K, V] {
return &Map[K, V]{
m: sync.Map{},
}
}
func (m *Map[K, V]) Store(k K, v V) {
m.m.Store(k, v)
}
//nolint:forcetypeassert
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
v, ok := m.m.Load(key)
if !ok {
var empty V
return empty, false
}
return v.(V), ok
}
func (m *Map[K, V]) Delete(key K) {
m.m.Delete(key)
}
//nolint:forcetypeassert
func (m *Map[K, V]) LoadAndDelete(key K) (actual V, loaded bool) {
act, loaded := m.m.LoadAndDelete(key)
if !loaded {
var empty V
return empty, loaded
}
return act.(V), loaded
}
//nolint:forcetypeassert
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
act, loaded := m.m.LoadOrStore(key, value)
if !loaded {
var empty V
return empty, loaded
}
return act.(V), loaded
}
func (m *Map[K, V]) CompareAndSwap(key K, old V, new V) bool {
return m.m.CompareAndSwap(key, old, new)
}
func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
return m.m.CompareAndDelete(key, old)
}
//nolint:forcetypeassert
func (m *Map[K, V]) Swap(key K, value V) (previous any, loaded bool) {
previous, loaded = m.m.Swap(key, value)
if !loaded {
var empty V
return empty, loaded
}
return previous.(V), loaded
}
//nolint:forcetypeassert
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
m.m.Range(func(key, value interface{}) bool {
return f(key.(K), value.(V))
})
}

View File

@ -1,25 +1,22 @@
package coderd_test
import (
"context"
"fmt"
"io"
"net/http"
"regexp"
"testing"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
coderden "github.com/coder/coder/v2/enterprise/coderd"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/testutil"
@ -31,128 +28,123 @@ func TestUserOIDC(t *testing.T) {
t.Run("RoleSync", func(t *testing.T) {
t.Parallel()
// NoRoles is the "control group". It has claims with 0 roles
// assigned, and asserts that the user has no roles.
t.Run("NoRoles", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
conf := coderdtest.NewOIDCConfig(t, "")
oidcRoleName := "TemplateAuthor"
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}
})
config.AllowSignups = true
config.UserRoleField = "roles"
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: config,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.UserRoleField = "roles"
},
})
admin, err := client.User(ctx, "me")
require.NoError(t, err)
require.Len(t, admin.OrganizationIDs, 1)
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
claims := jwt.MapClaims{
"email": "alice@coder.com",
}))
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
user, err := client.User(ctx, "alice")
require.NoError(t, err)
require.Len(t, user.Roles, 0)
roleNames := []string{}
require.ElementsMatch(t, roleNames, []string{})
}
// Login a new client that signs up
client, resp := runner.Login(t, claims)
require.Equal(t, http.StatusOK, resp.StatusCode)
// User should be in 0 groups.
runner.AssertRoles(t, "alice", []string{})
// Force a refresh, and assert nothing has changes
runner.ForceRefresh(t, client, claims)
runner.AssertRoles(t, "alice", []string{})
})
t.Run("NewUserAndRemoveRoles", func(t *testing.T) {
// A user has some roles, then on an oauth refresh will lose said
// roles from an updated claim.
t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) {
// TODO: Implement new feature to update roles/groups on OIDC
// refresh tokens. https://github.com/coder/coder/issues/9312
t.Skip("Refreshing tokens does not update roles :(")
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
conf := coderdtest.NewOIDCConfig(t, "")
oidcRoleName := "TemplateAuthor"
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}
})
config.AllowSignups = true
config.UserRoleField = "roles"
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: config,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
const oidcRoleName = "TemplateAuthor"
runner := setupOIDCTest(t, oidcTestConfig{
Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}},
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.UserRoleField = "roles"
cfg.UserRoleMapping = map[string][]string{
oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()},
}
},
})
admin, err := client.User(ctx, "me")
require.NoError(t, err)
require.Len(t, admin.OrganizationIDs, 1)
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
// User starts with the owner role
client, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
"roles": []string{"random", oidcRoleName, rbac.RoleOwner()},
}))
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
_ = resp.Body.Close()
user, err := client.User(ctx, "alice")
require.NoError(t, err)
})
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
require.Len(t, user.Roles, 3)
roleNames := []string{user.Roles[0].Name, user.Roles[1].Name, user.Roles[2].Name}
require.ElementsMatch(t, roleNames, []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
// Now remove the roles with a new oidc login
resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
// Now refresh the oauth, and check the roles are removed.
// Force a refresh, and assert nothing has changes
runner.ForceRefresh(t, client, jwt.MapClaims{
"email": "alice@coder.com",
"roles": []string{"random"},
}))
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
_ = resp.Body.Close()
user, err = client.User(ctx, "alice")
require.NoError(t, err)
require.Len(t, user.Roles, 0)
})
runner.AssertRoles(t, "alice", []string{})
})
// A user has some roles, then on another oauth login will lose said
// roles from an updated claim.
t.Run("NewUserAndRemoveRolesOnReAuth", func(t *testing.T) {
t.Parallel()
const oidcRoleName = "TemplateAuthor"
runner := setupOIDCTest(t, oidcTestConfig{
Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}},
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.UserRoleField = "roles"
cfg.UserRoleMapping = map[string][]string{
oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()},
}
},
})
// User starts with the owner role
_, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
"roles": []string{"random", oidcRoleName, rbac.RoleOwner()},
})
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
// Now login with oauth again, and check the roles are removed.
_, resp = runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
"roles": []string{"random"},
})
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertRoles(t, "alice", []string{})
})
// All manual role updates should fail when role sync is enabled.
t.Run("BlockAssignRoles", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig(t, jwt.MapClaims{})
config.AllowSignups = true
config.UserRoleField = "roles"
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: config,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.UserRoleField = "roles"
},
})
admin, err := client.User(ctx, "me")
require.NoError(t, err)
require.Len(t, admin.OrganizationIDs, 1)
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
_, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
"roles": []string{},
}))
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
require.Equal(t, http.StatusOK, resp.StatusCode)
// Try to manually update user roles, even though controlled by oidc
// role sync.
_, err = client.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{
ctx := testutil.Context(t, testutil.WaitShort)
_, err := runner.AdminClient.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{
Roles: []string{
rbac.RoleTemplateAdmin(),
},
@ -164,199 +156,211 @@ func TestUserOIDC(t *testing.T) {
t.Run("Groups", func(t *testing.T) {
t.Parallel()
// Assigns does a simple test of assigning a user to a group based
// on the oidc claims.
t.Run("Assigns", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
conf := coderdtest.NewOIDCConfig(t, "")
const groupClaim = "custom-groups"
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
cfg.GroupField = groupClaim
})
config.AllowSignups = true
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: config,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
const groupName = "bingbong"
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.GroupField = groupClaim
},
})
admin, err := client.User(ctx, "me")
require.NoError(t, err)
require.Len(t, admin.OrganizationIDs, 1)
groupName := "bingbong"
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
ctx := testutil.Context(t, testutil.WaitShort)
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
Name: groupName,
})
require.NoError(t, err)
require.Len(t, group.Members, 0)
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
"email": "colin@coder.com",
_, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
groupClaim: []string{groupName},
}))
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
group, err = client.Group(ctx, group.ID)
require.NoError(t, err)
require.Len(t, group.Members, 1)
})
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertGroups(t, "alice", []string{groupName})
})
// Tests the group mapping feature.
t.Run("AssignsMapped", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
conf := coderdtest.NewOIDCConfig(t, "")
const groupClaim = "custom-groups"
oidcGroupName := "pingpong"
coderGroupName := "bingbong"
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName}
})
config.AllowSignups = true
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: config,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
const oidcGroupName = "pingpong"
const coderGroupName = "bingbong"
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.GroupField = groupClaim
cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName}
},
})
admin, err := client.User(ctx, "me")
require.NoError(t, err)
require.Len(t, admin.OrganizationIDs, 1)
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
ctx := testutil.Context(t, testutil.WaitShort)
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
Name: coderGroupName,
})
require.NoError(t, err)
require.Len(t, group.Members, 0)
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
"email": "colin@coder.com",
"groups": []string{oidcGroupName},
}))
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
group, err = client.Group(ctx, group.ID)
require.NoError(t, err)
require.Len(t, group.Members, 1)
_, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
groupClaim: []string{oidcGroupName},
})
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertGroups(t, "alice", []string{coderGroupName})
})
t.Run("AddThenRemove", func(t *testing.T) {
// User is in a group, then on an oauth refresh will lose said
// group.
t.Run("AddThenRemoveOnRefresh", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
conf := coderdtest.NewOIDCConfig(t, "")
// TODO: Implement new feature to update roles/groups on OIDC
// refresh tokens. https://github.com/coder/coder/issues/9312
t.Skip("Refreshing tokens does not update groups :(")
config := conf.OIDCConfig(t, jwt.MapClaims{})
config.AllowSignups = true
client, firstUser := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: config,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
const groupClaim = "custom-groups"
const groupName = "bingbong"
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.GroupField = groupClaim
},
})
// Add some extra users/groups that should be asserted after.
// Adding this user as there was a bug that removing 1 user removed
// all users from the group.
_, extra := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
groupName := "bingbong"
group, err := client.CreateGroup(ctx, firstUser.OrganizationID, codersdk.CreateGroupRequest{
ctx := testutil.Context(t, testutil.WaitShort)
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
Name: groupName,
})
require.NoError(t, err, "create group")
require.NoError(t, err)
require.Len(t, group.Members, 0)
group, err = client.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{
AddUsers: []string{
firstUser.UserID.String(),
extra.ID.String(),
},
client, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
groupClaim: []string{groupName},
})
require.NoError(t, err, "patch group")
require.Len(t, group.Members, 2, "expect both members")
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertGroups(t, "alice", []string{groupName})
// Now add OIDC user into the group
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
"email": "colin@coder.com",
"groups": []string{groupName},
}))
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
group, err = client.Group(ctx, group.ID)
require.NoError(t, err)
require.Len(t, group.Members, 3)
// Login to remove the OIDC user from the group
resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
"email": "colin@coder.com",
"groups": []string{},
}))
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
group, err = client.Group(ctx, group.ID)
require.NoError(t, err)
require.Len(t, group.Members, 2)
var expected []uuid.UUID
for _, mem := range group.Members {
expected = append(expected, mem.ID)
}
require.ElementsMatchf(t, expected, []uuid.UUID{firstUser.UserID, extra.ID}, "expected members")
// Refresh without the group claim
runner.ForceRefresh(t, client, jwt.MapClaims{
"email": "alice@coder.com",
})
runner.AssertGroups(t, "alice", []string{})
})
t.Run("AddThenRemoveOnReAuth", func(t *testing.T) {
t.Parallel()
const groupClaim = "custom-groups"
const groupName = "bingbong"
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.GroupField = groupClaim
},
})
ctx := testutil.Context(t, testutil.WaitShort)
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
Name: groupName,
})
require.NoError(t, err)
require.Len(t, group.Members, 0)
_, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
groupClaim: []string{groupName},
})
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertGroups(t, "alice", []string{groupName})
// Refresh without the group claim
_, resp = runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
})
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertGroups(t, "alice", []string{})
})
// Updating groups where the claimed group does not exist.
t.Run("NoneMatch", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig(t, jwt.MapClaims{})
config.AllowSignups = true
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: config,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
const groupClaim = "custom-groups"
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.GroupField = groupClaim
},
})
admin, err := client.User(ctx, "me")
require.NoError(t, err)
require.Len(t, admin.OrganizationIDs, 1)
groupName := "bingbong"
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
Name: groupName,
_, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
groupClaim: []string{"not-exists"},
})
require.NoError(t, err)
require.Len(t, group.Members, 0)
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertGroups(t, "alice", []string{})
})
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
"email": "colin@coder.com",
"groups": []string{"coolin"},
}))
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// Updating groups where the claimed group does not exist creates
// the group.
t.Run("AutoCreate", func(t *testing.T) {
t.Parallel()
group, err = client.Group(ctx, group.ID)
require.NoError(t, err)
require.Len(t, group.Members, 0)
const groupClaim = "custom-groups"
const groupName = "make-me"
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.GroupField = groupClaim
cfg.CreateMissingGroups = true
},
})
_, resp := runner.Login(t, jwt.MapClaims{
"email": "alice@coder.com",
groupClaim: []string{groupName},
})
require.Equal(t, http.StatusOK, resp.StatusCode)
runner.AssertGroups(t, "alice", []string{groupName})
})
})
t.Run("Refresh", func(t *testing.T) {
t.Run("RefreshTokensMultiple", func(t *testing.T) {
t.Parallel()
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
cfg.UserRoleField = "roles"
},
})
claims := jwt.MapClaims{
"email": "alice@coder.com",
}
// Login a new client that signs up
client, resp := runner.Login(t, claims)
require.Equal(t, http.StatusOK, resp.StatusCode)
// Refresh multiple times.
for i := 0; i < 3; i++ {
runner.ForceRefresh(t, client, claims)
}
})
})
}
// nolint:bodyclose
func TestGroupSync(t *testing.T) {
t.Parallel()
@ -470,28 +474,20 @@ func TestGroupSync(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig(t, jwt.MapClaims{}, tc.modCfg)
client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: config,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) {
cfg.GroupField = "groups"
tc.modCfg(cfg)
},
})
admin, err := client.User(ctx, "me")
require.NoError(t, err)
require.Len(t, admin.OrganizationIDs, 1)
// Setup
ctx := testutil.Context(t, testutil.WaitLong)
org := runner.AdminUser.OrganizationIDs[0]
initialGroups := make(map[string]codersdk.Group)
for _, group := range tc.initialOrgGroups {
newGroup, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
newGroup, err := runner.AdminClient.CreateGroup(ctx, org, codersdk.CreateGroupRequest{
Name: group,
})
require.NoError(t, err)
@ -500,16 +496,16 @@ func TestGroupSync(t *testing.T) {
}
// Create the user and add them to their initial groups
_, user := coderdtest.CreateAnotherUser(t, client, admin.OrganizationIDs[0])
_, user := coderdtest.CreateAnotherUser(t, runner.AdminClient, org)
for _, group := range tc.initialUserGroups {
_, err := client.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{
_, err := runner.AdminClient.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{
AddUsers: []string{user.ID.String()},
})
require.NoError(t, err)
}
// nolint:gocritic
_, err = api.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{
_, err := runner.API.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{
NewLoginType: database.LoginTypeOIDC,
UserID: user.ID,
})
@ -517,11 +513,11 @@ func TestGroupSync(t *testing.T) {
// Log in the new user
tc.claims["email"] = user.Email
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.claims))
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
_ = resp.Body.Close()
_, resp := runner.Login(t, tc.claims)
require.Equal(t, http.StatusOK, resp.StatusCode)
orgGroups, err := client.GroupsByOrganization(ctx, admin.OrganizationIDs[0])
// Check group sources
orgGroups, err := runner.AdminClient.GroupsByOrganization(ctx, org)
require.NoError(t, err)
for _, group := range orgGroups {
@ -567,24 +563,107 @@ func TestGroupSync(t *testing.T) {
}
}
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
t.Helper()
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code))
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: codersdk.OAuth2StateCookie,
Value: "somestate",
})
res, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
require.NoError(t, err)
t.Log(string(data))
return res
// oidcTestRunner is just a helper to setup and run oidc tests.
// An actual Coderd instance is used to run the tests.
type oidcTestRunner struct {
AdminClient *codersdk.Client
AdminUser codersdk.User
API *coderden.API
// Login will call the OIDC flow with an unauthenticated client.
// The IDP will return the idToken claims.
Login func(t *testing.T, idToken jwt.MapClaims) (*codersdk.Client, *http.Response)
// ForceRefresh will use an authenticated codersdk.Client, and force their
// OIDC token to be expired and require a refresh. The refresh will use the claims provided.
// It just calls the /users/me endpoint to trigger the refresh.
ForceRefresh func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims)
}
type oidcTestConfig struct {
Userinfo jwt.MapClaims
// Config allows modifying the Coderd OIDC configuration.
Config func(cfg *coderd.OIDCConfig)
}
func (r *oidcTestRunner) AssertRoles(t *testing.T, userIdent string, roles []string) {
t.Helper()
ctx := testutil.Context(t, testutil.WaitMedium)
user, err := r.AdminClient.User(ctx, userIdent)
require.NoError(t, err)
roleNames := []string{}
for _, role := range user.Roles {
roleNames = append(roleNames, role.Name)
}
require.ElementsMatch(t, roles, roleNames, "expected roles")
}
func (r *oidcTestRunner) AssertGroups(t *testing.T, userIdent string, groups []string) {
t.Helper()
if !slice.Contains(groups, database.EveryoneGroup) {
var cpy []string
cpy = append(cpy, groups...)
// always include everyone group
cpy = append(cpy, database.EveryoneGroup)
groups = cpy
}
ctx := testutil.Context(t, testutil.WaitMedium)
user, err := r.AdminClient.User(ctx, userIdent)
require.NoError(t, err)
allGroups, err := r.AdminClient.GroupsByOrganization(ctx, user.OrganizationIDs[0])
require.NoError(t, err)
userInGroups := []string{}
for _, g := range allGroups {
for _, mem := range g.Members {
if mem.ID == user.ID {
userInGroups = append(userInGroups, g.Name)
}
}
}
require.ElementsMatch(t, groups, userInGroups, "expected groups")
}
func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner {
t.Helper()
fake := oidctest.NewFakeIDP(t,
oidctest.WithStaticUserInfo(settings.Userinfo),
oidctest.WithLogging(t, nil),
// Run fake IDP on a real webserver
oidctest.WithServing(),
)
ctx := testutil.Context(t, testutil.WaitMedium)
cfg := fake.OIDCConfig(t, nil, settings.Config)
owner, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
OIDCConfig: cfg,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureUserRoleManagement: 1,
codersdk.FeatureTemplateRBAC: 1,
},
},
})
admin, err := owner.User(ctx, "me")
require.NoError(t, err)
helper := oidctest.NewLoginHelper(owner, fake)
return &oidcTestRunner{
AdminClient: owner,
AdminUser: admin,
API: api,
Login: helper.Login,
ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) {
helper.ForceRefresh(t, api.Database, client, idToken)
},
}
}