mirror of https://github.com/coder/coder.git
test: add full OIDC fake IDP (#9317)
* test: implement fake OIDC provider with full functionality * Refactor existing tests
This commit is contained in:
parent
0a213a6ac3
commit
d9d4d74f99
|
@ -31,15 +31,13 @@ import (
|
|||
"time"
|
||||
|
||||
"cloud.google.com/go/compute/metadata"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/fullsailor/pkcs7"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/api/idtoken"
|
||||
"google.golang.org/api/option"
|
||||
|
@ -1020,152 +1018,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
|
|||
}
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
key *rsa.PrivateKey
|
||||
issuer string
|
||||
// These are optional
|
||||
refreshToken string
|
||||
oidcTokenExpires func() time.Time
|
||||
tokenSource func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.refreshToken = token
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.oidcTokenExpires = expFunc
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.tokenSource = src
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
|
||||
t.Helper()
|
||||
|
||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
if issuer == "" {
|
||||
issuer = "https://coder.com"
|
||||
}
|
||||
|
||||
cfg := &OIDCConfig{
|
||||
key: pkey,
|
||||
issuer: issuer,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
type tokenSource struct {
|
||||
src func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (s tokenSource) Token() (*oauth2.Token, error) {
|
||||
return s.src()
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
if cfg.tokenSource == nil {
|
||||
return nil
|
||||
}
|
||||
return tokenSource{
|
||||
src: cfg.tokenSource,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
token, err := base64.StdEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode code: %w", err)
|
||||
}
|
||||
|
||||
var exp time.Time
|
||||
if cfg.oidcTokenExpires != nil {
|
||||
exp = cfg.oidcTokenExpires()
|
||||
}
|
||||
|
||||
return (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
RefreshToken: cfg.refreshToken,
|
||||
Expiry: exp,
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": string(token),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = cfg.issuer
|
||||
}
|
||||
|
||||
if _, ok := claims["sub"]; !ok {
|
||||
claims["sub"] = "testme"
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return base64.StdEncoding.EncodeToString([]byte(signed))
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
// By default, the provider can be empty.
|
||||
// This means it won't support any endpoints!
|
||||
provider := &oidc.Provider{}
|
||||
if userInfoClaims != nil {
|
||||
resp, err := json.Marshal(userInfoClaims)
|
||||
require.NoError(t, err)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(resp)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
cfg := &oidc.ProviderConfig{
|
||||
UserInfoURL: srv.URL,
|
||||
}
|
||||
provider = cfg.NewProvider(context.Background())
|
||||
}
|
||||
newCFG := &coderd.OIDCConfig{
|
||||
OAuth2Config: cfg,
|
||||
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
}),
|
||||
Provider: provider,
|
||||
UsernameField: "preferred_username",
|
||||
EmailField: "email",
|
||||
AuthURLParams: map[string]string{"access_type": "offline"},
|
||||
GroupField: "groups",
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(newCFG)
|
||||
}
|
||||
return newCFG
|
||||
}
|
||||
|
||||
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
|
||||
// instance authentication for Azure.
|
||||
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
|
||||
|
@ -1254,22 +1106,6 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
|
|||
return cerr
|
||||
}
|
||||
|
||||
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
|
||||
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
|
||||
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
|
||||
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
|
||||
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
|
||||
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
|
||||
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
|
||||
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
|
||||
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
|
||||
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
|
||||
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
|
||||
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
|
||||
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
func DeploymentValues(t testing.TB) *codersdk.DeploymentValues {
|
||||
var cfg codersdk.DeploymentValues
|
||||
opts := cfg.Options()
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
package oidctest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// LoginHelper helps with logging in a user and refreshing their oauth tokens.
|
||||
// It is mainly because refreshing oauth tokens is a bit tricky and requires
|
||||
// some database manipulation.
|
||||
type LoginHelper struct {
|
||||
fake *FakeIDP
|
||||
client *codersdk.Client
|
||||
}
|
||||
|
||||
func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper {
|
||||
if client == nil {
|
||||
panic("client must not be nil")
|
||||
}
|
||||
if fake == nil {
|
||||
panic("fake must not be nil")
|
||||
}
|
||||
return &LoginHelper{
|
||||
fake: fake,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Login just helps by making an unauthenticated client and logging in with
|
||||
// the given claims. All Logins should be unauthenticated, so this is a
|
||||
// convenience method.
|
||||
func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) {
|
||||
t.Helper()
|
||||
unauthenticatedClient := codersdk.New(h.client.URL)
|
||||
|
||||
return h.fake.Login(t, unauthenticatedClient, idTokenClaims)
|
||||
}
|
||||
|
||||
// ExpireOauthToken expires the oauth token for the given user.
|
||||
func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink {
|
||||
t.Helper()
|
||||
|
||||
//nolint:gocritic // Testing
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||
|
||||
id, _, err := httpmw.SplitAPIToken(user.SessionToken())
|
||||
require.NoError(t, err)
|
||||
|
||||
// We need to get the OIDC link and update it in the database to force
|
||||
// it to be expired.
|
||||
key, err := db.GetAPIKeyByID(ctx, id)
|
||||
require.NoError(t, err, "get api key")
|
||||
|
||||
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
require.NoError(t, err, "get user link")
|
||||
|
||||
// Expire the oauth link for the given user.
|
||||
updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthExpiry: time.Now().Add(time.Hour * -1),
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.NoError(t, err, "expire user link")
|
||||
|
||||
return updated
|
||||
}
|
||||
|
||||
// ForceRefresh forces the client to refresh its oauth token. It does this by
|
||||
// expiring the oauth token, then doing an authenticated call. This will force
|
||||
// the API Key middleware to refresh the oauth token.
|
||||
//
|
||||
// A unit test assertion makes sure the refresh token is used.
|
||||
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) {
|
||||
t.Helper()
|
||||
|
||||
link := h.ExpireOauthToken(t, db, user)
|
||||
// Updates the claims that the IDP will return. By default, it always
|
||||
// uses the original claims for the original oauth token.
|
||||
h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
|
||||
})
|
||||
|
||||
// Do any authenticated call to force the refresh
|
||||
_, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
|
||||
require.NoError(t, err, "user must be able to be fetched")
|
||||
}
|
|
@ -0,0 +1,793 @@
|
|||
package oidctest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/util/syncmap"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// FakeIDP is a functional OIDC provider.
|
||||
// It only supports 1 OIDC client.
|
||||
type FakeIDP struct {
|
||||
issuer string
|
||||
key *rsa.PrivateKey
|
||||
provider providerJSON
|
||||
handler http.Handler
|
||||
cfg *oauth2.Config
|
||||
|
||||
// clientID to be used by coderd
|
||||
clientID string
|
||||
clientSecret string
|
||||
logger slog.Logger
|
||||
|
||||
// These maps are used to control the state of the IDP.
|
||||
// That is the various access tokens, refresh tokens, states, etc.
|
||||
codeToStateMap *syncmap.Map[string, string]
|
||||
// Token -> Email
|
||||
accessTokens *syncmap.Map[string, string]
|
||||
// Refresh Token -> Email
|
||||
refreshTokensUsed *syncmap.Map[string, bool]
|
||||
refreshTokens *syncmap.Map[string, string]
|
||||
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||
|
||||
// hooks
|
||||
// hookValidRedirectURL can be used to reject a redirect url from the
|
||||
// IDP -> Application. Almost all IDPs have the concept of
|
||||
// "Authorized Redirect URLs". This can be used to emulate that.
|
||||
hookValidRedirectURL func(redirectURL string) error
|
||||
hookUserInfo func(email string) jwt.MapClaims
|
||||
fakeCoderd func(req *http.Request) (*http.Response, error)
|
||||
hookOnRefresh func(email string) error
|
||||
// Custom authentication for the client. This is useful if you want
|
||||
// to test something like PKI auth vs a client_secret.
|
||||
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
|
||||
serve bool
|
||||
}
|
||||
|
||||
type FakeIDPOpt func(idp *FakeIDP)
|
||||
|
||||
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookValidRedirectURL = hook
|
||||
}
|
||||
}
|
||||
|
||||
// WithRefreshHook is called when a refresh token is used. The email is
|
||||
// the email of the user that is being refreshed assuming the claims are correct.
|
||||
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookOnRefresh = hook
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookAuthenticateClient = hook
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogging is optional, but will log some HTTP calls made to the IDP.
|
||||
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.logger = slogtest.Make(t, options)
|
||||
}
|
||||
}
|
||||
|
||||
// WithStaticUserInfo is optional, but will return the same user info for
|
||||
// every user on the /userinfo endpoint.
|
||||
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookUserInfo = func(_ string) jwt.MapClaims {
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookUserInfo = userInfoFunc
|
||||
}
|
||||
}
|
||||
|
||||
// WithServing makes the IDP run an actual http server.
|
||||
func WithServing() func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.serve = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithIssuer(issuer string) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.issuer = issuer
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// nolint:gosec // It thinks this is a secret lol
|
||||
tokenPath = "/oauth2/token"
|
||||
authorizePath = "/oauth2/authorize"
|
||||
keysPath = "/oauth2/keys"
|
||||
userInfoPath = "/oauth2/userinfo"
|
||||
)
|
||||
|
||||
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
t.Helper()
|
||||
|
||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
idp := &FakeIDP{
|
||||
key: pkey,
|
||||
clientID: uuid.NewString(),
|
||||
clientSecret: uuid.NewString(),
|
||||
logger: slog.Make(),
|
||||
codeToStateMap: syncmap.New[string, string](),
|
||||
accessTokens: syncmap.New[string, string](),
|
||||
refreshTokens: syncmap.New[string, string](),
|
||||
refreshTokensUsed: syncmap.New[string, bool](),
|
||||
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
hookOnRefresh: func(_ string) error { return nil },
|
||||
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
|
||||
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(idp)
|
||||
}
|
||||
|
||||
if idp.issuer == "" {
|
||||
idp.issuer = "https://coder.com"
|
||||
}
|
||||
|
||||
idp.handler = idp.httpHandler(t)
|
||||
idp.updateIssuerURL(t, idp.issuer)
|
||||
if idp.serve {
|
||||
idp.realServer(t)
|
||||
}
|
||||
|
||||
return idp
|
||||
}
|
||||
|
||||
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
t.Helper()
|
||||
|
||||
u, err := url.Parse(issuer)
|
||||
require.NoError(t, err, "invalid issuer URL")
|
||||
|
||||
f.issuer = issuer
|
||||
// providerJSON is the JSON representation of the OpenID Connect provider
|
||||
// These are all the urls that the IDP will respond to.
|
||||
f.provider = providerJSON{
|
||||
Issuer: issuer,
|
||||
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
|
||||
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
|
||||
Algorithms: []string{
|
||||
"RS256",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// realServer turns the FakeIDP into a real http server.
|
||||
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
srv := httptest.NewUnstartedServer(f.handler)
|
||||
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||
return ctx
|
||||
}
|
||||
srv.Start()
|
||||
t.Cleanup(srv.CloseClientConnections)
|
||||
t.Cleanup(srv.Close)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
f.updateIssuerURL(t, srv.URL)
|
||||
return srv
|
||||
}
|
||||
|
||||
// Login does the full OIDC flow starting at the "LoginButton".
|
||||
// The client argument is just to get the URL of the Coder instance.
|
||||
//
|
||||
// The client passed in is just to get the url of the Coder instance.
|
||||
// The actual client that is used is 100% unauthenticated and fresh.
|
||||
func (f *FakeIDP) Login(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||
t.Helper()
|
||||
|
||||
client, resp := f.AttemptLogin(t, client, idTokenClaims, opts...)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode, "client failed to login")
|
||||
return client, resp
|
||||
}
|
||||
|
||||
func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||
t.Helper()
|
||||
var err error
|
||||
|
||||
cli := f.HTTPClient(client.HTTPClient)
|
||||
shallowCpyCli := *cli
|
||||
|
||||
if shallowCpyCli.Jar == nil {
|
||||
shallowCpyCli.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err, "failed to create cookie jar")
|
||||
}
|
||||
|
||||
unauthenticated := codersdk.New(client.URL)
|
||||
unauthenticated.HTTPClient = &shallowCpyCli
|
||||
|
||||
return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...)
|
||||
}
|
||||
|
||||
// LoginWithClient reuses the context of the passed in client. This means the same
|
||||
// cookies will be used. This should be an unauthenticated client in most cases.
|
||||
//
|
||||
// This is a niche case, but it is needed for testing ConvertLoginType.
|
||||
func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||
t.Helper()
|
||||
|
||||
coderOauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback")
|
||||
require.NoError(t, err)
|
||||
f.SetRedirect(t, coderOauthURL.String())
|
||||
|
||||
cli := f.HTTPClient(client.HTTPClient)
|
||||
cli.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
// Store the idTokenClaims to the specific state request. This ties
|
||||
// the claims 1:1 with a given authentication flow.
|
||||
state := req.URL.Query().Get("state")
|
||||
f.stateToIDTokenClaims.Store(state, idTokenClaims)
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
if cli.Jar == nil {
|
||||
cli.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err, "failed to create cookie jar")
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(req)
|
||||
}
|
||||
|
||||
res, err := cli.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// If the coder session token exists, return the new authed client!
|
||||
var user *codersdk.Client
|
||||
cookies := cli.Jar.Cookies(client.URL)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == codersdk.SessionTokenCookie {
|
||||
user = codersdk.New(client.URL)
|
||||
user.SetSessionToken(cookie.Value)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if res.Body != nil {
|
||||
_ = res.Body.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return user, res
|
||||
}
|
||||
|
||||
// OIDCCallback will emulate the IDP redirecting back to the Coder callback.
|
||||
// This is helpful if no Coderd exists because the IDP needs to redirect to
|
||||
// something.
|
||||
// Essentially this is used to fake the Coderd side of the exchange.
|
||||
// The flow starts at the user hitting the OIDC login page.
|
||||
func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) {
|
||||
t.Helper()
|
||||
if f.serve {
|
||||
panic("cannot use OIDCCallback with WithServing. This is only for the in memory usage")
|
||||
}
|
||||
|
||||
f.stateToIDTokenClaims.Store(state, idTokenClaims)
|
||||
|
||||
cli := f.HTTPClient(nil)
|
||||
u := f.cfg.AuthCodeURL(state)
|
||||
req, err := http.NewRequest("GET", u, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := cli.Do(req.WithContext(context.Background()))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
})
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type providerJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||
}
|
||||
|
||||
// newCode enforces the code exchanged is actually a valid code
|
||||
// created by the IDP.
|
||||
func (f *FakeIDP) newCode(state string) string {
|
||||
code := uuid.NewString()
|
||||
f.codeToStateMap.Store(code, state)
|
||||
return code
|
||||
}
|
||||
|
||||
// newToken enforces the access token exchanged is actually a valid access token
|
||||
// created by the IDP.
|
||||
func (f *FakeIDP) newToken(email string) string {
|
||||
accessToken := uuid.NewString()
|
||||
f.accessTokens.Store(accessToken, email)
|
||||
return accessToken
|
||||
}
|
||||
|
||||
func (f *FakeIDP) newRefreshTokens(email string) string {
|
||||
refreshToken := uuid.NewString()
|
||||
f.refreshTokens.Store(refreshToken, email)
|
||||
return refreshToken
|
||||
}
|
||||
|
||||
// authenticateBearerTokenRequest enforces the access token is valid.
|
||||
func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
auth := req.Header.Get("Authorization")
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
_, ok := f.accessTokens.Load(token)
|
||||
if !ok {
|
||||
return "", xerrors.New("invalid access token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// authenticateOIDCClientRequest enforces the client_id and client_secret are valid.
|
||||
func (f *FakeIDP) authenticateOIDCClientRequest(t testing.TB, req *http.Request) (url.Values, error) {
|
||||
t.Helper()
|
||||
|
||||
if f.hookAuthenticateClient != nil {
|
||||
return f.hookAuthenticateClient(t, req)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(req.Body)
|
||||
if !assert.NoError(t, err, "read token request body") {
|
||||
return nil, xerrors.Errorf("authenticate request, read body: %w", err)
|
||||
}
|
||||
values, err := url.ParseQuery(string(data))
|
||||
if !assert.NoError(t, err, "parse token request values") {
|
||||
return nil, xerrors.New("invalid token request")
|
||||
}
|
||||
|
||||
if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") {
|
||||
return nil, xerrors.New("client_id mismatch")
|
||||
}
|
||||
|
||||
if !assert.Equal(t, f.clientSecret, values.Get("client_secret"), "client_secret mismatch") {
|
||||
return nil, xerrors.New("client_secret mismatch")
|
||||
}
|
||||
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// encodeClaims is a helper func to convert claims to a valid JWT.
|
||||
func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
}
|
||||
|
||||
if _, ok := claims["aud"]; !ok {
|
||||
claims["aud"] = f.clientID
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = f.issuer
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return signed
|
||||
}
|
||||
|
||||
// httpHandler is the IDP http server.
|
||||
func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
mux := chi.NewMux()
|
||||
// This endpoint is required to initialize the OIDC provider.
|
||||
// It is used to get the OIDC configuration.
|
||||
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http OIDC config", slog.F("url", r.URL.String()))
|
||||
|
||||
_ = json.NewEncoder(rw).Encode(f.provider)
|
||||
})
|
||||
|
||||
// Authorize is called when the user is redirected to the IDP to login.
|
||||
// This is the browser hitting the IDP and the user logging into Google or
|
||||
// w/e and clicking "Allow". They will be redirected back to the redirect
|
||||
// when this is done.
|
||||
mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http call authorize", slog.F("url", r.URL.String()))
|
||||
|
||||
clientID := r.URL.Query().Get("client_id")
|
||||
if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") {
|
||||
http.Error(rw, "invalid client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := r.URL.Query().Get("redirect_uri")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
scope := r.URL.Query().Get("scope")
|
||||
assert.NotEmpty(t, scope, "scope is empty")
|
||||
|
||||
responseType := r.URL.Query().Get("response_type")
|
||||
switch responseType {
|
||||
case "code":
|
||||
case "token":
|
||||
t.Errorf("response_type %q not supported", responseType)
|
||||
http.Error(rw, "invalid response_type", http.StatusBadRequest)
|
||||
return
|
||||
default:
|
||||
t.Errorf("unexpected response_type %q", responseType)
|
||||
http.Error(rw, "invalid response_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err := f.hookValidRedirectURL(redirectURI)
|
||||
if err != nil {
|
||||
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
|
||||
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ru, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error())
|
||||
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
q := ru.Query()
|
||||
q.Set("state", state)
|
||||
q.Set("code", f.newCode(state))
|
||||
ru.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect)
|
||||
}))
|
||||
|
||||
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
values, err := f.authenticateOIDCClientRequest(t, r)
|
||||
f.logger.Info(r.Context(), "http idp call token",
|
||||
slog.Error(err),
|
||||
slog.F("values", values.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
getEmail := func(claims jwt.MapClaims) string {
|
||||
email, ok := claims["email"]
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
emailStr, ok := email.(string)
|
||||
if !ok {
|
||||
return "wrong-type"
|
||||
}
|
||||
return emailStr
|
||||
}
|
||||
|
||||
var claims jwt.MapClaims
|
||||
switch values.Get("grant_type") {
|
||||
case "authorization_code":
|
||||
code := values.Get("code")
|
||||
if !assert.NotEmpty(t, code, "code is empty") {
|
||||
http.Error(rw, "invalid code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
stateStr, ok := f.codeToStateMap.Load(code)
|
||||
if !assert.True(t, ok, "invalid code") {
|
||||
http.Error(rw, "invalid code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// Always invalidate the code after it is used.
|
||||
f.codeToStateMap.Delete(code)
|
||||
|
||||
idTokenClaims, ok := f.stateToIDTokenClaims.Load(stateStr)
|
||||
if !ok {
|
||||
t.Errorf("missing id token claims")
|
||||
http.Error(rw, "missing id token claims", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
claims = idTokenClaims
|
||||
case "refresh_token":
|
||||
refreshToken := values.Get("refresh_token")
|
||||
if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") {
|
||||
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
_, ok := f.refreshTokens.Load(refreshToken)
|
||||
if !assert.True(t, ok, "invalid refresh_token") {
|
||||
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
idTokenClaims, ok := f.refreshIDTokenClaims.Load(refreshToken)
|
||||
if !ok {
|
||||
t.Errorf("missing id token claims in refresh")
|
||||
http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
claims = idTokenClaims
|
||||
err := f.hookOnRefresh(getEmail(claims))
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f.refreshTokensUsed.Store(refreshToken, true)
|
||||
// Always invalidate the refresh token after it is used.
|
||||
f.refreshTokens.Delete(refreshToken)
|
||||
default:
|
||||
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
|
||||
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
exp := time.Now().Add(time.Minute * 5)
|
||||
claims["exp"] = exp.UnixMilli()
|
||||
email := getEmail(claims)
|
||||
refreshToken := f.newRefreshTokens(email)
|
||||
token := map[string]interface{}{
|
||||
"access_token": f.newToken(email),
|
||||
"refresh_token": refreshToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int64((time.Minute * 5).Seconds()),
|
||||
"id_token": f.encodeClaims(t, claims),
|
||||
}
|
||||
// Store the claims for the next refresh
|
||||
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(rw).Encode(token)
|
||||
}))
|
||||
|
||||
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
token, err := f.authenticateBearerTokenRequest(t, r)
|
||||
f.logger.Info(r.Context(), "http call idp user info",
|
||||
slog.Error(err),
|
||||
slog.F("url", r.URL.String()),
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
email, ok := f.accessTokens.Load(token)
|
||||
if !ok {
|
||||
t.Errorf("access token user for user_info has no email to indicate which user")
|
||||
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
|
||||
}))
|
||||
|
||||
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http call idp /keys")
|
||||
set := jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: f.key.Public(),
|
||||
KeyID: "test-key",
|
||||
Algorithm: "RSA",
|
||||
},
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(rw).Encode(set)
|
||||
}))
|
||||
|
||||
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
|
||||
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// HTTPClient does nothing if IsServing is used.
|
||||
//
|
||||
// If IsServing is not used, then it will return a client that will make requests
|
||||
// to the IDP all in memory. If a request is not to the IDP, then the passed in
|
||||
// client will be used. If no client is passed in, then any regular network
|
||||
// requests will fail.
|
||||
func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
|
||||
if f.serve {
|
||||
if rest == nil || rest.Transport == nil {
|
||||
return &http.Client{}
|
||||
}
|
||||
return rest
|
||||
}
|
||||
|
||||
var jar http.CookieJar
|
||||
if rest != nil {
|
||||
jar = rest.Jar
|
||||
}
|
||||
return &http.Client{
|
||||
Jar: jar,
|
||||
Transport: fakeRoundTripper{
|
||||
roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
u, _ := url.Parse(f.issuer)
|
||||
if req.URL.Host != u.Host {
|
||||
if f.fakeCoderd != nil {
|
||||
return f.fakeCoderd(req)
|
||||
}
|
||||
if rest == nil || rest.Transport == nil {
|
||||
return nil, fmt.Errorf("unexpected network request to %q", req.URL.Host)
|
||||
}
|
||||
return rest.Transport.RoundTrip(req)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
f.handler.ServeHTTP(resp, req)
|
||||
return resp.Result(), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshUsed returns if the refresh token has been used. All refresh tokens
|
||||
// can only be used once, then they are deleted.
|
||||
func (f *FakeIDP) RefreshUsed(refreshToken string) bool {
|
||||
used, _ := f.refreshTokensUsed.Load(refreshToken)
|
||||
return used
|
||||
}
|
||||
|
||||
// UpdateRefreshClaims allows the caller to change what claims are returned
|
||||
// for a given refresh token. By default, all refreshes use the same claims as
|
||||
// the original IDToken issuance.
|
||||
func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) {
|
||||
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||
}
|
||||
|
||||
// SetRedirect is required for the IDP to know where to redirect and call
|
||||
// Coderd.
|
||||
func (f *FakeIDP) SetRedirect(t testing.TB, u string) {
|
||||
t.Helper()
|
||||
|
||||
f.cfg.RedirectURL = u
|
||||
}
|
||||
|
||||
// SetCoderdCallback is optional and only works if not using the IsServing.
|
||||
// It will setup a fake "Coderd" for the IDP to call when the IDP redirects
|
||||
// back after authenticating.
|
||||
func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) {
|
||||
if f.serve {
|
||||
panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'")
|
||||
}
|
||||
f.fakeCoderd = callback
|
||||
}
|
||||
|
||||
func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) {
|
||||
f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) {
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
return resp.Result(), nil
|
||||
})
|
||||
}
|
||||
|
||||
// OIDCConfig returns the OIDC config to use for Coderd.
|
||||
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
t.Helper()
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{"openid", "email", "profile"}
|
||||
}
|
||||
|
||||
oauthCfg := &oauth2.Config{
|
||||
ClientID: f.clientID,
|
||||
ClientSecret: f.clientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: f.provider.AuthURL,
|
||||
TokenURL: f.provider.TokenURL,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
// If the user is using a real network request, they will need to do
|
||||
// 'fake.SetRedirect()'
|
||||
RedirectURL: "https://redirect.com",
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil))
|
||||
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
|
||||
require.NoError(t, err, "failed to create OIDC provider")
|
||||
cfg := &coderd.OIDCConfig{
|
||||
OAuth2Config: oauthCfg,
|
||||
Provider: p,
|
||||
Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{f.key.Public()},
|
||||
}, &oidc.Config{
|
||||
ClientID: oauthCfg.ClientID,
|
||||
SupportedSigningAlgs: []string{
|
||||
"RS256",
|
||||
},
|
||||
// Todo: add support for Now()
|
||||
}),
|
||||
UsernameField: "preferred_username",
|
||||
EmailField: "email",
|
||||
AuthURLParams: map[string]string{"access_type": "offline"},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
f.cfg = oauthCfg
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct {
|
||||
roundTrip func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f.roundTrip(req)
|
||||
}
|
||||
|
||||
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
|
||||
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
|
||||
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
|
||||
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
|
||||
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
|
||||
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
|
||||
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
|
||||
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
|
||||
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
|
||||
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
|
||||
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
|
||||
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
|
||||
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
|
||||
-----END RSA PRIVATE KEY-----`
|
|
@ -0,0 +1,72 @@
|
|||
package oidctest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// TestFakeIDPBasicFlow tests the basic flow of the fake IDP.
|
||||
// It is done all in memory with no actual network requests.
|
||||
// nolint:bodyclose
|
||||
func TestFakeIDPBasicFlow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithLogging(t, nil),
|
||||
)
|
||||
|
||||
var handler http.Handler
|
||||
srv := httptest.NewServer(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler.ServeHTTP(w, r)
|
||||
})))
|
||||
defer srv.Close()
|
||||
|
||||
cfg := fake.OIDCConfig(t, nil)
|
||||
cli := fake.HTTPClient(nil)
|
||||
ctx := oidc.ClientContext(context.Background(), cli)
|
||||
|
||||
const expectedState = "random-state"
|
||||
var token *oauth2.Token
|
||||
// This is the Coder callback using an actual network request.
|
||||
fake.SetCoderdCallbackHandler(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Emulate OIDC flow
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
assert.Equal(t, expectedState, state, "state mismatch")
|
||||
|
||||
oauthToken, err := cfg.Exchange(ctx, code)
|
||||
if assert.NoError(t, err, "failed to exchange code") {
|
||||
assert.NotEmpty(t, oauthToken.AccessToken, "access token is empty")
|
||||
assert.NotEmpty(t, oauthToken.RefreshToken, "refresh token is empty")
|
||||
}
|
||||
token = oauthToken
|
||||
})
|
||||
|
||||
resp, err := fake.OIDCCallback(t, expectedState, jwt.MapClaims{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Test the user info
|
||||
_, err = cfg.Provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now test it can refresh
|
||||
refreshed, err := cfg.TokenSource(ctx, &oauth2.Token{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
Expiry: time.Now().Add(time.Minute * -1),
|
||||
}).Token()
|
||||
require.NoError(t, err, "failed to refresh token")
|
||||
require.NotEmpty(t, refreshed.AccessToken, "access token is empty on refresh")
|
||||
}
|
|
@ -215,7 +215,10 @@ func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
|
|||
}
|
||||
|
||||
var tokenRes struct {
|
||||
oauth2.Token
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
|
||||
// Extra fields returned by the refresh that are needed
|
||||
IDToken string `json:"id_token"`
|
||||
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
||||
|
|
|
@ -12,12 +12,15 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/oauthpki"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
@ -123,6 +126,58 @@ func TestAzureADPKIOIDC(t *testing.T) {
|
|||
require.Error(t, err, "error expected")
|
||||
}
|
||||
|
||||
// TestAzureAKPKIWithCoderd uses a fake IDP and a real Coderd to test PKI auth.
|
||||
// nolint:bodyclose
|
||||
func TestAzureAKPKIWithCoderd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scopes := []string{"openid", "email", "profile", "offline_access"}
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithIssuer("https://login.microsoftonline.com/fake_app"),
|
||||
oidctest.WithCustomClientAuth(func(t testing.TB, req *http.Request) (url.Values, error) {
|
||||
values := assertJWTAuth(t, req)
|
||||
if values == nil {
|
||||
return nil, xerrors.New("authorizatin failed in request")
|
||||
}
|
||||
return values, nil
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
})
|
||||
|
||||
oauthCfg := cfg.OAuth2Config.(*oauth2.Config)
|
||||
// Create the oauthpki config
|
||||
pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
|
||||
ClientID: oauthCfg.ClientID,
|
||||
TokenURL: oauthCfg.Endpoint.TokenURL,
|
||||
Scopes: scopes,
|
||||
PemEncodedKey: []byte(testClientKey),
|
||||
PemEncodedCert: []byte(testClientCert),
|
||||
Config: oauthCfg,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
cfg.OAuth2Config = pki
|
||||
|
||||
owner, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
OIDCConfig: cfg,
|
||||
})
|
||||
|
||||
// Create a user and login
|
||||
const email = "alice@coder.com"
|
||||
claims := jwt.MapClaims{
|
||||
"email": email,
|
||||
}
|
||||
helper := oidctest.NewLoginHelper(owner, fake)
|
||||
user, _ := helper.Login(t, claims)
|
||||
|
||||
// Try refreshing the token more than once.
|
||||
for i := 0; i < 2; i++ {
|
||||
helper.ForceRefresh(t, api.Database, user, claims)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
|
||||
// AD instance and saving them to replay, removing some details.
|
||||
// The reason this is done is that this is the only way to assert values
|
||||
|
@ -269,7 +324,7 @@ func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
|
||||
// assertJWTAuth will assert the basic JWT auth assertions. It will return the
|
||||
// url.Values from the request body for any additional assertions to be made.
|
||||
func assertJWTAuth(t *testing.T, r *http.Request) url.Values {
|
||||
func assertJWTAuth(t testing.TB, r *http.Request) url.Values {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if !assert.NoError(t, err) {
|
||||
return nil
|
||||
|
|
|
@ -4,28 +4,25 @@ import (
|
|||
"context"
|
||||
"crypto"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
@ -35,85 +32,42 @@ import (
|
|||
// This test specifically tests logging in with OIDC when an expired
|
||||
// OIDC session token exists.
|
||||
// The token refreshing should not happen since we are reauthenticating.
|
||||
// nolint:bodyclose
|
||||
func TestOIDCOauthLoginWithExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf := coderdtest.NewOIDCConfig(t, "",
|
||||
// Provide a refresh token so we use the refresh token flow
|
||||
coderdtest.WithRefreshToken("refresh_token"),
|
||||
// We need to set the expire in the future for the first api calls.
|
||||
coderdtest.WithTokenExpires(func() time.Time {
|
||||
return time.Now().Add(time.Hour).UTC()
|
||||
}),
|
||||
// No refresh should actually happen in this test.
|
||||
coderdtest.WithTokenSource(func() (*oauth2.Token, error) {
|
||||
return nil, xerrors.New("token should not require refresh")
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
auditor := audit.NewMock()
|
||||
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.IgnoreUserInfo = true
|
||||
})
|
||||
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
OIDCConfig: cfg,
|
||||
})
|
||||
|
||||
const username = "alice"
|
||||
claims := jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
"email_verified": true,
|
||||
"preferred_username": username,
|
||||
}
|
||||
config := conf.OIDCConfig(t, claims)
|
||||
|
||||
config.AllowSignups = true
|
||||
config.IgnoreUserInfo = true
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: config,
|
||||
Logger: &logger,
|
||||
})
|
||||
|
||||
helper := oidctest.NewLoginHelper(client, fake)
|
||||
// Signup alice
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, claims))
|
||||
// Set the client to use this OIDC context
|
||||
authCookie := authCookieValue(resp.Cookies())
|
||||
client.SetSessionToken(authCookie)
|
||||
_ = resp.Body.Close()
|
||||
userClient, _ := helper.Login(t, claims)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
// Verify the user and oauth link
|
||||
user, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, username, user.Username)
|
||||
// Expire the link. This will force the client to refresh the token.
|
||||
helper.ExpireOauthToken(t, api.Database, userClient)
|
||||
|
||||
// nolint:gocritic
|
||||
link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginType(user.LoginType),
|
||||
})
|
||||
require.NoError(t, err, "failed to get user link")
|
||||
|
||||
// Expire the link
|
||||
// nolint:gocritic
|
||||
_, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(),
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.NoError(t, err, "failed to update user link")
|
||||
|
||||
// Log in again with OIDC
|
||||
loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) {
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: authCookie,
|
||||
Path: "/",
|
||||
})
|
||||
})
|
||||
require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode)
|
||||
_ = loginAgain.Body.Close()
|
||||
|
||||
// Try to use new login
|
||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
||||
_, err = client.User(ctx, "me")
|
||||
require.NoError(t, err, "use new session")
|
||||
// Instead of refreshing, just log in again.
|
||||
helper.Login(t, claims)
|
||||
}
|
||||
|
||||
func TestUserLogin(t *testing.T) {
|
||||
|
@ -660,7 +614,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
"email": "kyle@kwc.io",
|
||||
},
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
Username: "kyle",
|
||||
}, {
|
||||
Name: "EmailNotVerified",
|
||||
|
@ -685,7 +639,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
"email_verified": false,
|
||||
},
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
Username: "kyle",
|
||||
IgnoreEmailVerified: true,
|
||||
}, {
|
||||
|
@ -709,7 +663,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
EmailDomain: []string{
|
||||
"kwc.io",
|
||||
},
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
Name: "EmptyClaims",
|
||||
IDTokenClaims: jwt.MapClaims{},
|
||||
|
@ -730,7 +684,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
},
|
||||
Username: "kyle",
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
Name: "UsernameFromClaims",
|
||||
IDTokenClaims: jwt.MapClaims{
|
||||
|
@ -740,7 +694,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
},
|
||||
Username: "hotdog",
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
// Services like Okta return the email as the username:
|
||||
// https://developer.okta.com/docs/reference/api/oidc/#base-claims-always-present
|
||||
|
@ -752,7 +706,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
},
|
||||
Username: "kyle",
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
// See: https://github.com/coder/coder/issues/4472
|
||||
Name: "UsernameIsEmail",
|
||||
|
@ -761,7 +715,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
},
|
||||
Username: "kyle",
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
Name: "WithPicture",
|
||||
IDTokenClaims: jwt.MapClaims{
|
||||
|
@ -773,7 +727,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
Username: "kyle",
|
||||
AllowSignups: true,
|
||||
AvatarURL: "/example.png",
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
Name: "WithUserInfoClaims",
|
||||
IDTokenClaims: jwt.MapClaims{
|
||||
|
@ -787,7 +741,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
Username: "potato",
|
||||
AllowSignups: true,
|
||||
AvatarURL: "/example.png",
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
Name: "GroupsDoesNothing",
|
||||
IDTokenClaims: jwt.MapClaims{
|
||||
|
@ -795,7 +749,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
"groups": []string{"pingpong"},
|
||||
},
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
Name: "UserInfoOverridesIDTokenClaims",
|
||||
IDTokenClaims: jwt.MapClaims{
|
||||
|
@ -810,7 +764,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
Username: "user",
|
||||
AllowSignups: true,
|
||||
IgnoreEmailVerified: false,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}, {
|
||||
Name: "InvalidUserInfo",
|
||||
IDTokenClaims: jwt.MapClaims{
|
||||
|
@ -837,36 +791,41 @@ func TestUserOIDC(t *testing.T) {
|
|||
Username: "user",
|
||||
IgnoreUserInfo: true,
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
StatusCode: http.StatusOK,
|
||||
}} {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
oidctest.WithStaticUserInfo(tc.UserInfoClaims),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = tc.AllowSignups
|
||||
cfg.EmailDomain = tc.EmailDomain
|
||||
cfg.IgnoreEmailVerified = tc.IgnoreEmailVerified
|
||||
cfg.IgnoreUserInfo = tc.IgnoreUserInfo
|
||||
})
|
||||
|
||||
auditor := audit.NewMock()
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
config := conf.OIDCConfig(t, tc.UserInfoClaims)
|
||||
config.AllowSignups = tc.AllowSignups
|
||||
config.EmailDomain = tc.EmailDomain
|
||||
config.IgnoreEmailVerified = tc.IgnoreEmailVerified
|
||||
config.IgnoreUserInfo = tc.IgnoreUserInfo
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
owner := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: config,
|
||||
OIDCConfig: cfg,
|
||||
Logger: &logger,
|
||||
})
|
||||
numLogs := len(auditor.AuditLogs())
|
||||
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.IDTokenClaims))
|
||||
client, resp := fake.AttemptLogin(t, owner, tc.IDTokenClaims)
|
||||
numLogs++ // add an audit log for login
|
||||
assert.Equal(t, tc.StatusCode, resp.StatusCode)
|
||||
require.Equal(t, tc.StatusCode, resp.StatusCode)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
if tc.Username != "" {
|
||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
||||
user, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.Username, user.Username)
|
||||
|
@ -877,7 +836,6 @@ func TestUserOIDC(t *testing.T) {
|
|||
}
|
||||
|
||||
if tc.AvatarURL != "" {
|
||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
||||
user, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.AvatarURL, user.AvatarURL)
|
||||
|
@ -890,26 +848,29 @@ func TestUserOIDC(t *testing.T) {
|
|||
|
||||
t.Run("OIDCConvert", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := audit.NewMock()
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
config := conf.OIDCConfig(t, nil)
|
||||
config.AllowSignups = true
|
||||
|
||||
cfg := coderdtest.DeploymentValues(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: config,
|
||||
DeploymentValues: cfg,
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: cfg,
|
||||
})
|
||||
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
code := conf.EncodeClaims(t, jwt.MapClaims{
|
||||
claims := jwt.MapClaims{
|
||||
"email": userData.Email,
|
||||
})
|
||||
|
||||
}
|
||||
var err error
|
||||
user.HTTPClient.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err)
|
||||
|
@ -921,52 +882,58 @@ func TestUserOIDC(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
fake.LoginWithClient(t, user, claims, func(r *http.Request) {
|
||||
r.URL.RawQuery = url.Values{
|
||||
"oidc_merge_state": {convertResponse.StateString},
|
||||
}.Encode()
|
||||
r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken())
|
||||
cookies := user.HTTPClient.Jar.Cookies(r.URL)
|
||||
for _, cookie := range cookies {
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("AlternateUsername", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
auditor := audit.NewMock()
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
config := conf.OIDCConfig(t, nil)
|
||||
config.AllowSignups = true
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
})
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: config,
|
||||
OIDCConfig: cfg,
|
||||
})
|
||||
numLogs := len(auditor.AuditLogs())
|
||||
|
||||
code := conf.EncodeClaims(t, jwt.MapClaims{
|
||||
numLogs := len(auditor.AuditLogs())
|
||||
claims := jwt.MapClaims{
|
||||
"email": "jon@coder.com",
|
||||
})
|
||||
resp := oidcCallback(t, client, code)
|
||||
}
|
||||
|
||||
userClient, _ := fake.Login(t, client, claims)
|
||||
numLogs++ // add an audit log for login
|
||||
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
||||
user, err := client.User(ctx, "me")
|
||||
user, err := userClient.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "jon", user.Username)
|
||||
|
||||
// Pass a different subject field so that we prompt creating a
|
||||
// new user.
|
||||
code = conf.EncodeClaims(t, jwt.MapClaims{
|
||||
// new user
|
||||
userClient, _ = fake.Login(t, client, jwt.MapClaims{
|
||||
"email": "jon@example2.com",
|
||||
"sub": "diff",
|
||||
})
|
||||
resp = oidcCallback(t, client, code)
|
||||
numLogs++ // add an audit log for login
|
||||
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
||||
user, err = client.User(ctx, "me")
|
||||
user, err = userClient.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")
|
||||
|
||||
|
@ -977,45 +944,62 @@ func TestUserOIDC(t *testing.T) {
|
|||
t.Run("Disabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
resp := oidcCallback(t, client, "asdf")
|
||||
oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback")
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NoIDToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: &coderd.OIDCConfig{
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
},
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
})
|
||||
|
||||
resp := oidcCallback(t, client, "asdf")
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: cfg,
|
||||
})
|
||||
|
||||
_, resp := fake.AttemptLogin(t, client, jwt.MapClaims{})
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("BadVerify", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
verifier := oidc.NewVerifier("", &oidc.StaticKeySet{
|
||||
badVerifier := oidc.NewVerifier("", &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{},
|
||||
}, &oidc.Config{})
|
||||
provider := &oidc.Provider{}
|
||||
badProvider := &oidc.Provider{}
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: &coderd.OIDCConfig{
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": "invalid",
|
||||
}),
|
||||
},
|
||||
Provider: provider,
|
||||
Verifier: verifier,
|
||||
},
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.Provider = badProvider
|
||||
cfg.Verifier = badVerifier
|
||||
})
|
||||
|
||||
resp := oidcCallback(t, client, "asdf")
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: cfg,
|
||||
})
|
||||
|
||||
_, resp := fake.AttemptLogin(t, client, jwt.MapClaims{})
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
@ -1146,36 +1130,6 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
|||
return res
|
||||
}
|
||||
|
||||
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
|
||||
return oidcCallbackWithState(t, client, code, "somestate", nil)
|
||||
}
|
||||
|
||||
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=%s", code, state))
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: codersdk.OAuth2StateCookie,
|
||||
Value: state,
|
||||
})
|
||||
if modify != nil {
|
||||
modify(req)
|
||||
}
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
data, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
t.Log(string(data))
|
||||
return res
|
||||
}
|
||||
|
||||
func i64ptr(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
|
|
|
@ -8,7 +8,10 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -403,6 +406,7 @@ func TestPostLogout(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// nolint:bodyclose
|
||||
func TestPostUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NoAuth", func(t *testing.T) {
|
||||
|
@ -593,15 +597,15 @@ func TestPostUsers(t *testing.T) {
|
|||
t.Run("CreateOIDCLoginType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
email := "another@user.org"
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{
|
||||
"email": email,
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
})
|
||||
config.AllowSignups = false
|
||||
config.IgnoreUserInfo = true
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
OIDCConfig: cfg,
|
||||
})
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
|
@ -618,15 +622,9 @@ func TestPostUsers(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Try to log in with OIDC.
|
||||
userClient := codersdk.New(client.URL)
|
||||
resp := oidcCallback(t, userClient, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
userClient, _ := fake.Login(t, client, jwt.MapClaims{
|
||||
"email": email,
|
||||
}))
|
||||
require.Equal(t, resp.StatusCode, http.StatusTemporaryRedirect)
|
||||
// Set the client to use this OIDC context
|
||||
authCookie := authCookieValue(resp.Cookies())
|
||||
userClient.SetSessionToken(authCookie)
|
||||
_ = resp.Body.Close()
|
||||
})
|
||||
|
||||
found, err := userClient.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
package syncmap
|
||||
|
||||
import "sync"
|
||||
|
||||
// Map is a type safe sync.Map
|
||||
type Map[K, V any] struct {
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func New[K, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{
|
||||
m: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Store(k K, v V) {
|
||||
m.m.Store(k, v)
|
||||
}
|
||||
|
||||
//nolint:forcetypeassert
|
||||
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
|
||||
v, ok := m.m.Load(key)
|
||||
if !ok {
|
||||
var empty V
|
||||
return empty, false
|
||||
}
|
||||
return v.(V), ok
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Delete(key K) {
|
||||
m.m.Delete(key)
|
||||
}
|
||||
|
||||
//nolint:forcetypeassert
|
||||
func (m *Map[K, V]) LoadAndDelete(key K) (actual V, loaded bool) {
|
||||
act, loaded := m.m.LoadAndDelete(key)
|
||||
if !loaded {
|
||||
var empty V
|
||||
return empty, loaded
|
||||
}
|
||||
return act.(V), loaded
|
||||
}
|
||||
|
||||
//nolint:forcetypeassert
|
||||
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
||||
act, loaded := m.m.LoadOrStore(key, value)
|
||||
if !loaded {
|
||||
var empty V
|
||||
return empty, loaded
|
||||
}
|
||||
return act.(V), loaded
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) CompareAndSwap(key K, old V, new V) bool {
|
||||
return m.m.CompareAndSwap(key, old, new)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
|
||||
return m.m.CompareAndDelete(key, old)
|
||||
}
|
||||
|
||||
//nolint:forcetypeassert
|
||||
func (m *Map[K, V]) Swap(key K, value V) (previous any, loaded bool) {
|
||||
previous, loaded = m.m.Swap(key, value)
|
||||
if !loaded {
|
||||
var empty V
|
||||
return empty, loaded
|
||||
}
|
||||
return previous.(V), loaded
|
||||
}
|
||||
|
||||
//nolint:forcetypeassert
|
||||
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
|
||||
m.m.Range(func(key, value interface{}) bool {
|
||||
return f(key.(K), value.(V))
|
||||
})
|
||||
}
|
|
@ -1,25 +1,22 @@
|
|||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
coderden "github.com/coder/coder/v2/enterprise/coderd"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
|
@ -31,128 +28,123 @@ func TestUserOIDC(t *testing.T) {
|
|||
t.Run("RoleSync", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// NoRoles is the "control group". It has claims with 0 roles
|
||||
// assigned, and asserts that the user has no roles.
|
||||
t.Run("NoRoles", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
oidcRoleName := "TemplateAuthor"
|
||||
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}
|
||||
})
|
||||
config.AllowSignups = true
|
||||
config.UserRoleField = "roles"
|
||||
|
||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.UserRoleField = "roles"
|
||||
},
|
||||
})
|
||||
|
||||
admin, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, admin.OrganizationIDs, 1)
|
||||
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
claims := jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
}))
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
user, err := client.User(ctx, "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, user.Roles, 0)
|
||||
roleNames := []string{}
|
||||
require.ElementsMatch(t, roleNames, []string{})
|
||||
}
|
||||
// Login a new client that signs up
|
||||
client, resp := runner.Login(t, claims)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
// User should be in 0 groups.
|
||||
runner.AssertRoles(t, "alice", []string{})
|
||||
// Force a refresh, and assert nothing has changes
|
||||
runner.ForceRefresh(t, client, claims)
|
||||
runner.AssertRoles(t, "alice", []string{})
|
||||
})
|
||||
|
||||
t.Run("NewUserAndRemoveRoles", func(t *testing.T) {
|
||||
// A user has some roles, then on an oauth refresh will lose said
|
||||
// roles from an updated claim.
|
||||
t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) {
|
||||
// TODO: Implement new feature to update roles/groups on OIDC
|
||||
// refresh tokens. https://github.com/coder/coder/issues/9312
|
||||
t.Skip("Refreshing tokens does not update roles :(")
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
oidcRoleName := "TemplateAuthor"
|
||||
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}
|
||||
})
|
||||
config.AllowSignups = true
|
||||
config.UserRoleField = "roles"
|
||||
|
||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
|
||||
const oidcRoleName = "TemplateAuthor"
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}},
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.UserRoleField = "roles"
|
||||
cfg.UserRoleMapping = map[string][]string{
|
||||
oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()},
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
admin, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, admin.OrganizationIDs, 1)
|
||||
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
// User starts with the owner role
|
||||
client, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
"roles": []string{"random", oidcRoleName, rbac.RoleOwner()},
|
||||
}))
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
_ = resp.Body.Close()
|
||||
user, err := client.User(ctx, "alice")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
|
||||
|
||||
require.Len(t, user.Roles, 3)
|
||||
roleNames := []string{user.Roles[0].Name, user.Roles[1].Name, user.Roles[2].Name}
|
||||
require.ElementsMatch(t, roleNames, []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
|
||||
|
||||
// Now remove the roles with a new oidc login
|
||||
resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
// Now refresh the oauth, and check the roles are removed.
|
||||
// Force a refresh, and assert nothing has changes
|
||||
runner.ForceRefresh(t, client, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
"roles": []string{"random"},
|
||||
}))
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
_ = resp.Body.Close()
|
||||
user, err = client.User(ctx, "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, user.Roles, 0)
|
||||
})
|
||||
runner.AssertRoles(t, "alice", []string{})
|
||||
})
|
||||
|
||||
// A user has some roles, then on another oauth login will lose said
|
||||
// roles from an updated claim.
|
||||
t.Run("NewUserAndRemoveRolesOnReAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const oidcRoleName = "TemplateAuthor"
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}},
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.UserRoleField = "roles"
|
||||
cfg.UserRoleMapping = map[string][]string{
|
||||
oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()},
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
// User starts with the owner role
|
||||
_, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
"roles": []string{"random", oidcRoleName, rbac.RoleOwner()},
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
|
||||
|
||||
// Now login with oauth again, and check the roles are removed.
|
||||
_, resp = runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
"roles": []string{"random"},
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
runner.AssertRoles(t, "alice", []string{})
|
||||
})
|
||||
|
||||
// All manual role updates should fail when role sync is enabled.
|
||||
t.Run("BlockAssignRoles", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{})
|
||||
config.AllowSignups = true
|
||||
config.UserRoleField = "roles"
|
||||
|
||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.UserRoleField = "roles"
|
||||
},
|
||||
})
|
||||
|
||||
admin, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, admin.OrganizationIDs, 1)
|
||||
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
_, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
"roles": []string{},
|
||||
}))
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
// Try to manually update user roles, even though controlled by oidc
|
||||
// role sync.
|
||||
_, err = client.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_, err := runner.AdminClient.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{
|
||||
Roles: []string{
|
||||
rbac.RoleTemplateAdmin(),
|
||||
},
|
||||
|
@ -164,199 +156,211 @@ func TestUserOIDC(t *testing.T) {
|
|||
|
||||
t.Run("Groups", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Assigns does a simple test of assigning a user to a group based
|
||||
// on the oidc claims.
|
||||
t.Run("Assigns", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
const groupClaim = "custom-groups"
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.GroupField = groupClaim
|
||||
})
|
||||
config.AllowSignups = true
|
||||
|
||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
||||
const groupName = "bingbong"
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.GroupField = groupClaim
|
||||
},
|
||||
})
|
||||
|
||||
admin, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, admin.OrganizationIDs, 1)
|
||||
|
||||
groupName := "bingbong"
|
||||
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||
Name: groupName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 0)
|
||||
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
"email": "colin@coder.com",
|
||||
_, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
groupClaim: []string{groupName},
|
||||
}))
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
group, err = client.Group(ctx, group.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 1)
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertGroups(t, "alice", []string{groupName})
|
||||
})
|
||||
|
||||
// Tests the group mapping feature.
|
||||
t.Run("AssignsMapped", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
const groupClaim = "custom-groups"
|
||||
|
||||
oidcGroupName := "pingpong"
|
||||
coderGroupName := "bingbong"
|
||||
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName}
|
||||
})
|
||||
config.AllowSignups = true
|
||||
|
||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
||||
const oidcGroupName = "pingpong"
|
||||
const coderGroupName = "bingbong"
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.GroupField = groupClaim
|
||||
cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName}
|
||||
},
|
||||
})
|
||||
|
||||
admin, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, admin.OrganizationIDs, 1)
|
||||
|
||||
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||
Name: coderGroupName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 0)
|
||||
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
"email": "colin@coder.com",
|
||||
"groups": []string{oidcGroupName},
|
||||
}))
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
group, err = client.Group(ctx, group.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 1)
|
||||
_, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
groupClaim: []string{oidcGroupName},
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertGroups(t, "alice", []string{coderGroupName})
|
||||
})
|
||||
|
||||
t.Run("AddThenRemove", func(t *testing.T) {
|
||||
// User is in a group, then on an oauth refresh will lose said
|
||||
// group.
|
||||
t.Run("AddThenRemoveOnRefresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
// TODO: Implement new feature to update roles/groups on OIDC
|
||||
// refresh tokens. https://github.com/coder/coder/issues/9312
|
||||
t.Skip("Refreshing tokens does not update groups :(")
|
||||
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{})
|
||||
config.AllowSignups = true
|
||||
|
||||
client, firstUser := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
||||
const groupClaim = "custom-groups"
|
||||
const groupName = "bingbong"
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.GroupField = groupClaim
|
||||
},
|
||||
})
|
||||
|
||||
// Add some extra users/groups that should be asserted after.
|
||||
// Adding this user as there was a bug that removing 1 user removed
|
||||
// all users from the group.
|
||||
_, extra := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
groupName := "bingbong"
|
||||
group, err := client.CreateGroup(ctx, firstUser.OrganizationID, codersdk.CreateGroupRequest{
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||
Name: groupName,
|
||||
})
|
||||
require.NoError(t, err, "create group")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 0)
|
||||
|
||||
group, err = client.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{
|
||||
AddUsers: []string{
|
||||
firstUser.UserID.String(),
|
||||
extra.ID.String(),
|
||||
},
|
||||
client, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
groupClaim: []string{groupName},
|
||||
})
|
||||
require.NoError(t, err, "patch group")
|
||||
require.Len(t, group.Members, 2, "expect both members")
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertGroups(t, "alice", []string{groupName})
|
||||
|
||||
// Now add OIDC user into the group
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
"email": "colin@coder.com",
|
||||
"groups": []string{groupName},
|
||||
}))
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
group, err = client.Group(ctx, group.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 3)
|
||||
|
||||
// Login to remove the OIDC user from the group
|
||||
resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
"email": "colin@coder.com",
|
||||
"groups": []string{},
|
||||
}))
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
group, err = client.Group(ctx, group.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 2)
|
||||
var expected []uuid.UUID
|
||||
for _, mem := range group.Members {
|
||||
expected = append(expected, mem.ID)
|
||||
}
|
||||
require.ElementsMatchf(t, expected, []uuid.UUID{firstUser.UserID, extra.ID}, "expected members")
|
||||
// Refresh without the group claim
|
||||
runner.ForceRefresh(t, client, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
})
|
||||
runner.AssertGroups(t, "alice", []string{})
|
||||
})
|
||||
|
||||
t.Run("AddThenRemoveOnReAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const groupClaim = "custom-groups"
|
||||
const groupName = "bingbong"
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.GroupField = groupClaim
|
||||
},
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||
Name: groupName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 0)
|
||||
|
||||
_, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
groupClaim: []string{groupName},
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertGroups(t, "alice", []string{groupName})
|
||||
|
||||
// Refresh without the group claim
|
||||
_, resp = runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertGroups(t, "alice", []string{})
|
||||
})
|
||||
|
||||
// Updating groups where the claimed group does not exist.
|
||||
t.Run("NoneMatch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{})
|
||||
config.AllowSignups = true
|
||||
|
||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
||||
const groupClaim = "custom-groups"
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.GroupField = groupClaim
|
||||
},
|
||||
})
|
||||
|
||||
admin, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, admin.OrganizationIDs, 1)
|
||||
|
||||
groupName := "bingbong"
|
||||
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||
Name: groupName,
|
||||
_, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
groupClaim: []string{"not-exists"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 0)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertGroups(t, "alice", []string{})
|
||||
})
|
||||
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
||||
"email": "colin@coder.com",
|
||||
"groups": []string{"coolin"},
|
||||
}))
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
// Updating groups where the claimed group does not exist creates
|
||||
// the group.
|
||||
t.Run("AutoCreate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
group, err = client.Group(ctx, group.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, group.Members, 0)
|
||||
const groupClaim = "custom-groups"
|
||||
const groupName = "make-me"
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.GroupField = groupClaim
|
||||
cfg.CreateMissingGroups = true
|
||||
},
|
||||
})
|
||||
|
||||
_, resp := runner.Login(t, jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
groupClaim: []string{groupName},
|
||||
})
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
runner.AssertGroups(t, "alice", []string{groupName})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Refresh", func(t *testing.T) {
|
||||
t.Run("RefreshTokensMultiple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
cfg.UserRoleField = "roles"
|
||||
},
|
||||
})
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
}
|
||||
// Login a new client that signs up
|
||||
client, resp := runner.Login(t, claims)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Refresh multiple times.
|
||||
for i := 0; i < 3; i++ {
|
||||
runner.ForceRefresh(t, client, claims)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// nolint:bodyclose
|
||||
func TestGroupSync(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -470,28 +474,20 @@ func TestGroupSync(t *testing.T) {
|
|||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, tc.modCfg)
|
||||
|
||||
client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
Config: func(cfg *coderd.OIDCConfig) {
|
||||
cfg.GroupField = "groups"
|
||||
tc.modCfg(cfg)
|
||||
},
|
||||
})
|
||||
|
||||
admin, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, admin.OrganizationIDs, 1)
|
||||
|
||||
// Setup
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
org := runner.AdminUser.OrganizationIDs[0]
|
||||
|
||||
initialGroups := make(map[string]codersdk.Group)
|
||||
for _, group := range tc.initialOrgGroups {
|
||||
newGroup, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||
newGroup, err := runner.AdminClient.CreateGroup(ctx, org, codersdk.CreateGroupRequest{
|
||||
Name: group,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -500,16 +496,16 @@ func TestGroupSync(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create the user and add them to their initial groups
|
||||
_, user := coderdtest.CreateAnotherUser(t, client, admin.OrganizationIDs[0])
|
||||
_, user := coderdtest.CreateAnotherUser(t, runner.AdminClient, org)
|
||||
for _, group := range tc.initialUserGroups {
|
||||
_, err := client.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{
|
||||
_, err := runner.AdminClient.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{
|
||||
AddUsers: []string{user.ID.String()},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// nolint:gocritic
|
||||
_, err = api.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{
|
||||
_, err := runner.API.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{
|
||||
NewLoginType: database.LoginTypeOIDC,
|
||||
UserID: user.ID,
|
||||
})
|
||||
|
@ -517,11 +513,11 @@ func TestGroupSync(t *testing.T) {
|
|||
|
||||
// Log in the new user
|
||||
tc.claims["email"] = user.Email
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.claims))
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
_ = resp.Body.Close()
|
||||
_, resp := runner.Login(t, tc.claims)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
orgGroups, err := client.GroupsByOrganization(ctx, admin.OrganizationIDs[0])
|
||||
// Check group sources
|
||||
orgGroups, err := runner.AdminClient.GroupsByOrganization(ctx, org)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, group := range orgGroups {
|
||||
|
@ -567,24 +563,107 @@ func TestGroupSync(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
|
||||
t.Helper()
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code))
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: codersdk.OAuth2StateCookie,
|
||||
Value: "somestate",
|
||||
})
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
data, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
t.Log(string(data))
|
||||
return res
|
||||
// oidcTestRunner is just a helper to setup and run oidc tests.
|
||||
// An actual Coderd instance is used to run the tests.
|
||||
type oidcTestRunner struct {
|
||||
AdminClient *codersdk.Client
|
||||
AdminUser codersdk.User
|
||||
API *coderden.API
|
||||
|
||||
// Login will call the OIDC flow with an unauthenticated client.
|
||||
// The IDP will return the idToken claims.
|
||||
Login func(t *testing.T, idToken jwt.MapClaims) (*codersdk.Client, *http.Response)
|
||||
// ForceRefresh will use an authenticated codersdk.Client, and force their
|
||||
// OIDC token to be expired and require a refresh. The refresh will use the claims provided.
|
||||
// It just calls the /users/me endpoint to trigger the refresh.
|
||||
ForceRefresh func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims)
|
||||
}
|
||||
|
||||
type oidcTestConfig struct {
|
||||
Userinfo jwt.MapClaims
|
||||
|
||||
// Config allows modifying the Coderd OIDC configuration.
|
||||
Config func(cfg *coderd.OIDCConfig)
|
||||
}
|
||||
|
||||
func (r *oidcTestRunner) AssertRoles(t *testing.T, userIdent string, roles []string) {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
user, err := r.AdminClient.User(ctx, userIdent)
|
||||
require.NoError(t, err)
|
||||
|
||||
roleNames := []string{}
|
||||
for _, role := range user.Roles {
|
||||
roleNames = append(roleNames, role.Name)
|
||||
}
|
||||
require.ElementsMatch(t, roles, roleNames, "expected roles")
|
||||
}
|
||||
|
||||
func (r *oidcTestRunner) AssertGroups(t *testing.T, userIdent string, groups []string) {
|
||||
t.Helper()
|
||||
|
||||
if !slice.Contains(groups, database.EveryoneGroup) {
|
||||
var cpy []string
|
||||
cpy = append(cpy, groups...)
|
||||
// always include everyone group
|
||||
cpy = append(cpy, database.EveryoneGroup)
|
||||
groups = cpy
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
user, err := r.AdminClient.User(ctx, userIdent)
|
||||
require.NoError(t, err)
|
||||
|
||||
allGroups, err := r.AdminClient.GroupsByOrganization(ctx, user.OrganizationIDs[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
userInGroups := []string{}
|
||||
for _, g := range allGroups {
|
||||
for _, mem := range g.Members {
|
||||
if mem.ID == user.ID {
|
||||
userInGroups = append(userInGroups, g.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, groups, userInGroups, "expected groups")
|
||||
}
|
||||
|
||||
func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner {
|
||||
t.Helper()
|
||||
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithStaticUserInfo(settings.Userinfo),
|
||||
oidctest.WithLogging(t, nil),
|
||||
// Run fake IDP on a real webserver
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
cfg := fake.OIDCConfig(t, nil, settings.Config)
|
||||
owner, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
OIDCConfig: cfg,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureUserRoleManagement: 1,
|
||||
codersdk.FeatureTemplateRBAC: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
admin, err := owner.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
|
||||
helper := oidctest.NewLoginHelper(owner, fake)
|
||||
|
||||
return &oidcTestRunner{
|
||||
AdminClient: owner,
|
||||
AdminUser: admin,
|
||||
API: api,
|
||||
Login: helper.Login,
|
||||
ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) {
|
||||
helper.ForceRefresh(t, api.Database, client, idToken)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue