mirror of https://github.com/coder/coder.git
test: add full OIDC fake IDP (#9317)
* test: implement fake OIDC provider with full functionality * Refactor existing tests
This commit is contained in:
parent
0a213a6ac3
commit
d9d4d74f99
|
@ -31,15 +31,13 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cloud.google.com/go/compute/metadata"
|
"cloud.google.com/go/compute/metadata"
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
|
||||||
"github.com/fullsailor/pkcs7"
|
"github.com/fullsailor/pkcs7"
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/moby/moby/pkg/namesgenerator"
|
"github.com/moby/moby/pkg/namesgenerator"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
"google.golang.org/api/idtoken"
|
"google.golang.org/api/idtoken"
|
||||||
"google.golang.org/api/option"
|
"google.golang.org/api/option"
|
||||||
|
@ -1020,152 +1018,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCConfig struct {
|
|
||||||
key *rsa.PrivateKey
|
|
||||||
issuer string
|
|
||||||
// These are optional
|
|
||||||
refreshToken string
|
|
||||||
oidcTokenExpires func() time.Time
|
|
||||||
tokenSource func() (*oauth2.Token, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
|
|
||||||
return func(cfg *OIDCConfig) {
|
|
||||||
cfg.refreshToken = token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
|
|
||||||
return func(cfg *OIDCConfig) {
|
|
||||||
cfg.oidcTokenExpires = expFunc
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
|
|
||||||
return func(cfg *OIDCConfig) {
|
|
||||||
cfg.tokenSource = src
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
|
||||||
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
if issuer == "" {
|
|
||||||
issuer = "https://coder.com"
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := &OIDCConfig{
|
|
||||||
key: pkey,
|
|
||||||
issuer: issuer,
|
|
||||||
}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(cfg)
|
|
||||||
}
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
|
||||||
return "/?state=" + url.QueryEscape(state)
|
|
||||||
}
|
|
||||||
|
|
||||||
type tokenSource struct {
|
|
||||||
src func() (*oauth2.Token, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s tokenSource) Token() (*oauth2.Token, error) {
|
|
||||||
return s.src()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
|
||||||
if cfg.tokenSource == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return tokenSource{
|
|
||||||
src: cfg.tokenSource,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
|
||||||
token, err := base64.StdEncoding.DecodeString(code)
|
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("decode code: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var exp time.Time
|
|
||||||
if cfg.oidcTokenExpires != nil {
|
|
||||||
exp = cfg.oidcTokenExpires()
|
|
||||||
}
|
|
||||||
|
|
||||||
return (&oauth2.Token{
|
|
||||||
AccessToken: "token",
|
|
||||||
RefreshToken: cfg.refreshToken,
|
|
||||||
Expiry: exp,
|
|
||||||
}).WithExtra(map[string]interface{}{
|
|
||||||
"id_token": string(token),
|
|
||||||
}), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if _, ok := claims["exp"]; !ok {
|
|
||||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := claims["iss"]; !ok {
|
|
||||||
claims["iss"] = cfg.issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := claims["sub"]; !ok {
|
|
||||||
claims["sub"] = "testme"
|
|
||||||
}
|
|
||||||
|
|
||||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return base64.StdEncoding.EncodeToString([]byte(signed))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
|
||||||
// By default, the provider can be empty.
|
|
||||||
// This means it won't support any endpoints!
|
|
||||||
provider := &oidc.Provider{}
|
|
||||||
if userInfoClaims != nil {
|
|
||||||
resp, err := json.Marshal(userInfoClaims)
|
|
||||||
require.NoError(t, err)
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write(resp)
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
cfg := &oidc.ProviderConfig{
|
|
||||||
UserInfoURL: srv.URL,
|
|
||||||
}
|
|
||||||
provider = cfg.NewProvider(context.Background())
|
|
||||||
}
|
|
||||||
newCFG := &coderd.OIDCConfig{
|
|
||||||
OAuth2Config: cfg,
|
|
||||||
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
|
|
||||||
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
|
|
||||||
}, &oidc.Config{
|
|
||||||
SkipClientIDCheck: true,
|
|
||||||
}),
|
|
||||||
Provider: provider,
|
|
||||||
UsernameField: "preferred_username",
|
|
||||||
EmailField: "email",
|
|
||||||
AuthURLParams: map[string]string{"access_type": "offline"},
|
|
||||||
GroupField: "groups",
|
|
||||||
}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(newCFG)
|
|
||||||
}
|
|
||||||
return newCFG
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
|
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
|
||||||
// instance authentication for Azure.
|
// instance authentication for Azure.
|
||||||
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
|
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
|
||||||
|
@ -1254,22 +1106,6 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
|
||||||
return cerr
|
return cerr
|
||||||
}
|
}
|
||||||
|
|
||||||
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
|
||||||
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
|
|
||||||
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
|
|
||||||
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
|
|
||||||
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
|
|
||||||
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
|
|
||||||
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
|
|
||||||
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
|
|
||||||
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
|
|
||||||
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
|
|
||||||
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
|
|
||||||
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
|
|
||||||
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
|
|
||||||
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
|
|
||||||
-----END RSA PRIVATE KEY-----`
|
|
||||||
|
|
||||||
func DeploymentValues(t testing.TB) *codersdk.DeploymentValues {
|
func DeploymentValues(t testing.TB) *codersdk.DeploymentValues {
|
||||||
var cfg codersdk.DeploymentValues
|
var cfg codersdk.DeploymentValues
|
||||||
opts := cfg.Options()
|
opts := cfg.Options()
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
package oidctest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||||
|
"github.com/coder/coder/v2/coderd/httpmw"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoginHelper helps with logging in a user and refreshing their oauth tokens.
|
||||||
|
// It is mainly because refreshing oauth tokens is a bit tricky and requires
|
||||||
|
// some database manipulation.
|
||||||
|
type LoginHelper struct {
|
||||||
|
fake *FakeIDP
|
||||||
|
client *codersdk.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper {
|
||||||
|
if client == nil {
|
||||||
|
panic("client must not be nil")
|
||||||
|
}
|
||||||
|
if fake == nil {
|
||||||
|
panic("fake must not be nil")
|
||||||
|
}
|
||||||
|
return &LoginHelper{
|
||||||
|
fake: fake,
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login just helps by making an unauthenticated client and logging in with
|
||||||
|
// the given claims. All Logins should be unauthenticated, so this is a
|
||||||
|
// convenience method.
|
||||||
|
func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) {
|
||||||
|
t.Helper()
|
||||||
|
unauthenticatedClient := codersdk.New(h.client.URL)
|
||||||
|
|
||||||
|
return h.fake.Login(t, unauthenticatedClient, idTokenClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpireOauthToken expires the oauth token for the given user.
|
||||||
|
func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
//nolint:gocritic // Testing
|
||||||
|
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||||
|
|
||||||
|
id, _, err := httpmw.SplitAPIToken(user.SessionToken())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// We need to get the OIDC link and update it in the database to force
|
||||||
|
// it to be expired.
|
||||||
|
key, err := db.GetAPIKeyByID(ctx, id)
|
||||||
|
require.NoError(t, err, "get api key")
|
||||||
|
|
||||||
|
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||||
|
UserID: key.UserID,
|
||||||
|
LoginType: database.LoginTypeOIDC,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "get user link")
|
||||||
|
|
||||||
|
// Expire the oauth link for the given user.
|
||||||
|
updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
|
||||||
|
OAuthAccessToken: link.OAuthAccessToken,
|
||||||
|
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||||
|
OAuthExpiry: time.Now().Add(time.Hour * -1),
|
||||||
|
UserID: link.UserID,
|
||||||
|
LoginType: link.LoginType,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "expire user link")
|
||||||
|
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForceRefresh forces the client to refresh its oauth token. It does this by
|
||||||
|
// expiring the oauth token, then doing an authenticated call. This will force
|
||||||
|
// the API Key middleware to refresh the oauth token.
|
||||||
|
//
|
||||||
|
// A unit test assertion makes sure the refresh token is used.
|
||||||
|
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
link := h.ExpireOauthToken(t, db, user)
|
||||||
|
// Updates the claims that the IDP will return. By default, it always
|
||||||
|
// uses the original claims for the original oauth token.
|
||||||
|
h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Do any authenticated call to force the refresh
|
||||||
|
_, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
|
||||||
|
require.NoError(t, err, "user must be able to be fetched")
|
||||||
|
}
|
|
@ -0,0 +1,793 @@
|
||||||
|
package oidctest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/util/syncmap"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/coderd"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FakeIDP is a functional OIDC provider.
|
||||||
|
// It only supports 1 OIDC client.
|
||||||
|
type FakeIDP struct {
|
||||||
|
issuer string
|
||||||
|
key *rsa.PrivateKey
|
||||||
|
provider providerJSON
|
||||||
|
handler http.Handler
|
||||||
|
cfg *oauth2.Config
|
||||||
|
|
||||||
|
// clientID to be used by coderd
|
||||||
|
clientID string
|
||||||
|
clientSecret string
|
||||||
|
logger slog.Logger
|
||||||
|
|
||||||
|
// These maps are used to control the state of the IDP.
|
||||||
|
// That is the various access tokens, refresh tokens, states, etc.
|
||||||
|
codeToStateMap *syncmap.Map[string, string]
|
||||||
|
// Token -> Email
|
||||||
|
accessTokens *syncmap.Map[string, string]
|
||||||
|
// Refresh Token -> Email
|
||||||
|
refreshTokensUsed *syncmap.Map[string, bool]
|
||||||
|
refreshTokens *syncmap.Map[string, string]
|
||||||
|
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||||
|
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||||
|
|
||||||
|
// hooks
|
||||||
|
// hookValidRedirectURL can be used to reject a redirect url from the
|
||||||
|
// IDP -> Application. Almost all IDPs have the concept of
|
||||||
|
// "Authorized Redirect URLs". This can be used to emulate that.
|
||||||
|
hookValidRedirectURL func(redirectURL string) error
|
||||||
|
hookUserInfo func(email string) jwt.MapClaims
|
||||||
|
fakeCoderd func(req *http.Request) (*http.Response, error)
|
||||||
|
hookOnRefresh func(email string) error
|
||||||
|
// Custom authentication for the client. This is useful if you want
|
||||||
|
// to test something like PKI auth vs a client_secret.
|
||||||
|
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
|
||||||
|
serve bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type FakeIDPOpt func(idp *FakeIDP)
|
||||||
|
|
||||||
|
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.hookValidRedirectURL = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRefreshHook is called when a refresh token is used. The email is
|
||||||
|
// the email of the user that is being refreshed assuming the claims are correct.
|
||||||
|
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.hookOnRefresh = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.hookAuthenticateClient = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogging is optional, but will log some HTTP calls made to the IDP.
|
||||||
|
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.logger = slogtest.Make(t, options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStaticUserInfo is optional, but will return the same user info for
|
||||||
|
// every user on the /userinfo endpoint.
|
||||||
|
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.hookUserInfo = func(_ string) jwt.MapClaims {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.hookUserInfo = userInfoFunc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithServing makes the IDP run an actual http server.
|
||||||
|
func WithServing() func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.serve = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithIssuer(issuer string) func(*FakeIDP) {
|
||||||
|
return func(f *FakeIDP) {
|
||||||
|
f.issuer = issuer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// nolint:gosec // It thinks this is a secret lol
|
||||||
|
tokenPath = "/oauth2/token"
|
||||||
|
authorizePath = "/oauth2/authorize"
|
||||||
|
keysPath = "/oauth2/keys"
|
||||||
|
userInfoPath = "/oauth2/userinfo"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||||
|
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
idp := &FakeIDP{
|
||||||
|
key: pkey,
|
||||||
|
clientID: uuid.NewString(),
|
||||||
|
clientSecret: uuid.NewString(),
|
||||||
|
logger: slog.Make(),
|
||||||
|
codeToStateMap: syncmap.New[string, string](),
|
||||||
|
accessTokens: syncmap.New[string, string](),
|
||||||
|
refreshTokens: syncmap.New[string, string](),
|
||||||
|
refreshTokensUsed: syncmap.New[string, bool](),
|
||||||
|
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||||
|
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||||
|
hookOnRefresh: func(_ string) error { return nil },
|
||||||
|
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
|
||||||
|
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(idp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if idp.issuer == "" {
|
||||||
|
idp.issuer = "https://coder.com"
|
||||||
|
}
|
||||||
|
|
||||||
|
idp.handler = idp.httpHandler(t)
|
||||||
|
idp.updateIssuerURL(t, idp.issuer)
|
||||||
|
if idp.serve {
|
||||||
|
idp.realServer(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return idp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
u, err := url.Parse(issuer)
|
||||||
|
require.NoError(t, err, "invalid issuer URL")
|
||||||
|
|
||||||
|
f.issuer = issuer
|
||||||
|
// providerJSON is the JSON representation of the OpenID Connect provider
|
||||||
|
// These are all the urls that the IDP will respond to.
|
||||||
|
f.provider = providerJSON{
|
||||||
|
Issuer: issuer,
|
||||||
|
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||||
|
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||||
|
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
|
||||||
|
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
|
||||||
|
Algorithms: []string{
|
||||||
|
"RS256",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// realServer turns the FakeIDP into a real http server.
|
||||||
|
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
srv := httptest.NewUnstartedServer(f.handler)
|
||||||
|
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
srv.Start()
|
||||||
|
t.Cleanup(srv.CloseClientConnections)
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
f.updateIssuerURL(t, srv.URL)
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login does the full OIDC flow starting at the "LoginButton".
|
||||||
|
// The client argument is just to get the URL of the Coder instance.
|
||||||
|
//
|
||||||
|
// The client passed in is just to get the url of the Coder instance.
|
||||||
|
// The actual client that is used is 100% unauthenticated and fresh.
|
||||||
|
func (f *FakeIDP) Login(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
client, resp := f.AttemptLogin(t, client, idTokenClaims, opts...)
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode, "client failed to login")
|
||||||
|
return client, resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||||
|
t.Helper()
|
||||||
|
var err error
|
||||||
|
|
||||||
|
cli := f.HTTPClient(client.HTTPClient)
|
||||||
|
shallowCpyCli := *cli
|
||||||
|
|
||||||
|
if shallowCpyCli.Jar == nil {
|
||||||
|
shallowCpyCli.Jar, err = cookiejar.New(nil)
|
||||||
|
require.NoError(t, err, "failed to create cookie jar")
|
||||||
|
}
|
||||||
|
|
||||||
|
unauthenticated := codersdk.New(client.URL)
|
||||||
|
unauthenticated.HTTPClient = &shallowCpyCli
|
||||||
|
|
||||||
|
return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithClient reuses the context of the passed in client. This means the same
|
||||||
|
// cookies will be used. This should be an unauthenticated client in most cases.
|
||||||
|
//
|
||||||
|
// This is a niche case, but it is needed for testing ConvertLoginType.
|
||||||
|
func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
coderOauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback")
|
||||||
|
require.NoError(t, err)
|
||||||
|
f.SetRedirect(t, coderOauthURL.String())
|
||||||
|
|
||||||
|
cli := f.HTTPClient(client.HTTPClient)
|
||||||
|
cli.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
|
// Store the idTokenClaims to the specific state request. This ties
|
||||||
|
// the claims 1:1 with a given authentication flow.
|
||||||
|
state := req.URL.Query().Get("state")
|
||||||
|
f.stateToIDTokenClaims.Store(state, idTokenClaims)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if cli.Jar == nil {
|
||||||
|
cli.Jar, err = cookiejar.New(nil)
|
||||||
|
require.NoError(t, err, "failed to create cookie jar")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := cli.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// If the coder session token exists, return the new authed client!
|
||||||
|
var user *codersdk.Client
|
||||||
|
cookies := cli.Jar.Cookies(client.URL)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
if cookie.Name == codersdk.SessionTokenCookie {
|
||||||
|
user = codersdk.New(client.URL)
|
||||||
|
user.SetSessionToken(cookie.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if res.Body != nil {
|
||||||
|
_ = res.Body.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return user, res
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDCCallback will emulate the IDP redirecting back to the Coder callback.
|
||||||
|
// This is helpful if no Coderd exists because the IDP needs to redirect to
|
||||||
|
// something.
|
||||||
|
// Essentially this is used to fake the Coderd side of the exchange.
|
||||||
|
// The flow starts at the user hitting the OIDC login page.
|
||||||
|
func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) {
|
||||||
|
t.Helper()
|
||||||
|
if f.serve {
|
||||||
|
panic("cannot use OIDCCallback with WithServing. This is only for the in memory usage")
|
||||||
|
}
|
||||||
|
|
||||||
|
f.stateToIDTokenClaims.Store(state, idTokenClaims)
|
||||||
|
|
||||||
|
cli := f.HTTPClient(nil)
|
||||||
|
u := f.cfg.AuthCodeURL(state)
|
||||||
|
req, err := http.NewRequest("GET", u, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, err := cli.Do(req.WithContext(context.Background()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if resp.Body != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type providerJSON struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
AuthURL string `json:"authorization_endpoint"`
|
||||||
|
TokenURL string `json:"token_endpoint"`
|
||||||
|
JWKSURL string `json:"jwks_uri"`
|
||||||
|
UserInfoURL string `json:"userinfo_endpoint"`
|
||||||
|
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCode enforces the code exchanged is actually a valid code
|
||||||
|
// created by the IDP.
|
||||||
|
func (f *FakeIDP) newCode(state string) string {
|
||||||
|
code := uuid.NewString()
|
||||||
|
f.codeToStateMap.Store(code, state)
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
|
||||||
|
// newToken enforces the access token exchanged is actually a valid access token
|
||||||
|
// created by the IDP.
|
||||||
|
func (f *FakeIDP) newToken(email string) string {
|
||||||
|
accessToken := uuid.NewString()
|
||||||
|
f.accessTokens.Store(accessToken, email)
|
||||||
|
return accessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeIDP) newRefreshTokens(email string) string {
|
||||||
|
refreshToken := uuid.NewString()
|
||||||
|
f.refreshTokens.Store(refreshToken, email)
|
||||||
|
return refreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticateBearerTokenRequest enforces the access token is valid.
|
||||||
|
func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
auth := req.Header.Get("Authorization")
|
||||||
|
token := strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
_, ok := f.accessTokens.Load(token)
|
||||||
|
if !ok {
|
||||||
|
return "", xerrors.New("invalid access token")
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authenticateOIDCClientRequest enforces the client_id and client_secret are valid.
|
||||||
|
func (f *FakeIDP) authenticateOIDCClientRequest(t testing.TB, req *http.Request) (url.Values, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if f.hookAuthenticateClient != nil {
|
||||||
|
return f.hookAuthenticateClient(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(req.Body)
|
||||||
|
if !assert.NoError(t, err, "read token request body") {
|
||||||
|
return nil, xerrors.Errorf("authenticate request, read body: %w", err)
|
||||||
|
}
|
||||||
|
values, err := url.ParseQuery(string(data))
|
||||||
|
if !assert.NoError(t, err, "parse token request values") {
|
||||||
|
return nil, xerrors.New("invalid token request")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") {
|
||||||
|
return nil, xerrors.New("client_id mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.Equal(t, f.clientSecret, values.Get("client_secret"), "client_secret mismatch") {
|
||||||
|
return nil, xerrors.New("client_secret mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
return values, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeClaims is a helper func to convert claims to a valid JWT.
|
||||||
|
func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if _, ok := claims["exp"]; !ok {
|
||||||
|
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := claims["aud"]; !ok {
|
||||||
|
claims["aud"] = f.clientID
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := claims["iss"]; !ok {
|
||||||
|
claims["iss"] = f.issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return signed
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpHandler is the IDP http server.
|
||||||
|
func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
mux := chi.NewMux()
|
||||||
|
// This endpoint is required to initialize the OIDC provider.
|
||||||
|
// It is used to get the OIDC configuration.
|
||||||
|
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
f.logger.Info(r.Context(), "http OIDC config", slog.F("url", r.URL.String()))
|
||||||
|
|
||||||
|
_ = json.NewEncoder(rw).Encode(f.provider)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Authorize is called when the user is redirected to the IDP to login.
|
||||||
|
// This is the browser hitting the IDP and the user logging into Google or
|
||||||
|
// w/e and clicking "Allow". They will be redirected back to the redirect
|
||||||
|
// when this is done.
|
||||||
|
mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
f.logger.Info(r.Context(), "http call authorize", slog.F("url", r.URL.String()))
|
||||||
|
|
||||||
|
clientID := r.URL.Query().Get("client_id")
|
||||||
|
if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") {
|
||||||
|
http.Error(rw, "invalid client_id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := r.URL.Query().Get("redirect_uri")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
|
scope := r.URL.Query().Get("scope")
|
||||||
|
assert.NotEmpty(t, scope, "scope is empty")
|
||||||
|
|
||||||
|
responseType := r.URL.Query().Get("response_type")
|
||||||
|
switch responseType {
|
||||||
|
case "code":
|
||||||
|
case "token":
|
||||||
|
t.Errorf("response_type %q not supported", responseType)
|
||||||
|
http.Error(rw, "invalid response_type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected response_type %q", responseType)
|
||||||
|
http.Error(rw, "invalid response_type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := f.hookValidRedirectURL(redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
|
||||||
|
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ru, err := url.Parse(redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error())
|
||||||
|
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
q := ru.Query()
|
||||||
|
q.Set("state", state)
|
||||||
|
q.Set("code", f.newCode(state))
|
||||||
|
ru.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect)
|
||||||
|
}))
|
||||||
|
|
||||||
|
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
values, err := f.authenticateOIDCClientRequest(t, r)
|
||||||
|
f.logger.Info(r.Context(), "http idp call token",
|
||||||
|
slog.Error(err),
|
||||||
|
slog.F("values", values.Encode()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
getEmail := func(claims jwt.MapClaims) string {
|
||||||
|
email, ok := claims["email"]
|
||||||
|
if !ok {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
emailStr, ok := email.(string)
|
||||||
|
if !ok {
|
||||||
|
return "wrong-type"
|
||||||
|
}
|
||||||
|
return emailStr
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims jwt.MapClaims
|
||||||
|
switch values.Get("grant_type") {
|
||||||
|
case "authorization_code":
|
||||||
|
code := values.Get("code")
|
||||||
|
if !assert.NotEmpty(t, code, "code is empty") {
|
||||||
|
http.Error(rw, "invalid code", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stateStr, ok := f.codeToStateMap.Load(code)
|
||||||
|
if !assert.True(t, ok, "invalid code") {
|
||||||
|
http.Error(rw, "invalid code", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Always invalidate the code after it is used.
|
||||||
|
f.codeToStateMap.Delete(code)
|
||||||
|
|
||||||
|
idTokenClaims, ok := f.stateToIDTokenClaims.Load(stateStr)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("missing id token claims")
|
||||||
|
http.Error(rw, "missing id token claims", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
claims = idTokenClaims
|
||||||
|
case "refresh_token":
|
||||||
|
refreshToken := values.Get("refresh_token")
|
||||||
|
if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") {
|
||||||
|
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := f.refreshTokens.Load(refreshToken)
|
||||||
|
if !assert.True(t, ok, "invalid refresh_token") {
|
||||||
|
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idTokenClaims, ok := f.refreshIDTokenClaims.Load(refreshToken)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("missing id token claims in refresh")
|
||||||
|
http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims = idTokenClaims
|
||||||
|
err := f.hookOnRefresh(getEmail(claims))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.refreshTokensUsed.Store(refreshToken, true)
|
||||||
|
// Always invalidate the refresh token after it is used.
|
||||||
|
f.refreshTokens.Delete(refreshToken)
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
|
||||||
|
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exp := time.Now().Add(time.Minute * 5)
|
||||||
|
claims["exp"] = exp.UnixMilli()
|
||||||
|
email := getEmail(claims)
|
||||||
|
refreshToken := f.newRefreshTokens(email)
|
||||||
|
token := map[string]interface{}{
|
||||||
|
"access_token": f.newToken(email),
|
||||||
|
"refresh_token": refreshToken,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": int64((time.Minute * 5).Seconds()),
|
||||||
|
"id_token": f.encodeClaims(t, claims),
|
||||||
|
}
|
||||||
|
// Store the claims for the next refresh
|
||||||
|
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||||
|
|
||||||
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(rw).Encode(token)
|
||||||
|
}))
|
||||||
|
|
||||||
|
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
token, err := f.authenticateBearerTokenRequest(t, r)
|
||||||
|
f.logger.Info(r.Context(), "http call idp user info",
|
||||||
|
slog.Error(err),
|
||||||
|
slog.F("url", r.URL.String()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
email, ok := f.accessTokens.Load(token)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("access token user for user_info has no email to indicate which user")
|
||||||
|
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
|
||||||
|
}))
|
||||||
|
|
||||||
|
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
f.logger.Info(r.Context(), "http call idp /keys")
|
||||||
|
set := jose.JSONWebKeySet{
|
||||||
|
Keys: []jose.JSONWebKey{
|
||||||
|
{
|
||||||
|
Key: f.key.Public(),
|
||||||
|
KeyID: "test-key",
|
||||||
|
Algorithm: "RSA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(rw).Encode(set)
|
||||||
|
}))
|
||||||
|
|
||||||
|
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
|
||||||
|
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
|
||||||
|
})
|
||||||
|
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient does nothing if IsServing is used.
|
||||||
|
//
|
||||||
|
// If IsServing is not used, then it will return a client that will make requests
|
||||||
|
// to the IDP all in memory. If a request is not to the IDP, then the passed in
|
||||||
|
// client will be used. If no client is passed in, then any regular network
|
||||||
|
// requests will fail.
|
||||||
|
func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
|
||||||
|
if f.serve {
|
||||||
|
if rest == nil || rest.Transport == nil {
|
||||||
|
return &http.Client{}
|
||||||
|
}
|
||||||
|
return rest
|
||||||
|
}
|
||||||
|
|
||||||
|
var jar http.CookieJar
|
||||||
|
if rest != nil {
|
||||||
|
jar = rest.Jar
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Jar: jar,
|
||||||
|
Transport: fakeRoundTripper{
|
||||||
|
roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||||
|
u, _ := url.Parse(f.issuer)
|
||||||
|
if req.URL.Host != u.Host {
|
||||||
|
if f.fakeCoderd != nil {
|
||||||
|
return f.fakeCoderd(req)
|
||||||
|
}
|
||||||
|
if rest == nil || rest.Transport == nil {
|
||||||
|
return nil, fmt.Errorf("unexpected network request to %q", req.URL.Host)
|
||||||
|
}
|
||||||
|
return rest.Transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
f.handler.ServeHTTP(resp, req)
|
||||||
|
return resp.Result(), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshUsed returns if the refresh token has been used. All refresh tokens
|
||||||
|
// can only be used once, then they are deleted.
|
||||||
|
func (f *FakeIDP) RefreshUsed(refreshToken string) bool {
|
||||||
|
used, _ := f.refreshTokensUsed.Load(refreshToken)
|
||||||
|
return used
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRefreshClaims allows the caller to change what claims are returned
|
||||||
|
// for a given refresh token. By default, all refreshes use the same claims as
|
||||||
|
// the original IDToken issuance.
|
||||||
|
func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) {
|
||||||
|
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRedirect is required for the IDP to know where to redirect and call
|
||||||
|
// Coderd.
|
||||||
|
func (f *FakeIDP) SetRedirect(t testing.TB, u string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f.cfg.RedirectURL = u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCoderdCallback is optional and only works if not using the IsServing.
|
||||||
|
// It will setup a fake "Coderd" for the IDP to call when the IDP redirects
|
||||||
|
// back after authenticating.
|
||||||
|
func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) {
|
||||||
|
if f.serve {
|
||||||
|
panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'")
|
||||||
|
}
|
||||||
|
f.fakeCoderd = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) {
|
||||||
|
f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(resp, req)
|
||||||
|
return resp.Result(), nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDCConfig returns the OIDC config to use for Coderd.
|
||||||
|
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||||
|
t.Helper()
|
||||||
|
if len(scopes) == 0 {
|
||||||
|
scopes = []string{"openid", "email", "profile"}
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthCfg := &oauth2.Config{
|
||||||
|
ClientID: f.clientID,
|
||||||
|
ClientSecret: f.clientSecret,
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: f.provider.AuthURL,
|
||||||
|
TokenURL: f.provider.TokenURL,
|
||||||
|
AuthStyle: oauth2.AuthStyleInParams,
|
||||||
|
},
|
||||||
|
// If the user is using a real network request, they will need to do
|
||||||
|
// 'fake.SetRedirect()'
|
||||||
|
RedirectURL: "https://redirect.com",
|
||||||
|
Scopes: scopes,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil))
|
||||||
|
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
|
||||||
|
require.NoError(t, err, "failed to create OIDC provider")
|
||||||
|
cfg := &coderd.OIDCConfig{
|
||||||
|
OAuth2Config: oauthCfg,
|
||||||
|
Provider: p,
|
||||||
|
Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{
|
||||||
|
PublicKeys: []crypto.PublicKey{f.key.Public()},
|
||||||
|
}, &oidc.Config{
|
||||||
|
ClientID: oauthCfg.ClientID,
|
||||||
|
SupportedSigningAlgs: []string{
|
||||||
|
"RS256",
|
||||||
|
},
|
||||||
|
// Todo: add support for Now()
|
||||||
|
}),
|
||||||
|
UsernameField: "preferred_username",
|
||||||
|
EmailField: "email",
|
||||||
|
AuthURLParams: map[string]string{"access_type": "offline"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.cfg = oauthCfg
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeRoundTripper struct {
|
||||||
|
roundTrip func(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f.roundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||||
|
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
|
||||||
|
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
|
||||||
|
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
|
||||||
|
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
|
||||||
|
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
|
||||||
|
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
|
||||||
|
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
|
||||||
|
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
|
||||||
|
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
|
||||||
|
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
|
||||||
|
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
|
||||||
|
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
|
||||||
|
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
|
||||||
|
-----END RSA PRIVATE KEY-----`
|
|
@ -0,0 +1,72 @@
|
||||||
|
package oidctest_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestFakeIDPBasicFlow tests the basic flow of the fake IDP.
|
||||||
|
// It is done all in memory with no actual network requests.
|
||||||
|
// nolint:bodyclose
|
||||||
|
func TestFakeIDPBasicFlow(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
fake := oidctest.NewFakeIDP(t,
|
||||||
|
oidctest.WithLogging(t, nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
var handler http.Handler
|
||||||
|
srv := httptest.NewServer(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
})))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
cfg := fake.OIDCConfig(t, nil)
|
||||||
|
cli := fake.HTTPClient(nil)
|
||||||
|
ctx := oidc.ClientContext(context.Background(), cli)
|
||||||
|
|
||||||
|
const expectedState = "random-state"
|
||||||
|
var token *oauth2.Token
|
||||||
|
// This is the Coder callback using an actual network request.
|
||||||
|
fake.SetCoderdCallbackHandler(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Emulate OIDC flow
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
assert.Equal(t, expectedState, state, "state mismatch")
|
||||||
|
|
||||||
|
oauthToken, err := cfg.Exchange(ctx, code)
|
||||||
|
if assert.NoError(t, err, "failed to exchange code") {
|
||||||
|
assert.NotEmpty(t, oauthToken.AccessToken, "access token is empty")
|
||||||
|
assert.NotEmpty(t, oauthToken.RefreshToken, "refresh token is empty")
|
||||||
|
}
|
||||||
|
token = oauthToken
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := fake.OIDCCallback(t, expectedState, jwt.MapClaims{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// Test the user info
|
||||||
|
_, err = cfg.Provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now test it can refresh
|
||||||
|
refreshed, err := cfg.TokenSource(ctx, &oauth2.Token{
|
||||||
|
AccessToken: token.AccessToken,
|
||||||
|
RefreshToken: token.RefreshToken,
|
||||||
|
Expiry: time.Now().Add(time.Minute * -1),
|
||||||
|
}).Token()
|
||||||
|
require.NoError(t, err, "failed to refresh token")
|
||||||
|
require.NotEmpty(t, refreshed.AccessToken, "access token is empty on refresh")
|
||||||
|
}
|
|
@ -215,7 +215,10 @@ func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenRes struct {
|
var tokenRes struct {
|
||||||
oauth2.Token
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type,omitempty"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
|
||||||
// Extra fields returned by the refresh that are needed
|
// Extra fields returned by the refresh that are needed
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
||||||
|
|
|
@ -12,12 +12,15 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd"
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||||
"github.com/coder/coder/v2/coderd/oauthpki"
|
"github.com/coder/coder/v2/coderd/oauthpki"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
@ -123,6 +126,58 @@ func TestAzureADPKIOIDC(t *testing.T) {
|
||||||
require.Error(t, err, "error expected")
|
require.Error(t, err, "error expected")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAzureAKPKIWithCoderd uses a fake IDP and a real Coderd to test PKI auth.
|
||||||
|
// nolint:bodyclose
|
||||||
|
func TestAzureAKPKIWithCoderd(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scopes := []string{"openid", "email", "profile", "offline_access"}
|
||||||
|
fake := oidctest.NewFakeIDP(t,
|
||||||
|
oidctest.WithIssuer("https://login.microsoftonline.com/fake_app"),
|
||||||
|
oidctest.WithCustomClientAuth(func(t testing.TB, req *http.Request) (url.Values, error) {
|
||||||
|
values := assertJWTAuth(t, req)
|
||||||
|
if values == nil {
|
||||||
|
return nil, xerrors.New("authorizatin failed in request")
|
||||||
|
}
|
||||||
|
return values, nil
|
||||||
|
}),
|
||||||
|
oidctest.WithServing(),
|
||||||
|
)
|
||||||
|
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
|
})
|
||||||
|
|
||||||
|
oauthCfg := cfg.OAuth2Config.(*oauth2.Config)
|
||||||
|
// Create the oauthpki config
|
||||||
|
pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
|
||||||
|
ClientID: oauthCfg.ClientID,
|
||||||
|
TokenURL: oauthCfg.Endpoint.TokenURL,
|
||||||
|
Scopes: scopes,
|
||||||
|
PemEncodedKey: []byte(testClientKey),
|
||||||
|
PemEncodedCert: []byte(testClientCert),
|
||||||
|
Config: oauthCfg,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
cfg.OAuth2Config = pki
|
||||||
|
|
||||||
|
owner, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||||
|
OIDCConfig: cfg,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a user and login
|
||||||
|
const email = "alice@coder.com"
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"email": email,
|
||||||
|
}
|
||||||
|
helper := oidctest.NewLoginHelper(owner, fake)
|
||||||
|
user, _ := helper.Login(t, claims)
|
||||||
|
|
||||||
|
// Try refreshing the token more than once.
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
helper.ForceRefresh(t, api.Database, user, claims)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
|
// TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
|
||||||
// AD instance and saving them to replay, removing some details.
|
// AD instance and saving them to replay, removing some details.
|
||||||
// The reason this is done is that this is the only way to assert values
|
// The reason this is done is that this is the only way to assert values
|
||||||
|
@ -269,7 +324,7 @@ func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
|
||||||
// assertJWTAuth will assert the basic JWT auth assertions. It will return the
|
// assertJWTAuth will assert the basic JWT auth assertions. It will return the
|
||||||
// url.Values from the request body for any additional assertions to be made.
|
// url.Values from the request body for any additional assertions to be made.
|
||||||
func assertJWTAuth(t *testing.T, r *http.Request) url.Values {
|
func assertJWTAuth(t testing.TB, r *http.Request) url.Values {
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if !assert.NoError(t, err) {
|
if !assert.NoError(t, err) {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -4,28 +4,25 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/google/go-github/v43/github"
|
"github.com/google/go-github/v43/github"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
"github.com/coder/coder/v2/coderd"
|
"github.com/coder/coder/v2/coderd"
|
||||||
"github.com/coder/coder/v2/coderd/audit"
|
"github.com/coder/coder/v2/coderd/audit"
|
||||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
||||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
@ -35,85 +32,42 @@ import (
|
||||||
// This test specifically tests logging in with OIDC when an expired
|
// This test specifically tests logging in with OIDC when an expired
|
||||||
// OIDC session token exists.
|
// OIDC session token exists.
|
||||||
// The token refreshing should not happen since we are reauthenticating.
|
// The token refreshing should not happen since we are reauthenticating.
|
||||||
|
// nolint:bodyclose
|
||||||
func TestOIDCOauthLoginWithExisting(t *testing.T) {
|
func TestOIDCOauthLoginWithExisting(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conf := coderdtest.NewOIDCConfig(t, "",
|
fake := oidctest.NewFakeIDP(t,
|
||||||
// Provide a refresh token so we use the refresh token flow
|
oidctest.WithRefreshHook(func(_ string) error {
|
||||||
coderdtest.WithRefreshToken("refresh_token"),
|
return xerrors.New("refreshing token should never occur")
|
||||||
// We need to set the expire in the future for the first api calls.
|
|
||||||
coderdtest.WithTokenExpires(func() time.Time {
|
|
||||||
return time.Now().Add(time.Hour).UTC()
|
|
||||||
}),
|
|
||||||
// No refresh should actually happen in this test.
|
|
||||||
coderdtest.WithTokenSource(func() (*oauth2.Token, error) {
|
|
||||||
return nil, xerrors.New("token should not require refresh")
|
|
||||||
}),
|
}),
|
||||||
|
oidctest.WithServing(),
|
||||||
)
|
)
|
||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
||||||
auditor := audit.NewMock()
|
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
|
cfg.IgnoreUserInfo = true
|
||||||
|
})
|
||||||
|
|
||||||
|
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||||
|
OIDCConfig: cfg,
|
||||||
|
})
|
||||||
|
|
||||||
const username = "alice"
|
const username = "alice"
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
"email": "alice@coder.com",
|
"email": "alice@coder.com",
|
||||||
"email_verified": true,
|
"email_verified": true,
|
||||||
"preferred_username": username,
|
"preferred_username": username,
|
||||||
}
|
}
|
||||||
config := conf.OIDCConfig(t, claims)
|
|
||||||
|
|
||||||
config.AllowSignups = true
|
|
||||||
config.IgnoreUserInfo = true
|
|
||||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
|
||||||
Auditor: auditor,
|
|
||||||
OIDCConfig: config,
|
|
||||||
Logger: &logger,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
helper := oidctest.NewLoginHelper(client, fake)
|
||||||
// Signup alice
|
// Signup alice
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, claims))
|
userClient, _ := helper.Login(t, claims)
|
||||||
// Set the client to use this OIDC context
|
|
||||||
authCookie := authCookieValue(resp.Cookies())
|
|
||||||
client.SetSessionToken(authCookie)
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitLong)
|
// Expire the link. This will force the client to refresh the token.
|
||||||
// Verify the user and oauth link
|
helper.ExpireOauthToken(t, api.Database, userClient)
|
||||||
user, err := client.User(ctx, "me")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, username, user.Username)
|
|
||||||
|
|
||||||
// nolint:gocritic
|
// Instead of refreshing, just log in again.
|
||||||
link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
helper.Login(t, claims)
|
||||||
UserID: user.ID,
|
|
||||||
LoginType: database.LoginType(user.LoginType),
|
|
||||||
})
|
|
||||||
require.NoError(t, err, "failed to get user link")
|
|
||||||
|
|
||||||
// Expire the link
|
|
||||||
// nolint:gocritic
|
|
||||||
_, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
|
||||||
OAuthAccessToken: link.OAuthAccessToken,
|
|
||||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
|
||||||
OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(),
|
|
||||||
UserID: link.UserID,
|
|
||||||
LoginType: link.LoginType,
|
|
||||||
})
|
|
||||||
require.NoError(t, err, "failed to update user link")
|
|
||||||
|
|
||||||
// Log in again with OIDC
|
|
||||||
loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) {
|
|
||||||
req.AddCookie(&http.Cookie{
|
|
||||||
Name: codersdk.SessionTokenCookie,
|
|
||||||
Value: authCookie,
|
|
||||||
Path: "/",
|
|
||||||
})
|
|
||||||
})
|
|
||||||
require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode)
|
|
||||||
_ = loginAgain.Body.Close()
|
|
||||||
|
|
||||||
// Try to use new login
|
|
||||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
|
||||||
_, err = client.User(ctx, "me")
|
|
||||||
require.NoError(t, err, "use new session")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserLogin(t *testing.T) {
|
func TestUserLogin(t *testing.T) {
|
||||||
|
@ -660,7 +614,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
"email": "kyle@kwc.io",
|
"email": "kyle@kwc.io",
|
||||||
},
|
},
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
Username: "kyle",
|
Username: "kyle",
|
||||||
}, {
|
}, {
|
||||||
Name: "EmailNotVerified",
|
Name: "EmailNotVerified",
|
||||||
|
@ -685,7 +639,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
"email_verified": false,
|
"email_verified": false,
|
||||||
},
|
},
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
Username: "kyle",
|
Username: "kyle",
|
||||||
IgnoreEmailVerified: true,
|
IgnoreEmailVerified: true,
|
||||||
}, {
|
}, {
|
||||||
|
@ -709,7 +663,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
EmailDomain: []string{
|
EmailDomain: []string{
|
||||||
"kwc.io",
|
"kwc.io",
|
||||||
},
|
},
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
Name: "EmptyClaims",
|
Name: "EmptyClaims",
|
||||||
IDTokenClaims: jwt.MapClaims{},
|
IDTokenClaims: jwt.MapClaims{},
|
||||||
|
@ -730,7 +684,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
},
|
},
|
||||||
Username: "kyle",
|
Username: "kyle",
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
Name: "UsernameFromClaims",
|
Name: "UsernameFromClaims",
|
||||||
IDTokenClaims: jwt.MapClaims{
|
IDTokenClaims: jwt.MapClaims{
|
||||||
|
@ -740,7 +694,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
},
|
},
|
||||||
Username: "hotdog",
|
Username: "hotdog",
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
// Services like Okta return the email as the username:
|
// Services like Okta return the email as the username:
|
||||||
// https://developer.okta.com/docs/reference/api/oidc/#base-claims-always-present
|
// https://developer.okta.com/docs/reference/api/oidc/#base-claims-always-present
|
||||||
|
@ -752,7 +706,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
},
|
},
|
||||||
Username: "kyle",
|
Username: "kyle",
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
// See: https://github.com/coder/coder/issues/4472
|
// See: https://github.com/coder/coder/issues/4472
|
||||||
Name: "UsernameIsEmail",
|
Name: "UsernameIsEmail",
|
||||||
|
@ -761,7 +715,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
},
|
},
|
||||||
Username: "kyle",
|
Username: "kyle",
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
Name: "WithPicture",
|
Name: "WithPicture",
|
||||||
IDTokenClaims: jwt.MapClaims{
|
IDTokenClaims: jwt.MapClaims{
|
||||||
|
@ -773,7 +727,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
Username: "kyle",
|
Username: "kyle",
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
AvatarURL: "/example.png",
|
AvatarURL: "/example.png",
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
Name: "WithUserInfoClaims",
|
Name: "WithUserInfoClaims",
|
||||||
IDTokenClaims: jwt.MapClaims{
|
IDTokenClaims: jwt.MapClaims{
|
||||||
|
@ -787,7 +741,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
Username: "potato",
|
Username: "potato",
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
AvatarURL: "/example.png",
|
AvatarURL: "/example.png",
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
Name: "GroupsDoesNothing",
|
Name: "GroupsDoesNothing",
|
||||||
IDTokenClaims: jwt.MapClaims{
|
IDTokenClaims: jwt.MapClaims{
|
||||||
|
@ -795,7 +749,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
"groups": []string{"pingpong"},
|
"groups": []string{"pingpong"},
|
||||||
},
|
},
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
Name: "UserInfoOverridesIDTokenClaims",
|
Name: "UserInfoOverridesIDTokenClaims",
|
||||||
IDTokenClaims: jwt.MapClaims{
|
IDTokenClaims: jwt.MapClaims{
|
||||||
|
@ -810,7 +764,7 @@ func TestUserOIDC(t *testing.T) {
|
||||||
Username: "user",
|
Username: "user",
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
IgnoreEmailVerified: false,
|
IgnoreEmailVerified: false,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}, {
|
}, {
|
||||||
Name: "InvalidUserInfo",
|
Name: "InvalidUserInfo",
|
||||||
IDTokenClaims: jwt.MapClaims{
|
IDTokenClaims: jwt.MapClaims{
|
||||||
|
@ -837,36 +791,41 @@ func TestUserOIDC(t *testing.T) {
|
||||||
Username: "user",
|
Username: "user",
|
||||||
IgnoreUserInfo: true,
|
IgnoreUserInfo: true,
|
||||||
AllowSignups: true,
|
AllowSignups: true,
|
||||||
StatusCode: http.StatusTemporaryRedirect,
|
StatusCode: http.StatusOK,
|
||||||
}} {
|
}} {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
fake := oidctest.NewFakeIDP(t,
|
||||||
|
oidctest.WithRefreshHook(func(_ string) error {
|
||||||
|
return xerrors.New("refreshing token should never occur")
|
||||||
|
}),
|
||||||
|
oidctest.WithServing(),
|
||||||
|
oidctest.WithStaticUserInfo(tc.UserInfoClaims),
|
||||||
|
)
|
||||||
|
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = tc.AllowSignups
|
||||||
|
cfg.EmailDomain = tc.EmailDomain
|
||||||
|
cfg.IgnoreEmailVerified = tc.IgnoreEmailVerified
|
||||||
|
cfg.IgnoreUserInfo = tc.IgnoreUserInfo
|
||||||
|
})
|
||||||
|
|
||||||
auditor := audit.NewMock()
|
auditor := audit.NewMock()
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
|
||||||
|
|
||||||
config := conf.OIDCConfig(t, tc.UserInfoClaims)
|
|
||||||
config.AllowSignups = tc.AllowSignups
|
|
||||||
config.EmailDomain = tc.EmailDomain
|
|
||||||
config.IgnoreEmailVerified = tc.IgnoreEmailVerified
|
|
||||||
config.IgnoreUserInfo = tc.IgnoreUserInfo
|
|
||||||
|
|
||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
owner := coderdtest.New(t, &coderdtest.Options{
|
||||||
Auditor: auditor,
|
Auditor: auditor,
|
||||||
OIDCConfig: config,
|
OIDCConfig: cfg,
|
||||||
Logger: &logger,
|
Logger: &logger,
|
||||||
})
|
})
|
||||||
numLogs := len(auditor.AuditLogs())
|
numLogs := len(auditor.AuditLogs())
|
||||||
|
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.IDTokenClaims))
|
client, resp := fake.AttemptLogin(t, owner, tc.IDTokenClaims)
|
||||||
numLogs++ // add an audit log for login
|
numLogs++ // add an audit log for login
|
||||||
assert.Equal(t, tc.StatusCode, resp.StatusCode)
|
require.Equal(t, tc.StatusCode, resp.StatusCode)
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitLong)
|
ctx := testutil.Context(t, testutil.WaitLong)
|
||||||
|
|
||||||
if tc.Username != "" {
|
if tc.Username != "" {
|
||||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
|
||||||
user, err := client.User(ctx, "me")
|
user, err := client.User(ctx, "me")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tc.Username, user.Username)
|
require.Equal(t, tc.Username, user.Username)
|
||||||
|
@ -877,7 +836,6 @@ func TestUserOIDC(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if tc.AvatarURL != "" {
|
if tc.AvatarURL != "" {
|
||||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
|
||||||
user, err := client.User(ctx, "me")
|
user, err := client.User(ctx, "me")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tc.AvatarURL, user.AvatarURL)
|
require.Equal(t, tc.AvatarURL, user.AvatarURL)
|
||||||
|
@ -890,26 +848,29 @@ func TestUserOIDC(t *testing.T) {
|
||||||
|
|
||||||
t.Run("OIDCConvert", func(t *testing.T) {
|
t.Run("OIDCConvert", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
auditor := audit.NewMock()
|
auditor := audit.NewMock()
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
fake := oidctest.NewFakeIDP(t,
|
||||||
|
oidctest.WithRefreshHook(func(_ string) error {
|
||||||
config := conf.OIDCConfig(t, nil)
|
return xerrors.New("refreshing token should never occur")
|
||||||
config.AllowSignups = true
|
}),
|
||||||
|
oidctest.WithServing(),
|
||||||
cfg := coderdtest.DeploymentValues(t)
|
)
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||||
Auditor: auditor,
|
cfg.AllowSignups = true
|
||||||
OIDCConfig: config,
|
|
||||||
DeploymentValues: cfg,
|
|
||||||
})
|
})
|
||||||
owner := coderdtest.CreateFirstUser(t, client)
|
|
||||||
|
|
||||||
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
|
Auditor: auditor,
|
||||||
|
OIDCConfig: cfg,
|
||||||
|
})
|
||||||
|
|
||||||
|
owner := coderdtest.CreateFirstUser(t, client)
|
||||||
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||||
|
|
||||||
code := conf.EncodeClaims(t, jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
"email": userData.Email,
|
"email": userData.Email,
|
||||||
})
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
user.HTTPClient.Jar, err = cookiejar.New(nil)
|
user.HTTPClient.Jar, err = cookiejar.New(nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -921,52 +882,58 @@ func TestUserOIDC(t *testing.T) {
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil)
|
fake.LoginWithClient(t, user, claims, func(r *http.Request) {
|
||||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
r.URL.RawQuery = url.Values{
|
||||||
|
"oidc_merge_state": {convertResponse.StateString},
|
||||||
|
}.Encode()
|
||||||
|
r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken())
|
||||||
|
cookies := user.HTTPClient.Jar.Cookies(r.URL)
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
r.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AlternateUsername", func(t *testing.T) {
|
t.Run("AlternateUsername", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
auditor := audit.NewMock()
|
auditor := audit.NewMock()
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
fake := oidctest.NewFakeIDP(t,
|
||||||
|
oidctest.WithRefreshHook(func(_ string) error {
|
||||||
config := conf.OIDCConfig(t, nil)
|
return xerrors.New("refreshing token should never occur")
|
||||||
config.AllowSignups = true
|
}),
|
||||||
|
oidctest.WithServing(),
|
||||||
|
)
|
||||||
|
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
|
})
|
||||||
|
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
Auditor: auditor,
|
Auditor: auditor,
|
||||||
OIDCConfig: config,
|
OIDCConfig: cfg,
|
||||||
})
|
})
|
||||||
numLogs := len(auditor.AuditLogs())
|
|
||||||
|
|
||||||
code := conf.EncodeClaims(t, jwt.MapClaims{
|
numLogs := len(auditor.AuditLogs())
|
||||||
|
claims := jwt.MapClaims{
|
||||||
"email": "jon@coder.com",
|
"email": "jon@coder.com",
|
||||||
})
|
}
|
||||||
resp := oidcCallback(t, client, code)
|
|
||||||
|
userClient, _ := fake.Login(t, client, claims)
|
||||||
numLogs++ // add an audit log for login
|
numLogs++ // add an audit log for login
|
||||||
|
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitLong)
|
ctx := testutil.Context(t, testutil.WaitLong)
|
||||||
|
user, err := userClient.User(ctx, "me")
|
||||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
|
||||||
user, err := client.User(ctx, "me")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "jon", user.Username)
|
require.Equal(t, "jon", user.Username)
|
||||||
|
|
||||||
// Pass a different subject field so that we prompt creating a
|
// Pass a different subject field so that we prompt creating a
|
||||||
// new user.
|
// new user
|
||||||
code = conf.EncodeClaims(t, jwt.MapClaims{
|
userClient, _ = fake.Login(t, client, jwt.MapClaims{
|
||||||
"email": "jon@example2.com",
|
"email": "jon@example2.com",
|
||||||
"sub": "diff",
|
"sub": "diff",
|
||||||
})
|
})
|
||||||
resp = oidcCallback(t, client, code)
|
|
||||||
numLogs++ // add an audit log for login
|
numLogs++ // add an audit log for login
|
||||||
|
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
user, err = userClient.User(ctx, "me")
|
||||||
|
|
||||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
|
||||||
user, err = client.User(ctx, "me")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")
|
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")
|
||||||
|
|
||||||
|
@ -977,45 +944,62 @@ func TestUserOIDC(t *testing.T) {
|
||||||
t.Run("Disabled", func(t *testing.T) {
|
t.Run("Disabled", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
client := coderdtest.New(t, nil)
|
client := coderdtest.New(t, nil)
|
||||||
resp := oidcCallback(t, client, "asdf")
|
oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp, err := client.HTTPClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("NoIDToken", func(t *testing.T) {
|
t.Run("NoIDToken", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
fake := oidctest.NewFakeIDP(t,
|
||||||
OIDCConfig: &coderd.OIDCConfig{
|
oidctest.WithRefreshHook(func(_ string) error {
|
||||||
OAuth2Config: &testutil.OAuth2Config{},
|
return xerrors.New("refreshing token should never occur")
|
||||||
},
|
}),
|
||||||
|
oidctest.WithServing(),
|
||||||
|
)
|
||||||
|
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := oidcCallback(t, client, "asdf")
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
|
OIDCConfig: cfg,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, resp := fake.AttemptLogin(t, client, jwt.MapClaims{})
|
||||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("BadVerify", func(t *testing.T) {
|
t.Run("BadVerify", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
verifier := oidc.NewVerifier("", &oidc.StaticKeySet{
|
badVerifier := oidc.NewVerifier("", &oidc.StaticKeySet{
|
||||||
PublicKeys: []crypto.PublicKey{},
|
PublicKeys: []crypto.PublicKey{},
|
||||||
}, &oidc.Config{})
|
}, &oidc.Config{})
|
||||||
provider := &oidc.Provider{}
|
badProvider := &oidc.Provider{}
|
||||||
|
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
fake := oidctest.NewFakeIDP(t,
|
||||||
OIDCConfig: &coderd.OIDCConfig{
|
oidctest.WithRefreshHook(func(_ string) error {
|
||||||
OAuth2Config: &testutil.OAuth2Config{
|
return xerrors.New("refreshing token should never occur")
|
||||||
Token: (&oauth2.Token{
|
}),
|
||||||
AccessToken: "token",
|
oidctest.WithServing(),
|
||||||
}).WithExtra(map[string]interface{}{
|
)
|
||||||
"id_token": "invalid",
|
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||||
}),
|
cfg.AllowSignups = true
|
||||||
},
|
cfg.Provider = badProvider
|
||||||
Provider: provider,
|
cfg.Verifier = badVerifier
|
||||||
Verifier: verifier,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := oidcCallback(t, client, "asdf")
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
|
OIDCConfig: cfg,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, resp := fake.AttemptLogin(t, client, jwt.MapClaims{})
|
||||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1146,36 +1130,6 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
|
|
||||||
return oidcCallbackWithState(t, client, code, "somestate", nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse
|
|
||||||
}
|
|
||||||
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=%s", code, state))
|
|
||||||
require.NoError(t, err)
|
|
||||||
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.AddCookie(&http.Cookie{
|
|
||||||
Name: codersdk.OAuth2StateCookie,
|
|
||||||
Value: state,
|
|
||||||
})
|
|
||||||
if modify != nil {
|
|
||||||
modify(req)
|
|
||||||
}
|
|
||||||
res, err := client.HTTPClient.Do(req)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer res.Body.Close()
|
|
||||||
data, err := io.ReadAll(res.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Log(string(data))
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func i64ptr(i int64) *int64 {
|
func i64ptr(i int64) *int64 {
|
||||||
return &i
|
return &i
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/coder/coder/v2/coderd"
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -403,6 +406,7 @@ func TestPostLogout(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:bodyclose
|
||||||
func TestPostUsers(t *testing.T) {
|
func TestPostUsers(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("NoAuth", func(t *testing.T) {
|
t.Run("NoAuth", func(t *testing.T) {
|
||||||
|
@ -593,15 +597,15 @@ func TestPostUsers(t *testing.T) {
|
||||||
t.Run("CreateOIDCLoginType", func(t *testing.T) {
|
t.Run("CreateOIDCLoginType", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
email := "another@user.org"
|
email := "another@user.org"
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
fake := oidctest.NewFakeIDP(t,
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{
|
oidctest.WithServing(),
|
||||||
"email": email,
|
)
|
||||||
|
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
})
|
})
|
||||||
config.AllowSignups = false
|
|
||||||
config.IgnoreUserInfo = true
|
|
||||||
|
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
OIDCConfig: config,
|
OIDCConfig: cfg,
|
||||||
})
|
})
|
||||||
first := coderdtest.CreateFirstUser(t, client)
|
first := coderdtest.CreateFirstUser(t, client)
|
||||||
|
|
||||||
|
@ -618,15 +622,9 @@ func TestPostUsers(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Try to log in with OIDC.
|
// Try to log in with OIDC.
|
||||||
userClient := codersdk.New(client.URL)
|
userClient, _ := fake.Login(t, client, jwt.MapClaims{
|
||||||
resp := oidcCallback(t, userClient, conf.EncodeClaims(t, jwt.MapClaims{
|
|
||||||
"email": email,
|
"email": email,
|
||||||
}))
|
})
|
||||||
require.Equal(t, resp.StatusCode, http.StatusTemporaryRedirect)
|
|
||||||
// Set the client to use this OIDC context
|
|
||||||
authCookie := authCookieValue(resp.Cookies())
|
|
||||||
userClient.SetSessionToken(authCookie)
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
|
|
||||||
found, err := userClient.User(ctx, "me")
|
found, err := userClient.User(ctx, "me")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
package syncmap
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// Map is a type safe sync.Map
|
||||||
|
type Map[K, V any] struct {
|
||||||
|
m sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func New[K, V any]() *Map[K, V] {
|
||||||
|
return &Map[K, V]{
|
||||||
|
m: sync.Map{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) Store(k K, v V) {
|
||||||
|
m.m.Store(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:forcetypeassert
|
||||||
|
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
|
||||||
|
v, ok := m.m.Load(key)
|
||||||
|
if !ok {
|
||||||
|
var empty V
|
||||||
|
return empty, false
|
||||||
|
}
|
||||||
|
return v.(V), ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) Delete(key K) {
|
||||||
|
m.m.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:forcetypeassert
|
||||||
|
func (m *Map[K, V]) LoadAndDelete(key K) (actual V, loaded bool) {
|
||||||
|
act, loaded := m.m.LoadAndDelete(key)
|
||||||
|
if !loaded {
|
||||||
|
var empty V
|
||||||
|
return empty, loaded
|
||||||
|
}
|
||||||
|
return act.(V), loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:forcetypeassert
|
||||||
|
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
||||||
|
act, loaded := m.m.LoadOrStore(key, value)
|
||||||
|
if !loaded {
|
||||||
|
var empty V
|
||||||
|
return empty, loaded
|
||||||
|
}
|
||||||
|
return act.(V), loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) CompareAndSwap(key K, old V, new V) bool {
|
||||||
|
return m.m.CompareAndSwap(key, old, new)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
|
||||||
|
return m.m.CompareAndDelete(key, old)
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:forcetypeassert
|
||||||
|
func (m *Map[K, V]) Swap(key K, value V) (previous any, loaded bool) {
|
||||||
|
previous, loaded = m.m.Swap(key, value)
|
||||||
|
if !loaded {
|
||||||
|
var empty V
|
||||||
|
return empty, loaded
|
||||||
|
}
|
||||||
|
return previous.(V), loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:forcetypeassert
|
||||||
|
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
|
||||||
|
m.m.Range(func(key, value interface{}) bool {
|
||||||
|
return f(key.(K), value.(V))
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,25 +1,22 @@
|
||||||
package coderd_test
|
package coderd_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/coderd"
|
"github.com/coder/coder/v2/coderd"
|
||||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||||
"github.com/coder/coder/v2/coderd/rbac"
|
"github.com/coder/coder/v2/coderd/rbac"
|
||||||
"github.com/coder/coder/v2/coderd/util/slice"
|
"github.com/coder/coder/v2/coderd/util/slice"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
coderden "github.com/coder/coder/v2/enterprise/coderd"
|
||||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
@ -31,128 +28,123 @@ func TestUserOIDC(t *testing.T) {
|
||||||
t.Run("RoleSync", func(t *testing.T) {
|
t.Run("RoleSync", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
// NoRoles is the "control group". It has claims with 0 roles
|
||||||
|
// assigned, and asserts that the user has no roles.
|
||||||
t.Run("NoRoles", func(t *testing.T) {
|
t.Run("NoRoles", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
oidcRoleName := "TemplateAuthor"
|
cfg.UserRoleField = "roles"
|
||||||
|
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
|
|
||||||
cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}
|
|
||||||
})
|
|
||||||
config.AllowSignups = true
|
|
||||||
config.UserRoleField = "roles"
|
|
||||||
|
|
||||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
|
||||||
Options: &coderdtest.Options{
|
|
||||||
OIDCConfig: config,
|
|
||||||
},
|
|
||||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
||||||
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
admin, err := client.User(ctx, "me")
|
claims := jwt.MapClaims{
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, admin.OrganizationIDs, 1)
|
|
||||||
|
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
|
||||||
"email": "alice@coder.com",
|
"email": "alice@coder.com",
|
||||||
}))
|
}
|
||||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
// Login a new client that signs up
|
||||||
user, err := client.User(ctx, "alice")
|
client, resp := runner.Login(t, claims)
|
||||||
require.NoError(t, err)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
// User should be in 0 groups.
|
||||||
require.Len(t, user.Roles, 0)
|
runner.AssertRoles(t, "alice", []string{})
|
||||||
roleNames := []string{}
|
// Force a refresh, and assert nothing has changes
|
||||||
require.ElementsMatch(t, roleNames, []string{})
|
runner.ForceRefresh(t, client, claims)
|
||||||
|
runner.AssertRoles(t, "alice", []string{})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("NewUserAndRemoveRoles", func(t *testing.T) {
|
// A user has some roles, then on an oauth refresh will lose said
|
||||||
|
// roles from an updated claim.
|
||||||
|
t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) {
|
||||||
|
// TODO: Implement new feature to update roles/groups on OIDC
|
||||||
|
// refresh tokens. https://github.com/coder/coder/issues/9312
|
||||||
|
t.Skip("Refreshing tokens does not update roles :(")
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
const oidcRoleName = "TemplateAuthor"
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
|
Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}},
|
||||||
oidcRoleName := "TemplateAuthor"
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
|
cfg.UserRoleField = "roles"
|
||||||
cfg.UserRoleMapping = map[string][]string{oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}}
|
cfg.UserRoleMapping = map[string][]string{
|
||||||
})
|
oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()},
|
||||||
config.AllowSignups = true
|
}
|
||||||
config.UserRoleField = "roles"
|
|
||||||
|
|
||||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
|
||||||
Options: &coderdtest.Options{
|
|
||||||
OIDCConfig: config,
|
|
||||||
},
|
|
||||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
||||||
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
admin, err := client.User(ctx, "me")
|
// User starts with the owner role
|
||||||
require.NoError(t, err)
|
client, resp := runner.Login(t, jwt.MapClaims{
|
||||||
require.Len(t, admin.OrganizationIDs, 1)
|
|
||||||
|
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
|
||||||
"email": "alice@coder.com",
|
"email": "alice@coder.com",
|
||||||
"roles": []string{"random", oidcRoleName, rbac.RoleOwner()},
|
"roles": []string{"random", oidcRoleName, rbac.RoleOwner()},
|
||||||
}))
|
})
|
||||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
_ = resp.Body.Close()
|
runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
|
||||||
user, err := client.User(ctx, "alice")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Len(t, user.Roles, 3)
|
// Now refresh the oauth, and check the roles are removed.
|
||||||
roleNames := []string{user.Roles[0].Name, user.Roles[1].Name, user.Roles[2].Name}
|
// Force a refresh, and assert nothing has changes
|
||||||
require.ElementsMatch(t, roleNames, []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
|
runner.ForceRefresh(t, client, jwt.MapClaims{
|
||||||
|
|
||||||
// Now remove the roles with a new oidc login
|
|
||||||
resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
|
||||||
"email": "alice@coder.com",
|
"email": "alice@coder.com",
|
||||||
"roles": []string{"random"},
|
"roles": []string{"random"},
|
||||||
}))
|
})
|
||||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
runner.AssertRoles(t, "alice", []string{})
|
||||||
_ = resp.Body.Close()
|
|
||||||
user, err = client.User(ctx, "alice")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Len(t, user.Roles, 0)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// A user has some roles, then on another oauth login will lose said
|
||||||
|
// roles from an updated claim.
|
||||||
|
t.Run("NewUserAndRemoveRolesOnReAuth", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const oidcRoleName = "TemplateAuthor"
|
||||||
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
|
Userinfo: jwt.MapClaims{oidcRoleName: []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()}},
|
||||||
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
|
cfg.UserRoleField = "roles"
|
||||||
|
cfg.UserRoleMapping = map[string][]string{
|
||||||
|
oidcRoleName: {rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// User starts with the owner role
|
||||||
|
_, resp := runner.Login(t, jwt.MapClaims{
|
||||||
|
"email": "alice@coder.com",
|
||||||
|
"roles": []string{"random", oidcRoleName, rbac.RoleOwner()},
|
||||||
|
})
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin(), rbac.RoleOwner()})
|
||||||
|
|
||||||
|
// Now login with oauth again, and check the roles are removed.
|
||||||
|
_, resp = runner.Login(t, jwt.MapClaims{
|
||||||
|
"email": "alice@coder.com",
|
||||||
|
"roles": []string{"random"},
|
||||||
|
})
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
runner.AssertRoles(t, "alice", []string{})
|
||||||
|
})
|
||||||
|
|
||||||
|
// All manual role updates should fail when role sync is enabled.
|
||||||
t.Run("BlockAssignRoles", func(t *testing.T) {
|
t.Run("BlockAssignRoles", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{})
|
cfg.UserRoleField = "roles"
|
||||||
config.AllowSignups = true
|
|
||||||
config.UserRoleField = "roles"
|
|
||||||
|
|
||||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
|
||||||
Options: &coderdtest.Options{
|
|
||||||
OIDCConfig: config,
|
|
||||||
},
|
|
||||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
||||||
Features: license.Features{codersdk.FeatureUserRoleManagement: 1},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
admin, err := client.User(ctx, "me")
|
_, resp := runner.Login(t, jwt.MapClaims{
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, admin.OrganizationIDs, 1)
|
|
||||||
|
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
|
||||||
"email": "alice@coder.com",
|
"email": "alice@coder.com",
|
||||||
"roles": []string{},
|
"roles": []string{},
|
||||||
}))
|
})
|
||||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
// Try to manually update user roles, even though controlled by oidc
|
// Try to manually update user roles, even though controlled by oidc
|
||||||
// role sync.
|
// role sync.
|
||||||
_, err = client.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
_, err := runner.AdminClient.UpdateUserRoles(ctx, "alice", codersdk.UpdateRoles{
|
||||||
Roles: []string{
|
Roles: []string{
|
||||||
rbac.RoleTemplateAdmin(),
|
rbac.RoleTemplateAdmin(),
|
||||||
},
|
},
|
||||||
|
@ -164,199 +156,211 @@ func TestUserOIDC(t *testing.T) {
|
||||||
|
|
||||||
t.Run("Groups", func(t *testing.T) {
|
t.Run("Groups", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
// Assigns does a simple test of assigning a user to a group based
|
||||||
|
// on the oidc claims.
|
||||||
t.Run("Assigns", func(t *testing.T) {
|
t.Run("Assigns", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitLong)
|
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
|
||||||
|
|
||||||
const groupClaim = "custom-groups"
|
const groupClaim = "custom-groups"
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
|
const groupName = "bingbong"
|
||||||
cfg.GroupField = groupClaim
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
})
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
config.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
|
cfg.GroupField = groupClaim
|
||||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
|
||||||
Options: &coderdtest.Options{
|
|
||||||
OIDCConfig: config,
|
|
||||||
},
|
|
||||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
||||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
admin, err := client.User(ctx, "me")
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
require.NoError(t, err)
|
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||||
require.Len(t, admin.OrganizationIDs, 1)
|
|
||||||
|
|
||||||
groupName := "bingbong"
|
|
||||||
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
|
||||||
Name: groupName,
|
Name: groupName,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, group.Members, 0)
|
require.Len(t, group.Members, 0)
|
||||||
|
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
_, resp := runner.Login(t, jwt.MapClaims{
|
||||||
"email": "colin@coder.com",
|
"email": "alice@coder.com",
|
||||||
groupClaim: []string{groupName},
|
groupClaim: []string{groupName},
|
||||||
}))
|
})
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
runner.AssertGroups(t, "alice", []string{groupName})
|
||||||
group, err = client.Group(ctx, group.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, group.Members, 1)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Tests the group mapping feature.
|
||||||
t.Run("AssignsMapped", func(t *testing.T) {
|
t.Run("AssignsMapped", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
const groupClaim = "custom-groups"
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
|
||||||
|
|
||||||
oidcGroupName := "pingpong"
|
const oidcGroupName = "pingpong"
|
||||||
coderGroupName := "bingbong"
|
const coderGroupName = "bingbong"
|
||||||
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName}
|
cfg.AllowSignups = true
|
||||||
})
|
cfg.GroupField = groupClaim
|
||||||
config.AllowSignups = true
|
cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName}
|
||||||
|
|
||||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
|
||||||
Options: &coderdtest.Options{
|
|
||||||
OIDCConfig: config,
|
|
||||||
},
|
|
||||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
||||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
admin, err := client.User(ctx, "me")
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
require.NoError(t, err)
|
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||||
require.Len(t, admin.OrganizationIDs, 1)
|
|
||||||
|
|
||||||
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
|
||||||
Name: coderGroupName,
|
Name: coderGroupName,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, group.Members, 0)
|
require.Len(t, group.Members, 0)
|
||||||
|
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
_, resp := runner.Login(t, jwt.MapClaims{
|
||||||
"email": "colin@coder.com",
|
"email": "alice@coder.com",
|
||||||
"groups": []string{oidcGroupName},
|
groupClaim: []string{oidcGroupName},
|
||||||
}))
|
})
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
runner.AssertGroups(t, "alice", []string{coderGroupName})
|
||||||
group, err = client.Group(ctx, group.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, group.Members, 1)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AddThenRemove", func(t *testing.T) {
|
// User is in a group, then on an oauth refresh will lose said
|
||||||
|
// group.
|
||||||
|
t.Run("AddThenRemoveOnRefresh", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitLong)
|
// TODO: Implement new feature to update roles/groups on OIDC
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
// refresh tokens. https://github.com/coder/coder/issues/9312
|
||||||
|
t.Skip("Refreshing tokens does not update groups :(")
|
||||||
|
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{})
|
const groupClaim = "custom-groups"
|
||||||
config.AllowSignups = true
|
const groupName = "bingbong"
|
||||||
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
client, firstUser := coderdenttest.New(t, &coderdenttest.Options{
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
Options: &coderdtest.Options{
|
cfg.AllowSignups = true
|
||||||
OIDCConfig: config,
|
cfg.GroupField = groupClaim
|
||||||
},
|
|
||||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
||||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
// Add some extra users/groups that should be asserted after.
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
// Adding this user as there was a bug that removing 1 user removed
|
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||||
// all users from the group.
|
|
||||||
_, extra := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
|
||||||
groupName := "bingbong"
|
|
||||||
group, err := client.CreateGroup(ctx, firstUser.OrganizationID, codersdk.CreateGroupRequest{
|
|
||||||
Name: groupName,
|
Name: groupName,
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "create group")
|
require.NoError(t, err)
|
||||||
|
require.Len(t, group.Members, 0)
|
||||||
|
|
||||||
group, err = client.PatchGroup(ctx, group.ID, codersdk.PatchGroupRequest{
|
client, resp := runner.Login(t, jwt.MapClaims{
|
||||||
AddUsers: []string{
|
"email": "alice@coder.com",
|
||||||
firstUser.UserID.String(),
|
groupClaim: []string{groupName},
|
||||||
extra.ID.String(),
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "patch group")
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
require.Len(t, group.Members, 2, "expect both members")
|
runner.AssertGroups(t, "alice", []string{groupName})
|
||||||
|
|
||||||
// Now add OIDC user into the group
|
// Refresh without the group claim
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
runner.ForceRefresh(t, client, jwt.MapClaims{
|
||||||
"email": "colin@coder.com",
|
"email": "alice@coder.com",
|
||||||
"groups": []string{groupName},
|
})
|
||||||
}))
|
runner.AssertGroups(t, "alice", []string{})
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
||||||
|
|
||||||
group, err = client.Group(ctx, group.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, group.Members, 3)
|
|
||||||
|
|
||||||
// Login to remove the OIDC user from the group
|
|
||||||
resp = oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
|
||||||
"email": "colin@coder.com",
|
|
||||||
"groups": []string{},
|
|
||||||
}))
|
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
||||||
|
|
||||||
group, err = client.Group(ctx, group.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, group.Members, 2)
|
|
||||||
var expected []uuid.UUID
|
|
||||||
for _, mem := range group.Members {
|
|
||||||
expected = append(expected, mem.ID)
|
|
||||||
}
|
|
||||||
require.ElementsMatchf(t, expected, []uuid.UUID{firstUser.UserID, extra.ID}, "expected members")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("AddThenRemoveOnReAuth", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const groupClaim = "custom-groups"
|
||||||
|
const groupName = "bingbong"
|
||||||
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
|
cfg.GroupField = groupClaim
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
group, err := runner.AdminClient.CreateGroup(ctx, runner.AdminUser.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
||||||
|
Name: groupName,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, group.Members, 0)
|
||||||
|
|
||||||
|
_, resp := runner.Login(t, jwt.MapClaims{
|
||||||
|
"email": "alice@coder.com",
|
||||||
|
groupClaim: []string{groupName},
|
||||||
|
})
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
runner.AssertGroups(t, "alice", []string{groupName})
|
||||||
|
|
||||||
|
// Refresh without the group claim
|
||||||
|
_, resp = runner.Login(t, jwt.MapClaims{
|
||||||
|
"email": "alice@coder.com",
|
||||||
|
})
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
runner.AssertGroups(t, "alice", []string{})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating groups where the claimed group does not exist.
|
||||||
t.Run("NoneMatch", func(t *testing.T) {
|
t.Run("NoneMatch", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitLong)
|
const groupClaim = "custom-groups"
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{})
|
cfg.AllowSignups = true
|
||||||
config.AllowSignups = true
|
cfg.GroupField = groupClaim
|
||||||
|
|
||||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{
|
|
||||||
Options: &coderdtest.Options{
|
|
||||||
OIDCConfig: config,
|
|
||||||
},
|
|
||||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
||||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
admin, err := client.User(ctx, "me")
|
_, resp := runner.Login(t, jwt.MapClaims{
|
||||||
require.NoError(t, err)
|
"email": "alice@coder.com",
|
||||||
require.Len(t, admin.OrganizationIDs, 1)
|
groupClaim: []string{"not-exists"},
|
||||||
|
|
||||||
groupName := "bingbong"
|
|
||||||
group, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
|
||||||
Name: groupName,
|
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
require.Len(t, group.Members, 0)
|
runner.AssertGroups(t, "alice", []string{})
|
||||||
|
})
|
||||||
|
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, jwt.MapClaims{
|
// Updating groups where the claimed group does not exist creates
|
||||||
"email": "colin@coder.com",
|
// the group.
|
||||||
"groups": []string{"coolin"},
|
t.Run("AutoCreate", func(t *testing.T) {
|
||||||
}))
|
t.Parallel()
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
||||||
|
|
||||||
group, err = client.Group(ctx, group.ID)
|
const groupClaim = "custom-groups"
|
||||||
require.NoError(t, err)
|
const groupName = "make-me"
|
||||||
require.Len(t, group.Members, 0)
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
|
cfg.GroupField = groupClaim
|
||||||
|
cfg.CreateMissingGroups = true
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, resp := runner.Login(t, jwt.MapClaims{
|
||||||
|
"email": "alice@coder.com",
|
||||||
|
groupClaim: []string{groupName},
|
||||||
|
})
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
runner.AssertGroups(t, "alice", []string{groupName})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Refresh", func(t *testing.T) {
|
||||||
|
t.Run("RefreshTokensMultiple", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.AllowSignups = true
|
||||||
|
cfg.UserRoleField = "roles"
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"email": "alice@coder.com",
|
||||||
|
}
|
||||||
|
// Login a new client that signs up
|
||||||
|
client, resp := runner.Login(t, claims)
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// Refresh multiple times.
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
runner.ForceRefresh(t, client, claims)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:bodyclose
|
||||||
func TestGroupSync(t *testing.T) {
|
func TestGroupSync(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -470,28 +474,20 @@ func TestGroupSync(t *testing.T) {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx := testutil.Context(t, testutil.WaitLong)
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
conf := coderdtest.NewOIDCConfig(t, "")
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
|
cfg.GroupField = "groups"
|
||||||
config := conf.OIDCConfig(t, jwt.MapClaims{}, tc.modCfg)
|
tc.modCfg(cfg)
|
||||||
|
|
||||||
client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
|
||||||
Options: &coderdtest.Options{
|
|
||||||
OIDCConfig: config,
|
|
||||||
},
|
|
||||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
||||||
Features: license.Features{codersdk.FeatureTemplateRBAC: 1},
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
admin, err := client.User(ctx, "me")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, admin.OrganizationIDs, 1)
|
|
||||||
|
|
||||||
// Setup
|
// Setup
|
||||||
|
ctx := testutil.Context(t, testutil.WaitLong)
|
||||||
|
org := runner.AdminUser.OrganizationIDs[0]
|
||||||
|
|
||||||
initialGroups := make(map[string]codersdk.Group)
|
initialGroups := make(map[string]codersdk.Group)
|
||||||
for _, group := range tc.initialOrgGroups {
|
for _, group := range tc.initialOrgGroups {
|
||||||
newGroup, err := client.CreateGroup(ctx, admin.OrganizationIDs[0], codersdk.CreateGroupRequest{
|
newGroup, err := runner.AdminClient.CreateGroup(ctx, org, codersdk.CreateGroupRequest{
|
||||||
Name: group,
|
Name: group,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -500,16 +496,16 @@ func TestGroupSync(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the user and add them to their initial groups
|
// Create the user and add them to their initial groups
|
||||||
_, user := coderdtest.CreateAnotherUser(t, client, admin.OrganizationIDs[0])
|
_, user := coderdtest.CreateAnotherUser(t, runner.AdminClient, org)
|
||||||
for _, group := range tc.initialUserGroups {
|
for _, group := range tc.initialUserGroups {
|
||||||
_, err := client.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{
|
_, err := runner.AdminClient.PatchGroup(ctx, initialGroups[group].ID, codersdk.PatchGroupRequest{
|
||||||
AddUsers: []string{user.ID.String()},
|
AddUsers: []string{user.ID.String()},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
_, err = api.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{
|
_, err := runner.API.Database.UpdateUserLoginType(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLoginTypeParams{
|
||||||
NewLoginType: database.LoginTypeOIDC,
|
NewLoginType: database.LoginTypeOIDC,
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
})
|
})
|
||||||
|
@ -517,11 +513,11 @@ func TestGroupSync(t *testing.T) {
|
||||||
|
|
||||||
// Log in the new user
|
// Log in the new user
|
||||||
tc.claims["email"] = user.Email
|
tc.claims["email"] = user.Email
|
||||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.claims))
|
_, resp := runner.Login(t, tc.claims)
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
_ = resp.Body.Close()
|
|
||||||
|
|
||||||
orgGroups, err := client.GroupsByOrganization(ctx, admin.OrganizationIDs[0])
|
// Check group sources
|
||||||
|
orgGroups, err := runner.AdminClient.GroupsByOrganization(ctx, org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, group := range orgGroups {
|
for _, group := range orgGroups {
|
||||||
|
@ -567,24 +563,107 @@ func TestGroupSync(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
|
// oidcTestRunner is just a helper to setup and run oidc tests.
|
||||||
t.Helper()
|
// An actual Coderd instance is used to run the tests.
|
||||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
type oidcTestRunner struct {
|
||||||
return http.ErrUseLastResponse
|
AdminClient *codersdk.Client
|
||||||
}
|
AdminUser codersdk.User
|
||||||
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code))
|
API *coderden.API
|
||||||
require.NoError(t, err)
|
|
||||||
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
|
// Login will call the OIDC flow with an unauthenticated client.
|
||||||
require.NoError(t, err)
|
// The IDP will return the idToken claims.
|
||||||
req.AddCookie(&http.Cookie{
|
Login func(t *testing.T, idToken jwt.MapClaims) (*codersdk.Client, *http.Response)
|
||||||
Name: codersdk.OAuth2StateCookie,
|
// ForceRefresh will use an authenticated codersdk.Client, and force their
|
||||||
Value: "somestate",
|
// OIDC token to be expired and require a refresh. The refresh will use the claims provided.
|
||||||
})
|
// It just calls the /users/me endpoint to trigger the refresh.
|
||||||
res, err := client.HTTPClient.Do(req)
|
ForceRefresh func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims)
|
||||||
require.NoError(t, err)
|
}
|
||||||
defer res.Body.Close()
|
|
||||||
data, err := io.ReadAll(res.Body)
|
type oidcTestConfig struct {
|
||||||
require.NoError(t, err)
|
Userinfo jwt.MapClaims
|
||||||
t.Log(string(data))
|
|
||||||
return res
|
// Config allows modifying the Coderd OIDC configuration.
|
||||||
|
Config func(cfg *coderd.OIDCConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oidcTestRunner) AssertRoles(t *testing.T, userIdent string, roles []string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
user, err := r.AdminClient.User(ctx, userIdent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
roleNames := []string{}
|
||||||
|
for _, role := range user.Roles {
|
||||||
|
roleNames = append(roleNames, role.Name)
|
||||||
|
}
|
||||||
|
require.ElementsMatch(t, roles, roleNames, "expected roles")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *oidcTestRunner) AssertGroups(t *testing.T, userIdent string, groups []string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if !slice.Contains(groups, database.EveryoneGroup) {
|
||||||
|
var cpy []string
|
||||||
|
cpy = append(cpy, groups...)
|
||||||
|
// always include everyone group
|
||||||
|
cpy = append(cpy, database.EveryoneGroup)
|
||||||
|
groups = cpy
|
||||||
|
}
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
user, err := r.AdminClient.User(ctx, userIdent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
allGroups, err := r.AdminClient.GroupsByOrganization(ctx, user.OrganizationIDs[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userInGroups := []string{}
|
||||||
|
for _, g := range allGroups {
|
||||||
|
for _, mem := range g.Members {
|
||||||
|
if mem.ID == user.ID {
|
||||||
|
userInGroups = append(userInGroups, g.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.ElementsMatch(t, groups, userInGroups, "expected groups")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupOIDCTest(t *testing.T, settings oidcTestConfig) *oidcTestRunner {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
fake := oidctest.NewFakeIDP(t,
|
||||||
|
oidctest.WithStaticUserInfo(settings.Userinfo),
|
||||||
|
oidctest.WithLogging(t, nil),
|
||||||
|
// Run fake IDP on a real webserver
|
||||||
|
oidctest.WithServing(),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
cfg := fake.OIDCConfig(t, nil, settings.Config)
|
||||||
|
owner, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||||
|
Options: &coderdtest.Options{
|
||||||
|
OIDCConfig: cfg,
|
||||||
|
},
|
||||||
|
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||||
|
Features: license.Features{
|
||||||
|
codersdk.FeatureUserRoleManagement: 1,
|
||||||
|
codersdk.FeatureTemplateRBAC: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
admin, err := owner.User(ctx, "me")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
helper := oidctest.NewLoginHelper(owner, fake)
|
||||||
|
|
||||||
|
return &oidcTestRunner{
|
||||||
|
AdminClient: owner,
|
||||||
|
AdminUser: admin,
|
||||||
|
API: api,
|
||||||
|
Login: helper.Login,
|
||||||
|
ForceRefresh: func(t *testing.T, client *codersdk.Client, idToken jwt.MapClaims) {
|
||||||
|
helper.ForceRefresh(t, api.Database, client, idToken)
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue