mirror of https://github.com/coder/coder.git
fix: refresh all oauth links on external auth page (#11646)
* fix: refresh all oauth links on external auth page
This commit is contained in:
parent
d583acad00
commit
08b4eb3124
|
@ -362,7 +362,6 @@ func (api *API) listUserExternalAuths(rw http.ResponseWriter, r *http.Request) {
|
|||
if err == nil && valid {
|
||||
links[i] = newLink
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
|
@ -198,6 +200,66 @@ func TestExternalAuthManagement(t *testing.T) {
|
|||
require.Len(t, list.Providers, 2)
|
||||
require.Len(t, list.Links, 0)
|
||||
})
|
||||
t.Run("RefreshAllProviders", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
const githubID = "fake-github"
|
||||
const gitlabID = "fake-gitlab"
|
||||
|
||||
githubCalled := false
|
||||
githubApp := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error {
|
||||
githubCalled = true
|
||||
return nil
|
||||
}))
|
||||
gitlabCalled := false
|
||||
gitlab := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error {
|
||||
gitlabCalled = true
|
||||
return nil
|
||||
}))
|
||||
|
||||
owner, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
ExternalAuthConfigs: []*externalauth.Config{
|
||||
githubApp.ExternalAuthConfig(t, githubID, nil, func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
}),
|
||||
gitlab.ExternalAuthConfig(t, gitlabID, nil, func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitLab.String()
|
||||
}),
|
||||
},
|
||||
})
|
||||
ownerUser := coderdtest.CreateFirstUser(t, owner)
|
||||
// Just a regular user
|
||||
client, user := coderdtest.CreateAnotherUser(t, owner, ownerUser.OrganizationID)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Log into github & gitlab
|
||||
githubApp.ExternalLogin(t, client)
|
||||
gitlab.ExternalLogin(t, client)
|
||||
|
||||
links, err := db.GetExternalAuthLinksByUserID(
|
||||
dbauthz.As(ctx, coderdtest.AuthzUserSubject(user, ownerUser.OrganizationID)), user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, links, 2)
|
||||
|
||||
// Expire the links
|
||||
for _, l := range links {
|
||||
_, err := db.UpdateExternalAuthLink(dbauthz.As(ctx, coderdtest.AuthzUserSubject(user, ownerUser.OrganizationID)), database.UpdateExternalAuthLinkParams{
|
||||
ProviderID: l.ProviderID,
|
||||
UserID: l.UserID,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
OAuthAccessToken: l.OAuthAccessToken,
|
||||
OAuthRefreshToken: l.OAuthRefreshToken,
|
||||
OAuthExpiry: time.Now().Add(time.Hour * -1),
|
||||
OAuthExtra: l.OAuthExtra,
|
||||
})
|
||||
require.NoErrorf(t, err, "expire key for %s", l.ProviderID)
|
||||
}
|
||||
|
||||
list, err := client.ListExternalAuths(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, list.Links, 2)
|
||||
require.True(t, githubCalled, "github should be refreshed")
|
||||
require.True(t, gitlabCalled, "gitlab should be refreshed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExternalAuthDevice(t *testing.T) {
|
||||
|
|
|
@ -336,6 +336,10 @@ func ExpectJSONMime(res *http.Response) error {
|
|||
|
||||
// ReadBodyAsError reads the response as a codersdk.Response, and
|
||||
// wraps it in a codersdk.Error type for easy marshaling.
|
||||
//
|
||||
// This will always return an error, so only call it if the response failed
|
||||
// your expectations. Usually via status code checking.
|
||||
// nolint:staticcheck
|
||||
func ReadBodyAsError(res *http.Response) error {
|
||||
if res == nil {
|
||||
return xerrors.Errorf("no body returned")
|
||||
|
|
|
@ -283,6 +283,17 @@ func Test_readBodyAsError(t *testing.T) {
|
|||
assert.Equal(t, unexpectedJSON, sdkErr.Response.Detail)
|
||||
},
|
||||
},
|
||||
{
|
||||
// Even status code 200 should be considered an error if this function
|
||||
// is called. There are parts of the code that require this function
|
||||
// to always return an error.
|
||||
name: "OKResp",
|
||||
req: nil,
|
||||
res: newResponse(http.StatusOK, jsonCT, marshal(map[string]any{})),
|
||||
assert: func(t *testing.T, err error) {
|
||||
require.Error(t, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range tests {
|
||||
|
|
|
@ -321,19 +321,17 @@ func (p *ProxyHealth) runOnce(ctx context.Context, now time.Time) (map[uuid.UUID
|
|||
// readable.
|
||||
builder.WriteString(fmt.Sprintf("unexpected status code %d. ", resp.StatusCode))
|
||||
builder.WriteString(fmt.Sprintf("\nEncountered error, send a request to %q from the Coderd environment to debug this issue.", reqURL))
|
||||
// err will always be non-nil
|
||||
err := codersdk.ReadBodyAsError(resp)
|
||||
if err != nil {
|
||||
var apiErr *codersdk.Error
|
||||
if xerrors.As(err, &apiErr) {
|
||||
builder.WriteString(fmt.Sprintf("\nError Message: %s\nError Detail: %s", apiErr.Message, apiErr.Detail))
|
||||
for _, v := range apiErr.Validations {
|
||||
// Pretty sure this is not possible from the called endpoint, but just in case.
|
||||
builder.WriteString(fmt.Sprintf("\n\tValidation: %s=%s", v.Field, v.Detail))
|
||||
}
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf("\nError: %s", err.Error()))
|
||||
var apiErr *codersdk.Error
|
||||
if xerrors.As(err, &apiErr) {
|
||||
builder.WriteString(fmt.Sprintf("\nError Message: %s\nError Detail: %s", apiErr.Message, apiErr.Detail))
|
||||
for _, v := range apiErr.Validations {
|
||||
// Pretty sure this is not possible from the called endpoint, but just in case.
|
||||
builder.WriteString(fmt.Sprintf("\n\tValidation: %s=%s", v.Field, v.Detail))
|
||||
}
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("\nError: %s", err.Error()))
|
||||
|
||||
status.Report.Errors = []string{builder.String()}
|
||||
case err != nil:
|
||||
|
|
Loading…
Reference in New Issue