mirror of https://github.com/coder/coder.git
chore: remove duplicate validate calls on same oauth token (#11598)
* chore: remove duplicate validate calls on same oauth token
This commit is contained in:
parent
8181c9f349
commit
03ee63931c
|
@ -184,6 +184,9 @@ type Options struct {
|
|||
// under the enterprise license, and can't be imported into AGPL.
|
||||
ParseLicenseClaims func(rawJWT string) (email string, trial bool, err error)
|
||||
AllowWorkspaceRenames bool
|
||||
|
||||
// NewTicker is used for unit tests to replace "time.NewTicker".
|
||||
NewTicker func(duration time.Duration) (tick <-chan time.Time, done func())
|
||||
}
|
||||
|
||||
// @title Coder API
|
||||
|
@ -208,6 +211,12 @@ func New(options *Options) *API {
|
|||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
if options.NewTicker == nil {
|
||||
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
|
||||
ticker := time.NewTicker(duration)
|
||||
return ticker.C, ticker.Stop
|
||||
}
|
||||
}
|
||||
|
||||
// Safety check: if we're not running a unit test, we *must* have a Prometheus registry.
|
||||
if options.PrometheusRegistry == nil && flag.Lookup("test.v") == nil {
|
||||
|
|
|
@ -145,6 +145,7 @@ type Options struct {
|
|||
|
||||
WorkspaceAppsStatsCollectorOptions workspaceapps.StatsCollectorOptions
|
||||
AllowWorkspaceRenames bool
|
||||
NewTicker func(duration time.Duration) (<-chan time.Time, func())
|
||||
}
|
||||
|
||||
// New constructs a codersdk client connected to an in-memory API instance.
|
||||
|
@ -451,6 +452,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
|
|||
StatsBatcher: options.StatsBatcher,
|
||||
WorkspaceAppsStatsCollectorOptions: options.WorkspaceAppsStatsCollectorOptions,
|
||||
AllowWorkspaceRenames: options.AllowWorkspaceRenames,
|
||||
NewTicker: options.NewTicker,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2051,13 +2051,14 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
|||
if listen {
|
||||
// Since we're ticking frequently and this sign-in operation is rare,
|
||||
// we are OK with polling to avoid the complexity of pubsub.
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
ticker, done := api.NewTicker(time.Second)
|
||||
defer done()
|
||||
var previousToken database.ExternalAuthLink
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
case <-ticker:
|
||||
}
|
||||
externalAuthLink, err := api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
|
||||
ProviderID: externalAuthConfig.ID,
|
||||
|
@ -2081,6 +2082,15 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
|||
if externalAuthLink.OAuthExpiry.Before(dbtime.Now()) && !externalAuthLink.OAuthExpiry.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only attempt to revalidate an oauth token if it has actually changed.
|
||||
// No point in trying to validate the same token over and over again.
|
||||
if previousToken.OAuthAccessToken == externalAuthLink.OAuthAccessToken &&
|
||||
previousToken.OAuthRefreshToken == externalAuthLink.OAuthRefreshToken &&
|
||||
previousToken.OAuthExpiry == externalAuthLink.OAuthExpiry {
|
||||
continue
|
||||
}
|
||||
|
||||
valid, _, err := externalAuthConfig.ValidateToken(ctx, externalAuthLink.OAuthAccessToken)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to validate external auth token",
|
||||
|
@ -2089,6 +2099,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
|||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
previousToken = externalAuthLink
|
||||
if !valid {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -25,12 +25,15 @@ import (
|
|||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmem"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
|
@ -1536,3 +1539,94 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) {
|
|||
require.True(t, ok)
|
||||
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// ValidateURLSpam acts as a workspace calling GIT_ASK_PASS which
|
||||
// will wait until the external auth token is valid. The issue is we spam
|
||||
// the validate endpoint with requests until the token is valid. We do this
|
||||
// even if the token has not changed. We are calling validate with the
|
||||
// same inputs expecting a different result (insanity?). To reduce our
|
||||
// api rate limit usage, we should do nothing if the inputs have not
|
||||
// changed.
|
||||
//
|
||||
// Note that an expired oauth token is already skipped, so this really
|
||||
// only covers the case of a revoked token.
|
||||
t.Run("ValidateURLSpam", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const providerID = "fake-idp"
|
||||
|
||||
// Count all the times we call validate
|
||||
validateCalls := 0
|
||||
fake := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithMiddlewares(func(handler http.Handler) http.Handler {
|
||||
return http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Count all the validate calls
|
||||
if strings.Contains(r.URL.Path, "/external-auth-validate/") {
|
||||
validateCalls++
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
}))
|
||||
|
||||
ticks := make(chan time.Time)
|
||||
// setup
|
||||
ownerClient, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
NewTicker: func(duration time.Duration) (<-chan time.Time, func()) {
|
||||
return ticks, func() {}
|
||||
},
|
||||
ExternalAuthConfigs: []*externalauth.Config{
|
||||
fake.ExternalAuthConfig(t, providerID, nil, func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
}),
|
||||
},
|
||||
})
|
||||
first := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
tmpDir := t.TempDir()
|
||||
client, user := coderdtest.CreateAnotherUser(t, ownerClient, first.OrganizationID)
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.Workspace{
|
||||
OrganizationID: first.OrganizationID,
|
||||
OwnerID: user.ID,
|
||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
||||
agents[0].Directory = tmpDir
|
||||
return agents
|
||||
}).Do()
|
||||
|
||||
agentClient := agentsdk.New(client.URL)
|
||||
agentClient.SetSessionToken(r.AgentToken)
|
||||
|
||||
// We need to include an invalid oauth token that is not expired.
|
||||
dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
|
||||
ProviderID: providerID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
OAuthAccessToken: "invalid",
|
||||
OAuthRefreshToken: "bad",
|
||||
OAuthExpiry: dbtime.Now().Add(time.Hour),
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
go func() {
|
||||
// The request that will block and fire off validate calls.
|
||||
_, err := agentClient.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
|
||||
ID: providerID,
|
||||
Match: "",
|
||||
Listen: true,
|
||||
})
|
||||
assert.Error(t, err, "this should fail")
|
||||
}()
|
||||
|
||||
// Send off 10 ticks to cause 10 validate calls
|
||||
for i := 0; i < 10; i++ {
|
||||
ticks <- time.Now()
|
||||
}
|
||||
cancel()
|
||||
// We expect only 1
|
||||
// In a failed test, you will likely see 9, as the last one
|
||||
// gets cancelled.
|
||||
require.Equal(t, 1, validateCalls, "validate calls duplicated on same token")
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue