From 08b4eb3124228f8f06e4c3bc4a98dd441ca2eab7 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Tue, 16 Jan 2024 11:03:55 -0600 Subject: [PATCH] fix: refresh all oauth links on external auth page (#11646) * fix: refresh all oauth links on external auth page --- coderd/externalauth.go | 1 - coderd/externalauth_test.go | 62 ++++++++++++++++++++ codersdk/client.go | 4 ++ codersdk/client_internal_test.go | 11 ++++ enterprise/coderd/proxyhealth/proxyhealth.go | 18 +++--- 5 files changed, 85 insertions(+), 11 deletions(-) diff --git a/coderd/externalauth.go b/coderd/externalauth.go index b9d7e665b1..001592e04e 100644 --- a/coderd/externalauth.go +++ b/coderd/externalauth.go @@ -362,7 +362,6 @@ func (api *API) listUserExternalAuths(rw http.ResponseWriter, r *http.Request) { if err == nil && valid { links[i] = newLink } - break } } } diff --git a/coderd/externalauth_test.go b/coderd/externalauth_test.go index 4892ad6598..17adfac69d 100644 --- a/coderd/externalauth_test.go +++ b/coderd/externalauth_test.go @@ -18,6 +18,8 @@ import ( "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/httpapi" @@ -198,6 +200,66 @@ func TestExternalAuthManagement(t *testing.T) { require.Len(t, list.Providers, 2) require.Len(t, list.Links, 0) }) + t.Run("RefreshAllProviders", func(t *testing.T) { + t.Parallel() + const githubID = "fake-github" + const gitlabID = "fake-gitlab" + + githubCalled := false + githubApp := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error { + githubCalled = true + return nil + })) + gitlabCalled := false + gitlab := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithRefresh(func(email string) error { + gitlabCalled = true + return nil + })) + + owner, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ + ExternalAuthConfigs: []*externalauth.Config{ + githubApp.ExternalAuthConfig(t, githubID, nil, func(cfg *externalauth.Config) { + cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() + }), + gitlab.ExternalAuthConfig(t, gitlabID, nil, func(cfg *externalauth.Config) { + cfg.Type = codersdk.EnhancedExternalAuthProviderGitLab.String() + }), + }, + }) + ownerUser := coderdtest.CreateFirstUser(t, owner) + // Just a regular user + client, user := coderdtest.CreateAnotherUser(t, owner, ownerUser.OrganizationID) + ctx := testutil.Context(t, testutil.WaitLong) + + // Log into github & gitlab + githubApp.ExternalLogin(t, client) + gitlab.ExternalLogin(t, client) + + links, err := db.GetExternalAuthLinksByUserID( + dbauthz.As(ctx, coderdtest.AuthzUserSubject(user, ownerUser.OrganizationID)), user.ID) + require.NoError(t, err) + require.Len(t, links, 2) + + // Expire the links + for _, l := range links { + _, err := db.UpdateExternalAuthLink(dbauthz.As(ctx, coderdtest.AuthzUserSubject(user, ownerUser.OrganizationID)), database.UpdateExternalAuthLinkParams{ + ProviderID: l.ProviderID, + UserID: l.UserID, + UpdatedAt: dbtime.Now(), + OAuthAccessToken: l.OAuthAccessToken, + OAuthRefreshToken: l.OAuthRefreshToken, + OAuthExpiry: time.Now().Add(time.Hour * -1), + OAuthExtra: l.OAuthExtra, + }) + require.NoErrorf(t, err, "expire key for %s", l.ProviderID) + } + + list, err := client.ListExternalAuths(ctx) + require.NoError(t, err) + require.Len(t, list.Links, 2) + require.True(t, githubCalled, "github should be refreshed") + require.True(t, gitlabCalled, "gitlab should be refreshed") + }) } func TestExternalAuthDevice(t *testing.T) { diff --git a/codersdk/client.go b/codersdk/client.go index d7ca661adf..b6a1b1dc11 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -336,6 +336,10 @@ func ExpectJSONMime(res *http.Response) error { // ReadBodyAsError reads the response as a codersdk.Response, and // wraps it in a codersdk.Error type for easy marshaling. +// +// This will always return an error, so only call it if the response failed +// your expectations. Usually via status code checking. +// nolint:staticcheck func ReadBodyAsError(res *http.Response) error { if res == nil { return xerrors.Errorf("no body returned") diff --git a/codersdk/client_internal_test.go b/codersdk/client_internal_test.go index ae86ce81ef..9093c27778 100644 --- a/codersdk/client_internal_test.go +++ b/codersdk/client_internal_test.go @@ -283,6 +283,17 @@ func Test_readBodyAsError(t *testing.T) { assert.Equal(t, unexpectedJSON, sdkErr.Response.Detail) }, }, + { + // Even status code 200 should be considered an error if this function + // is called. There are parts of the code that require this function + // to always return an error. + name: "OKResp", + req: nil, + res: newResponse(http.StatusOK, jsonCT, marshal(map[string]any{})), + assert: func(t *testing.T, err error) { + require.Error(t, err) + }, + }, } for _, c := range tests { diff --git a/enterprise/coderd/proxyhealth/proxyhealth.go b/enterprise/coderd/proxyhealth/proxyhealth.go index 56a2fe4e1f..4d77f02c51 100644 --- a/enterprise/coderd/proxyhealth/proxyhealth.go +++ b/enterprise/coderd/proxyhealth/proxyhealth.go @@ -321,19 +321,17 @@ func (p *ProxyHealth) runOnce(ctx context.Context, now time.Time) (map[uuid.UUID // readable. builder.WriteString(fmt.Sprintf("unexpected status code %d. ", resp.StatusCode)) builder.WriteString(fmt.Sprintf("\nEncountered error, send a request to %q from the Coderd environment to debug this issue.", reqURL)) + // err will always be non-nil err := codersdk.ReadBodyAsError(resp) - if err != nil { - var apiErr *codersdk.Error - if xerrors.As(err, &apiErr) { - builder.WriteString(fmt.Sprintf("\nError Message: %s\nError Detail: %s", apiErr.Message, apiErr.Detail)) - for _, v := range apiErr.Validations { - // Pretty sure this is not possible from the called endpoint, but just in case. - builder.WriteString(fmt.Sprintf("\n\tValidation: %s=%s", v.Field, v.Detail)) - } - } else { - builder.WriteString(fmt.Sprintf("\nError: %s", err.Error())) + var apiErr *codersdk.Error + if xerrors.As(err, &apiErr) { + builder.WriteString(fmt.Sprintf("\nError Message: %s\nError Detail: %s", apiErr.Message, apiErr.Detail)) + for _, v := range apiErr.Validations { + // Pretty sure this is not possible from the called endpoint, but just in case. + builder.WriteString(fmt.Sprintf("\n\tValidation: %s=%s", v.Field, v.Detail)) } } + builder.WriteString(fmt.Sprintf("\nError: %s", err.Error())) status.Report.Errors = []string{builder.String()} case err != nil: