mirror of https://github.com/coder/coder.git
fix: make 'NoRefresh' honor unlimited tokens in gitauth (#9472)
* chore: fix NoRefresh to honor unlimited tokens * improve testing coverage of gitauth * refactor rest of gitauth tests
This commit is contained in:
parent
da0ef92f77
commit
58f7071569
|
@ -7,6 +7,7 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -41,7 +42,7 @@ import (
|
|||
type FakeIDP struct {
|
||||
issuer string
|
||||
key *rsa.PrivateKey
|
||||
provider providerJSON
|
||||
provider ProviderJSON
|
||||
handler http.Handler
|
||||
cfg *oauth2.Config
|
||||
|
||||
|
@ -66,7 +67,7 @@ type FakeIDP struct {
|
|||
// IDP -> Application. Almost all IDPs have the concept of
|
||||
// "Authorized Redirect URLs". This can be used to emulate that.
|
||||
hookValidRedirectURL func(redirectURL string) error
|
||||
hookUserInfo func(email string) jwt.MapClaims
|
||||
hookUserInfo func(email string) (jwt.MapClaims, error)
|
||||
fakeCoderd func(req *http.Request) (*http.Response, error)
|
||||
hookOnRefresh func(email string) error
|
||||
// Custom authentication for the client. This is useful if you want
|
||||
|
@ -75,6 +76,26 @@ type FakeIDP struct {
|
|||
serve bool
|
||||
}
|
||||
|
||||
func StatusError(code int, err error) error {
|
||||
return statusHookError{
|
||||
Err: err,
|
||||
HTTPStatusCode: code,
|
||||
}
|
||||
}
|
||||
|
||||
// statusHookError allows a hook to change the returned http status code.
|
||||
type statusHookError struct {
|
||||
Err error
|
||||
HTTPStatusCode int
|
||||
}
|
||||
|
||||
func (s statusHookError) Error() string {
|
||||
if s.Err == nil {
|
||||
return ""
|
||||
}
|
||||
return s.Err.Error()
|
||||
}
|
||||
|
||||
type FakeIDPOpt func(idp *FakeIDP)
|
||||
|
||||
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
|
||||
|
@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
|
|||
}
|
||||
}
|
||||
|
||||
// WithRefreshHook is called when a refresh token is used. The email is
|
||||
// WithRefresh is called when a refresh token is used. The email is
|
||||
// the email of the user that is being refreshed assuming the claims are correct.
|
||||
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
|
||||
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookOnRefresh = hook
|
||||
}
|
||||
|
@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
|
|||
// every user on the /userinfo endpoint.
|
||||
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookUserInfo = func(_ string) jwt.MapClaims {
|
||||
return info
|
||||
f.hookUserInfo = func(_ string) (jwt.MapClaims, error) {
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
|
||||
func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookUserInfo = userInfoFunc
|
||||
}
|
||||
|
@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
|||
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
hookOnRefresh: func(_ string) error { return nil },
|
||||
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
|
||||
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
|
||||
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||
}
|
||||
|
||||
|
@ -181,6 +202,10 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
|||
return idp
|
||||
}
|
||||
|
||||
func (f *FakeIDP) WellknownConfig() ProviderJSON {
|
||||
return f.provider
|
||||
}
|
||||
|
||||
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
t.Helper()
|
||||
|
||||
|
@ -188,9 +213,9 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
|||
require.NoError(t, err, "invalid issuer URL")
|
||||
|
||||
f.issuer = issuer
|
||||
// providerJSON is the JSON representation of the OpenID Connect provider
|
||||
// ProviderJSON is the JSON representation of the OpenID Connect provider
|
||||
// These are all the urls that the IDP will respond to.
|
||||
f.provider = providerJSON{
|
||||
f.provider = ProviderJSON{
|
||||
Issuer: issuer,
|
||||
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||
|
@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
|||
return srv
|
||||
}
|
||||
|
||||
// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
|
||||
// valid token for some given claims.
|
||||
func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) {
|
||||
state := uuid.NewString()
|
||||
f.stateToIDTokenClaims.Store(state, claims)
|
||||
code := f.newCode(state)
|
||||
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
|
||||
}
|
||||
|
||||
// Login does the full OIDC flow starting at the "LoginButton".
|
||||
// The client argument is just to get the URL of the Coder instance.
|
||||
//
|
||||
|
@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
type providerJSON struct {
|
||||
// ProviderJSON is the .well-known/configuration JSON
|
||||
type ProviderJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
|
@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
|||
err := f.hookValidRedirectURL(redirectURI)
|
||||
if err != nil {
|
||||
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
|
||||
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
|
||||
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
|||
slog.F("values", values.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
|
||||
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
getEmail := func(claims jwt.MapClaims) string {
|
||||
|
@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
|||
claims = idTokenClaims
|
||||
err := f.hookOnRefresh(getEmail(claims))
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
|
||||
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
|||
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
|
||||
claims, err := f.hookUserInfo(email)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(rw).Encode(claims)
|
||||
}))
|
||||
|
||||
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
|
|||
return cfg
|
||||
}
|
||||
|
||||
func httpErrorCode(defaultCode int, err error) int {
|
||||
var stautsErr statusHookError
|
||||
status := defaultCode
|
||||
if errors.As(err, &stautsErr) {
|
||||
status = stautsErr.HTTPStatusCode
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct {
|
||||
roundTrip func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
|
|
@ -60,17 +60,30 @@ type Config struct {
|
|||
}
|
||||
|
||||
// RefreshToken automatically refreshes the token if expired and permitted.
|
||||
// It returns the token and a bool indicating if the token was refreshed.
|
||||
// It returns the token and a bool indicating if the token is valid.
|
||||
func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
|
||||
// If the token is expired and refresh is disabled, we prompt
|
||||
// the user to authenticate again.
|
||||
if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(dbtime.Now()) {
|
||||
if c.NoRefresh &&
|
||||
// If the time is set to 0, then it should never expire.
|
||||
// This is true for github, which has no expiry.
|
||||
!gitAuthLink.OAuthExpiry.IsZero() &&
|
||||
gitAuthLink.OAuthExpiry.Before(dbtime.Now()) {
|
||||
return gitAuthLink, false, nil
|
||||
}
|
||||
|
||||
// This is additional defensive programming. Because TokenSource is an interface,
|
||||
// we cannot be sure that the implementation will treat an 'IsZero' time
|
||||
// as "not-expired". The default implementation does, but a custom implementation
|
||||
// might not. Removing the refreshToken will guarantee a refresh will fail.
|
||||
refreshToken := gitAuthLink.OAuthRefreshToken
|
||||
if c.NoRefresh {
|
||||
refreshToken = ""
|
||||
}
|
||||
|
||||
token, err := c.TokenSource(ctx, &oauth2.Token{
|
||||
AccessToken: gitAuthLink.OAuthAccessToken,
|
||||
RefreshToken: gitAuthLink.OAuthRefreshToken,
|
||||
RefreshToken: refreshToken,
|
||||
Expiry: gitAuthLink.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
|
@ -130,8 +143,13 @@ func (c *Config) ValidateToken(ctx context.Context, token string) (bool, *coders
|
|||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
cli := http.DefaultClient
|
||||
if v, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
||||
cli = v
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
res, err := cli.Do(req)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
|
|
@ -3,18 +3,22 @@ package gitauth_test
|
|||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/gitauth"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
|
@ -22,17 +26,70 @@ import (
|
|||
|
||||
func TestRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("FalseIfNoRefresh", func(t *testing.T) {
|
||||
const providerID = "test-idp"
|
||||
expired := time.Now().Add(time.Hour * -1)
|
||||
|
||||
t.Run("NoRefreshExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
NoRefresh: true,
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
|
||||
OAuthExpiry: time.Time{},
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
t.Error("refresh on the IDP was called, but NoRefresh was set")
|
||||
return xerrors.New("should not be called")
|
||||
}),
|
||||
// The IDP should not be contacted since the token is expired. An expired
|
||||
// token with 'NoRefresh' should early abort.
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
t.Error("token was validated, but it was expired and this should never have happened.")
|
||||
return nil, xerrors.New("should not be called")
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.NoRefresh = true
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
// Expire the link
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, refreshed, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.False(t, refreshed)
|
||||
})
|
||||
|
||||
// NoRefreshNoExpiry tests that an oauth token without an expiry is always valid.
|
||||
// The "validate url" should be hit, but the refresh endpoint should not.
|
||||
t.Run("NoRefreshNoExpiry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validated := false
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
t.Error("refresh on the IDP was called, but NoRefresh was set")
|
||||
return xerrors.New("should not be called")
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validated = true
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.NoRefresh = true
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
|
||||
// Zero time used
|
||||
link.OAuthExpiry = time.Time{}
|
||||
_, refreshed, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, refreshed, "token without expiry is always valid")
|
||||
require.True(t, validated, "token should have been validated")
|
||||
})
|
||||
|
||||
t.Run("FalseIfTokenSourceFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
|
@ -42,111 +99,167 @@ func TestRefreshToken(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
|
||||
OAuthExpiry: expired,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, refreshed)
|
||||
})
|
||||
|
||||
t.Run("ValidateServerError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("Failure"))
|
||||
}))
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
require.ErrorContains(t, err, "Failure")
|
||||
|
||||
const staticError = "static error"
|
||||
validated := false
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validated = true
|
||||
return jwt.MapClaims{}, xerrors.New(staticError)
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, _, err := config.RefreshToken(ctx, nil, link)
|
||||
require.ErrorContains(t, err, staticError)
|
||||
require.True(t, validated, "token should have been attempted to be validated")
|
||||
})
|
||||
|
||||
// ValidateFailure tests if the token is no longer valid with a 401 response.
|
||||
t.Run("ValidateFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Not permitted"))
|
||||
}))
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
require.NoError(t, err)
|
||||
|
||||
const staticError = "static error"
|
||||
validated := false
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validated = true
|
||||
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, refreshed, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err, staticError)
|
||||
require.False(t, refreshed)
|
||||
require.True(t, validated, "token should have been attempted to be validated")
|
||||
})
|
||||
|
||||
t.Run("ValidateRetryGitHub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hit := false
|
||||
// We need to ensure that the exponential backoff kicks in properly.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !hit {
|
||||
hit = true
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Not permitted"))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
config := &gitauth.Config{
|
||||
ID: "test",
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "updated",
|
||||
},
|
||||
|
||||
const staticError = "static error"
|
||||
validateCalls := 0
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
t.Error("refresh on the IDP was called, but the token is not expired")
|
||||
return xerrors.New("should not be called")
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validateCalls++
|
||||
// Make the first call return a 401, subsequent calls should return a 200.
|
||||
if validateCalls > 1 {
|
||||
return jwt.MapClaims{}, nil
|
||||
}
|
||||
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.Type = codersdk.GitProviderGitHub
|
||||
},
|
||||
ValidateURL: srv.URL,
|
||||
Type: codersdk.GitProviderGitHub,
|
||||
}
|
||||
db := dbfake.New()
|
||||
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
ProviderID: config.ID,
|
||||
OAuthAccessToken: "initial",
|
||||
})
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), db, link)
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
// Unlimited lifetime, this is what GitHub returns tokens as
|
||||
link.OAuthExpiry = time.Time{}
|
||||
|
||||
_, ok, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, refreshed)
|
||||
require.True(t, hit)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 2, validateCalls, "token should have been attempted to be validated more than once")
|
||||
})
|
||||
|
||||
t.Run("ValidateNoUpdate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
validated := make(chan struct{})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
close(validated)
|
||||
}))
|
||||
accessToken := "testing"
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: accessToken,
|
||||
},
|
||||
|
||||
validateCalls := 0
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
t.Error("refresh on the IDP was called, but the token is not expired")
|
||||
return xerrors.New("should not be called")
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validateCalls++
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.Type = codersdk.GitProviderGitHub
|
||||
},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
|
||||
OAuthAccessToken: accessToken,
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
|
||||
_, ok, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, valid)
|
||||
<-validated
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 1, validateCalls, "token is validated")
|
||||
})
|
||||
|
||||
// A token update comes from a refresh.
|
||||
t.Run("Updates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
ID: "test",
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "updated",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db := dbfake.New()
|
||||
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
ProviderID: config.ID,
|
||||
OAuthAccessToken: "initial",
|
||||
validateCalls := 0
|
||||
refreshCalls := 0
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
refreshCalls++
|
||||
return nil
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validateCalls++
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.Type = codersdk.GitProviderGitHub
|
||||
},
|
||||
DB: db,
|
||||
})
|
||||
_, valid, err := config.RefreshToken(context.Background(), db, link)
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
// Force a refresh
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
updated, ok, err := config.RefreshToken(ctx, db, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, valid)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 1, validateCalls, "token is validated")
|
||||
require.Equal(t, 1, refreshCalls, "token is refreshed")
|
||||
require.NotEqualf(t, link.OAuthAccessToken, updated.OAuthAccessToken, "token is updated")
|
||||
//nolint:gocritic // testing
|
||||
dbLink, err := db.GetGitAuthLink(dbauthz.AsSystemRestricted(context.Background()), database.GetGitAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB")
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -232,3 +345,65 @@ func TestConvertYAML(t *testing.T) {
|
|||
require.Equal(t, "https://auth.com?client_id=id&redirect_uri=%2Fgitauth%2Fgitlab%2Fcallback&response_type=code&scope=read", config[0].AuthCodeURL(""))
|
||||
})
|
||||
}
|
||||
|
||||
type testConfig struct {
|
||||
FakeIDPOpts []oidctest.FakeIDPOpt
|
||||
CoderOIDCConfigOpts []func(cfg *coderd.OIDCConfig)
|
||||
GitConfigOpt func(cfg *gitauth.Config)
|
||||
// If DB is passed in, the link will be inserted into the DB.
|
||||
DB database.Store
|
||||
}
|
||||
|
||||
// setupTest will configure a fake IDP and a gitauth.Config for testing.
|
||||
// The Fake's userinfo endpoint is used for validating tokens.
|
||||
// No http servers are started so use the fake IDP's HTTPClient to make requests.
|
||||
// The returned token is a fully valid token for the IDP. Feel free to manipulate it
|
||||
// to test different scenarios.
|
||||
func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *gitauth.Config, database.GitAuthLink) {
|
||||
t.Helper()
|
||||
|
||||
const providerID = "test-idp"
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
append([]oidctest.FakeIDPOpt{}, settings.FakeIDPOpts...)...,
|
||||
)
|
||||
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: fake.OIDCConfig(t, nil, settings.CoderOIDCConfigOpts...),
|
||||
ID: providerID,
|
||||
ValidateURL: fake.WellknownConfig().UserInfoURL,
|
||||
}
|
||||
settings.GitConfigOpt(config)
|
||||
|
||||
oauthToken, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
|
||||
"email": "test@coder.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
now := time.Now()
|
||||
link := database.GitAuthLink{
|
||||
ProviderID: providerID,
|
||||
UserID: uuid.New(),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
OAuthAccessToken: oauthToken.AccessToken,
|
||||
OAuthRefreshToken: oauthToken.RefreshToken,
|
||||
// The caller can manually expire this if they want.
|
||||
OAuthExpiry: now.Add(time.Hour),
|
||||
}
|
||||
|
||||
if settings.DB != nil {
|
||||
// Feel free to insert additional things like the user, etc if required.
|
||||
link, err = settings.DB.InsertGitAuthLink(context.Background(), database.InsertGitAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
CreatedAt: link.CreatedAt,
|
||||
UpdatedAt: link.UpdatedAt,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
})
|
||||
require.NoError(t, err, "failed to insert link into DB")
|
||||
}
|
||||
|
||||
return fake, config, link
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ func TestOIDCOauthLoginWithExisting(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
|
@ -797,7 +797,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
|
@ -851,7 +851,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
|
||||
auditor := audit.NewMock()
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
|
@ -898,7 +898,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
t.Parallel()
|
||||
auditor := audit.NewMock()
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
|
@ -959,7 +959,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
t.Run("NoIDToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
|
@ -984,7 +984,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
badProvider := &oidc.Provider{}
|
||||
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
|
|
|
@ -365,7 +365,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
|
||||
runner := setupOIDCTest(t, oidcTestConfig{
|
||||
FakeOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefreshHook(func(_ string) error {
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
// Always "expired" refresh token.
|
||||
return xerrors.New("refresh token is expired")
|
||||
}),
|
||||
|
|
Loading…
Reference in New Issue