From 4d39da294e444d6304d00f3ff5104abab9244068 Mon Sep 17 00:00:00 2001 From: Asher Date: Tue, 20 Feb 2024 15:58:43 -0800 Subject: [PATCH] feat: add oauth2 token exchange (#12196) Co-authored-by: Steven Masley --- coderd/apidoc/docs.go | 169 ++++ coderd/apidoc/swagger.json | 156 +++ coderd/coderdtest/oidctest/helper.go | 52 + coderd/coderdtest/oidctest/idp.go | 37 +- coderd/database/dbauthz/dbauthz.go | 83 ++ coderd/database/dbauthz/dbauthz_test.go | 139 +++ coderd/database/dbgen/dbgen.go | 29 + coderd/database/dbmem/dbmem.go | 289 +++++- coderd/database/dbmetrics/dbmetrics.go | 70 ++ coderd/database/dbmock/dbmock.go | 147 +++ coderd/database/dump.sql | 71 +- coderd/database/foreign_key_constraint.go | 4 + .../000195_oauth2_provider_codes.down.sql | 18 + .../000195_oauth2_provider_codes.up.sql | 65 ++ .../000195_oauth2_provider_codes.up.sql | 23 + coderd/database/modelmethods.go | 16 + coderd/database/models.go | 38 +- coderd/database/querier.go | 10 + coderd/database/queries.sql.go | 298 +++++- coderd/database/queries/oauth2.sql | 87 +- coderd/database/sqlc.yaml | 4 + coderd/database/unique_constraint.go | 6 +- coderd/httpapi/queryparams.go | 41 +- coderd/httpapi/queryparams_test.go | 7 +- coderd/httpmw/oauth2.go | 43 +- coderd/insights.go | 12 +- coderd/rbac/object.go | 12 +- coderd/rbac/object_gen.go | 1 + coderd/rbac/roles.go | 2 + coderd/users.go | 2 +- coderd/workspaceapps/proxy.go | 2 +- codersdk/oauth2.go | 63 +- docs/api/enterprise.md | 127 +++ docs/api/schemas.md | 21 + enterprise/coderd/coderd.go | 22 + .../coderd/identityprovider/authorize.go | 140 +++ .../coderd/identityprovider/middleware.go | 149 +++ enterprise/coderd/identityprovider/revoke.go | 44 + enterprise/coderd/identityprovider/secrets.go | 77 ++ enterprise/coderd/identityprovider/tokens.go | 378 ++++++++ enterprise/coderd/jfrog.go | 4 +- enterprise/coderd/oauth2.go | 96 +- enterprise/coderd/oauth2_test.go | 897 ++++++++++++++++-- site/site.go | 38 + site/src/api/typesGenerated.ts | 18 + site/static/oauth2allow.html | 168 ++++ 46 files changed, 4008 insertions(+), 167 deletions(-) create mode 100644 coderd/database/migrations/000195_oauth2_provider_codes.down.sql create mode 100644 coderd/database/migrations/000195_oauth2_provider_codes.up.sql create mode 100644 coderd/database/migrations/testdata/fixtures/000195_oauth2_provider_codes.up.sql create mode 100644 enterprise/coderd/identityprovider/authorize.go create mode 100644 enterprise/coderd/identityprovider/middleware.go create mode 100644 enterprise/coderd/identityprovider/revoke.go create mode 100644 enterprise/coderd/identityprovider/secrets.go create mode 100644 enterprise/coderd/identityprovider/tokens.go create mode 100644 site/static/oauth2allow.html diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index ba528c74d6..93a644c303 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -1491,6 +1491,146 @@ const docTemplate = `{ } } }, + "/login/oauth2/authorize": { + "post": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": [ + "Enterprise" + ], + "summary": "OAuth2 authorization request.", + "operationId": "oauth2-authorization-request", + "parameters": [ + { + "type": "string", + "description": "Client ID", + "name": "client_id", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "A random unguessable string", + "name": "state", + "in": "query", + "required": true + }, + { + "enum": [ + "code" + ], + "type": "string", + "description": "Response type", + "name": "response_type", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "Redirect here after authorization", + "name": "redirect_uri", + "in": "query" + }, + { + "type": "string", + "description": "Token scopes (currently ignored)", + "name": "scope", + "in": "query" + } + ], + "responses": { + "302": { + "description": "Found" + } + } + } + }, + "/login/oauth2/tokens": { + "post": { + "produces": [ + "application/json" + ], + "tags": [ + "Enterprise" + ], + "summary": "OAuth2 token exchange.", + "operationId": "oauth2-token-exchange", + "parameters": [ + { + "type": "string", + "description": "Client ID, required if grant_type=authorization_code", + "name": "client_id", + "in": "formData" + }, + { + "type": "string", + "description": "Client secret, required if grant_type=authorization_code", + "name": "client_secret", + "in": "formData" + }, + { + "type": "string", + "description": "Authorization code, required if grant_type=authorization_code", + "name": "code", + "in": "formData" + }, + { + "type": "string", + "description": "Refresh token, required if grant_type=refresh_token", + "name": "refresh_token", + "in": "formData" + }, + { + "enum": [ + "authorization_code", + "refresh_token" + ], + "type": "string", + "description": "Grant type", + "name": "grant_type", + "in": "formData", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/oauth2.Token" + } + } + } + }, + "delete": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": [ + "Enterprise" + ], + "summary": "Delete OAuth2 application tokens.", + "operationId": "delete-oauth2-application-tokens", + "parameters": [ + { + "type": "string", + "description": "Client ID", + "name": "client_id", + "in": "query", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + } + } + } + }, "/oauth2-provider/apps": { "get": { "security": [ @@ -1506,6 +1646,14 @@ const docTemplate = `{ ], "summary": "Get OAuth2 applications.", "operationId": "get-oauth2-applications", + "parameters": [ + { + "type": "string", + "description": "Filter by applications authorized for a user", + "name": "user_id", + "in": "query" + } + ], "responses": { "200": { "description": "OK", @@ -13948,6 +14096,27 @@ const docTemplate = `{ } } }, + "oauth2.Token": { + "type": "object", + "properties": { + "access_token": { + "description": "AccessToken is the token that authorizes and authenticates\nthe requests.", + "type": "string" + }, + "expiry": { + "description": "Expiry is the optional expiration time of the access token.\n\nIf zero, TokenSource implementations will reuse the same\ntoken forever and RefreshToken or equivalent\nmechanisms for that TokenSource will not be used.", + "type": "string" + }, + "refresh_token": { + "description": "RefreshToken is a token that's used by the application\n(as opposed to the user) to refresh the access token\nif it expires.", + "type": "string" + }, + "token_type": { + "description": "TokenType is the type of token.\nThe Type method returns either this or \"Bearer\", the default.", + "type": "string" + } + } + }, "tailcfg.DERPHomeParams": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index d09381cf74..c243356f88 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -1291,6 +1291,133 @@ } } }, + "/login/oauth2/authorize": { + "post": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": ["Enterprise"], + "summary": "OAuth2 authorization request.", + "operationId": "oauth2-authorization-request", + "parameters": [ + { + "type": "string", + "description": "Client ID", + "name": "client_id", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "A random unguessable string", + "name": "state", + "in": "query", + "required": true + }, + { + "enum": ["code"], + "type": "string", + "description": "Response type", + "name": "response_type", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "Redirect here after authorization", + "name": "redirect_uri", + "in": "query" + }, + { + "type": "string", + "description": "Token scopes (currently ignored)", + "name": "scope", + "in": "query" + } + ], + "responses": { + "302": { + "description": "Found" + } + } + } + }, + "/login/oauth2/tokens": { + "post": { + "produces": ["application/json"], + "tags": ["Enterprise"], + "summary": "OAuth2 token exchange.", + "operationId": "oauth2-token-exchange", + "parameters": [ + { + "type": "string", + "description": "Client ID, required if grant_type=authorization_code", + "name": "client_id", + "in": "formData" + }, + { + "type": "string", + "description": "Client secret, required if grant_type=authorization_code", + "name": "client_secret", + "in": "formData" + }, + { + "type": "string", + "description": "Authorization code, required if grant_type=authorization_code", + "name": "code", + "in": "formData" + }, + { + "type": "string", + "description": "Refresh token, required if grant_type=refresh_token", + "name": "refresh_token", + "in": "formData" + }, + { + "enum": ["authorization_code", "refresh_token"], + "type": "string", + "description": "Grant type", + "name": "grant_type", + "in": "formData", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/oauth2.Token" + } + } + } + }, + "delete": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": ["Enterprise"], + "summary": "Delete OAuth2 application tokens.", + "operationId": "delete-oauth2-application-tokens", + "parameters": [ + { + "type": "string", + "description": "Client ID", + "name": "client_id", + "in": "query", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + } + } + } + }, "/oauth2-provider/apps": { "get": { "security": [ @@ -1302,6 +1429,14 @@ "tags": ["Enterprise"], "summary": "Get OAuth2 applications.", "operationId": "get-oauth2-applications", + "parameters": [ + { + "type": "string", + "description": "Filter by applications authorized for a user", + "name": "user_id", + "in": "query" + } + ], "responses": { "200": { "description": "OK", @@ -12716,6 +12851,27 @@ } } }, + "oauth2.Token": { + "type": "object", + "properties": { + "access_token": { + "description": "AccessToken is the token that authorizes and authenticates\nthe requests.", + "type": "string" + }, + "expiry": { + "description": "Expiry is the optional expiration time of the access token.\n\nIf zero, TokenSource implementations will reuse the same\ntoken forever and RefreshToken or equivalent\nmechanisms for that TokenSource will not be used.", + "type": "string" + }, + "refresh_token": { + "description": "RefreshToken is a token that's used by the application\n(as opposed to the user) to refresh the access token\nif it expires.", + "type": "string" + }, + "token_type": { + "description": "TokenType is the type of token.\nThe Type method returns either this or \"Bearer\", the default.", + "type": "string" + } + } + }, "tailcfg.DERPHomeParams": { "type": "object", "properties": { diff --git a/coderd/coderdtest/oidctest/helper.go b/coderd/coderdtest/oidctest/helper.go index abf29d4fa2..beb1243e2c 100644 --- a/coderd/coderdtest/oidctest/helper.go +++ b/coderd/coderdtest/oidctest/helper.go @@ -1,14 +1,17 @@ package oidctest import ( + "context" "database/sql" "encoding/json" "net/http" + "net/url" "testing" "time" "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -114,3 +117,52 @@ func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *coders _, err := user.User(testutil.Context(t, testutil.WaitShort), "me") require.NoError(t, err, "user must be able to be fetched") } + +// OAuth2GetCode emulates a user clicking "allow" on the IDP page. When doing +// unit tests, it's easier to skip this step sometimes. It does make an actual +// request to the IDP, so it should be equivalent to doing this "manually" with +// actual requests. +func OAuth2GetCode(rawAuthURL string, doRequest func(req *http.Request) (*http.Response, error)) (string, error) { + authURL, err := url.Parse(rawAuthURL) + if err != nil { + return "", xerrors.Errorf("failed to parse auth URL: %w", err) + } + + r, err := http.NewRequestWithContext(context.Background(), http.MethodGet, rawAuthURL, nil) + if err != nil { + return "", xerrors.Errorf("failed to create auth request: %w", err) + } + + expCode := http.StatusTemporaryRedirect + resp, err := doRequest(r) + if err != nil { + return "", xerrors.Errorf("request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != expCode { + return "", codersdk.ReadBodyAsError(resp) + } + + to := resp.Header.Get("Location") + if to == "" { + return "", xerrors.Errorf("expected redirect location") + } + + toURL, err := url.Parse(to) + if err != nil { + return "", xerrors.Errorf("failed to parse redirect location: %w", err) + } + + code := toURL.Query().Get("code") + if code == "" { + return "", xerrors.Errorf("expected code in redirect location") + } + + state := authURL.Query().Get("state") + newState := toURL.Query().Get("state") + if newState != state { + return "", xerrors.Errorf("expected state %q, got %q", state, newState) + } + return code, nil +} diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index cc0fe97434..a185332e87 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -534,37 +534,18 @@ func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthI // unit tests, it's easier to skip this step sometimes. It does make an actual // request to the IDP, so it should be equivalent to doing this "manually" with // actual requests. -func (f *FakeIDP) CreateAuthCode(t testing.TB, state string, opts ...func(r *http.Request)) string { +func (f *FakeIDP) CreateAuthCode(t testing.TB, state string) string { // We need to store some claims, because this is also an OIDC provider, and // it expects some claims to be present. f.stateToIDTokenClaims.Store(state, jwt.MapClaims{}) - u := f.cfg.AuthCodeURL(state) - r, err := http.NewRequestWithContext(context.Background(), http.MethodPost, u, nil) - require.NoError(t, err, "failed to create auth request") - - for _, opt := range opts { - opt(r) - } - - rw := httptest.NewRecorder() - f.handler.ServeHTTP(rw, r) - resp := rw.Result() - defer resp.Body.Close() - - require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode, "expected redirect") - to := resp.Header.Get("Location") - require.NotEmpty(t, to, "expected redirect location") - - toURL, err := url.Parse(to) - require.NoError(t, err, "failed to parse redirect location") - - code := toURL.Query().Get("code") - require.NotEmpty(t, code, "expected code in redirect location") - - newState := toURL.Query().Get("state") - require.Equal(t, state, newState, "expected state to match") - + code, err := OAuth2GetCode(f.cfg.AuthCodeURL(state), func(req *http.Request) (*http.Response, error) { + rw := httptest.NewRecorder() + f.handler.ServeHTTP(rw, req) + resp := rw.Result() + return resp, nil + }) + require.NoError(t, err, "failed to get auth code") return code } @@ -1071,7 +1052,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler { f.logger.Info(r.Context(), "http call device auth") p := httpapi.NewQueryParamParser() - p.Required("client_id") + p.RequiredNotEmpty("client_id") clientID := p.String(r.URL.Query(), "", "client_id") _ = p.String(r.URL.Query(), "", "scopes") if len(p.Errors) > 0 { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index d4cb4b15c7..135703bb0b 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -801,6 +801,25 @@ func (q *querier) DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) return q.db.DeleteOAuth2ProviderAppByID(ctx, id) } +func (q *querier) DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error { + code, err := q.db.GetOAuth2ProviderAppCodeByID(ctx, id) + if err != nil { + return err + } + if err := q.authorizeContext(ctx, rbac.ActionDelete, code); err != nil { + return err + } + return q.db.DeleteOAuth2ProviderAppCodeByID(ctx, id) +} + +func (q *querier) DeleteOAuth2ProviderAppCodesByAppAndUserID(ctx context.Context, arg database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, + rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(arg.UserID.String())); err != nil { + return err + } + return q.db.DeleteOAuth2ProviderAppCodesByAppAndUserID(ctx, arg) +} + func (q *querier) DeleteOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) error { if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceOAuth2ProviderAppSecret); err != nil { return err @@ -808,6 +827,14 @@ func (q *querier) DeleteOAuth2ProviderAppSecretByID(ctx context.Context, id uuid return q.db.DeleteOAuth2ProviderAppSecretByID(ctx, id) } +func (q *querier) DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx context.Context, arg database.DeleteOAuth2ProviderAppTokensByAppAndUserIDParams) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, + rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(arg.UserID.String())); err != nil { + return err + } + return q.db.DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx, arg) +} + func (q *querier) DeleteOldProvisionerDaemons(ctx context.Context) error { if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceSystem); err != nil { return err @@ -1175,6 +1202,14 @@ func (q *querier) GetOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) (d return q.db.GetOAuth2ProviderAppByID(ctx, id) } +func (q *querier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderAppCode, error) { + return fetch(q.log, q.auth, q.db.GetOAuth2ProviderAppCodeByID)(ctx, id) +} + +func (q *querier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (database.OAuth2ProviderAppCode, error) { + return fetch(q.log, q.auth, q.db.GetOAuth2ProviderAppCodeByPrefix)(ctx, secretPrefix) +} + func (q *querier) GetOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderAppSecret, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceOAuth2ProviderAppSecret); err != nil { return database.OAuth2ProviderAppSecret{}, err @@ -1182,6 +1217,10 @@ func (q *querier) GetOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UU return q.db.GetOAuth2ProviderAppSecretByID(ctx, id) } +func (q *querier) GetOAuth2ProviderAppSecretByPrefix(ctx context.Context, secretPrefix []byte) (database.OAuth2ProviderAppSecret, error) { + return fetch(q.log, q.auth, q.db.GetOAuth2ProviderAppSecretByPrefix)(ctx, secretPrefix) +} + func (q *querier) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, appID uuid.UUID) ([]database.OAuth2ProviderAppSecret, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceOAuth2ProviderAppSecret); err != nil { return []database.OAuth2ProviderAppSecret{}, err @@ -1189,6 +1228,22 @@ func (q *querier) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, appID return q.db.GetOAuth2ProviderAppSecretsByAppID(ctx, appID) } +func (q *querier) GetOAuth2ProviderAppTokenByPrefix(ctx context.Context, hashPrefix []byte) (database.OAuth2ProviderAppToken, error) { + token, err := q.db.GetOAuth2ProviderAppTokenByPrefix(ctx, hashPrefix) + if err != nil { + return database.OAuth2ProviderAppToken{}, err + } + // The user ID is on the API key so that has to be fetched. + key, err := q.db.GetAPIKeyByID(ctx, token.APIKeyID) + if err != nil { + return database.OAuth2ProviderAppToken{}, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(key.UserID.String())); err != nil { + return database.OAuth2ProviderAppToken{}, err + } + return token, nil +} + func (q *querier) GetOAuth2ProviderApps(ctx context.Context) ([]database.OAuth2ProviderApp, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceOAuth2ProviderApp); err != nil { return []database.OAuth2ProviderApp{}, err @@ -1196,6 +1251,15 @@ func (q *querier) GetOAuth2ProviderApps(ctx context.Context) ([]database.OAuth2P return q.db.GetOAuth2ProviderApps(ctx) } +func (q *querier) GetOAuth2ProviderAppsByUserID(ctx context.Context, userID uuid.UUID) ([]database.GetOAuth2ProviderAppsByUserIDRow, error) { + // This authz check is to make sure the caller can read all their own tokens. + if err := q.authorizeContext(ctx, rbac.ActionRead, + rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(userID.String())); err != nil { + return []database.GetOAuth2ProviderAppsByUserIDRow{}, err + } + return q.db.GetOAuth2ProviderAppsByUserID(ctx, userID) +} + func (q *querier) GetOAuthSigningKey(ctx context.Context) (string, error) { if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil { return "", err @@ -2242,6 +2306,14 @@ func (q *querier) InsertOAuth2ProviderApp(ctx context.Context, arg database.Inse return q.db.InsertOAuth2ProviderApp(ctx, arg) } +func (q *querier) InsertOAuth2ProviderAppCode(ctx context.Context, arg database.InsertOAuth2ProviderAppCodeParams) (database.OAuth2ProviderAppCode, error) { + if err := q.authorizeContext(ctx, rbac.ActionCreate, + rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(arg.UserID.String())); err != nil { + return database.OAuth2ProviderAppCode{}, err + } + return q.db.InsertOAuth2ProviderAppCode(ctx, arg) +} + func (q *querier) InsertOAuth2ProviderAppSecret(ctx context.Context, arg database.InsertOAuth2ProviderAppSecretParams) (database.OAuth2ProviderAppSecret, error) { if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceOAuth2ProviderAppSecret); err != nil { return database.OAuth2ProviderAppSecret{}, err @@ -2249,6 +2321,17 @@ func (q *querier) InsertOAuth2ProviderAppSecret(ctx context.Context, arg databas return q.db.InsertOAuth2ProviderAppSecret(ctx, arg) } +func (q *querier) InsertOAuth2ProviderAppToken(ctx context.Context, arg database.InsertOAuth2ProviderAppTokenParams) (database.OAuth2ProviderAppToken, error) { + key, err := q.db.GetAPIKeyByID(ctx, arg.APIKeyID) + if err != nil { + return database.OAuth2ProviderAppToken{}, err + } + if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(key.UserID.String())); err != nil { + return database.OAuth2ProviderAppToken{}, err + } + return q.db.InsertOAuth2ProviderAppToken(ctx, arg) +} + func (q *querier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { return insert(q.log, q.auth, rbac.ResourceOrganization, q.db.InsertOrganization)(ctx, arg) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index a483a1fe96..275eb27040 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2316,6 +2316,34 @@ func (s *MethodTestSuite) TestOAuth2ProviderApps() { app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) check.Args(app.ID).Asserts(rbac.ResourceOAuth2ProviderApp, rbac.ActionRead).Returns(app) })) + s.Run("GetOAuth2ProviderAppsByUserID", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{ + UserID: user.ID, + }) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + _ = dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + secret := dbgen.OAuth2ProviderAppSecret(s.T(), db, database.OAuth2ProviderAppSecret{ + AppID: app.ID, + }) + for i := 0; i < 5; i++ { + _ = dbgen.OAuth2ProviderAppToken(s.T(), db, database.OAuth2ProviderAppToken{ + AppSecretID: secret.ID, + APIKeyID: key.ID, + }) + } + check.Args(user.ID).Asserts(rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(user.ID.String()), rbac.ActionRead).Returns([]database.GetOAuth2ProviderAppsByUserIDRow{ + { + OAuth2ProviderApp: database.OAuth2ProviderApp{ + ID: app.ID, + CallbackURL: app.CallbackURL, + Icon: app.Icon, + Name: app.Name, + }, + TokenCount: 5, + }, + }) + })) s.Run("InsertOAuth2ProviderApp", s.Subtest(func(db database.Store, check *expects) { check.Args(database.InsertOAuth2ProviderAppParams{}).Asserts(rbac.ResourceOAuth2ProviderApp, rbac.ActionCreate) })) @@ -2361,6 +2389,13 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppSecrets() { }) check.Args(secret.ID).Asserts(rbac.ResourceOAuth2ProviderAppSecret, rbac.ActionRead).Returns(secret) })) + s.Run("GetOAuth2ProviderAppSecretByPrefix", s.Subtest(func(db database.Store, check *expects) { + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + secret := dbgen.OAuth2ProviderAppSecret(s.T(), db, database.OAuth2ProviderAppSecret{ + AppID: app.ID, + }) + check.Args(secret.SecretPrefix).Asserts(rbac.ResourceOAuth2ProviderAppSecret, rbac.ActionRead).Returns(secret) + })) s.Run("InsertOAuth2ProviderAppSecret", s.Subtest(func(db database.Store, check *expects) { app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) check.Args(database.InsertOAuth2ProviderAppSecretParams{ @@ -2386,3 +2421,107 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppSecrets() { check.Args(secret.ID).Asserts(rbac.ResourceOAuth2ProviderAppSecret, rbac.ActionDelete) })) } + +func (s *MethodTestSuite) TestOAuth2ProviderAppCodes() { + s.Run("GetOAuth2ProviderAppCodeByID", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + code := dbgen.OAuth2ProviderAppCode(s.T(), db, database.OAuth2ProviderAppCode{ + AppID: app.ID, + UserID: user.ID, + }) + check.Args(code.ID).Asserts(code, rbac.ActionRead).Returns(code) + })) + s.Run("GetOAuth2ProviderAppCodeByPrefix", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + code := dbgen.OAuth2ProviderAppCode(s.T(), db, database.OAuth2ProviderAppCode{ + AppID: app.ID, + UserID: user.ID, + }) + check.Args(code.SecretPrefix).Asserts(code, rbac.ActionRead).Returns(code) + })) + s.Run("InsertOAuth2ProviderAppCode", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + check.Args(database.InsertOAuth2ProviderAppCodeParams{ + AppID: app.ID, + UserID: user.ID, + }).Asserts(rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(user.ID.String()), rbac.ActionCreate) + })) + s.Run("DeleteOAuth2ProviderAppCodeByID", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + code := dbgen.OAuth2ProviderAppCode(s.T(), db, database.OAuth2ProviderAppCode{ + AppID: app.ID, + UserID: user.ID, + }) + check.Args(code.ID).Asserts(code, rbac.ActionDelete) + })) + s.Run("DeleteOAuth2ProviderAppCodesByAppAndUserID", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + for i := 0; i < 5; i++ { + _ = dbgen.OAuth2ProviderAppCode(s.T(), db, database.OAuth2ProviderAppCode{ + AppID: app.ID, + UserID: user.ID, + }) + } + check.Args(database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams{ + AppID: app.ID, + UserID: user.ID, + }).Asserts(rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(user.ID.String()), rbac.ActionDelete) + })) +} + +func (s *MethodTestSuite) TestOAuth2ProviderAppTokens() { + s.Run("InsertOAuth2ProviderAppToken", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{ + UserID: user.ID, + }) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + secret := dbgen.OAuth2ProviderAppSecret(s.T(), db, database.OAuth2ProviderAppSecret{ + AppID: app.ID, + }) + check.Args(database.InsertOAuth2ProviderAppTokenParams{ + AppSecretID: secret.ID, + APIKeyID: key.ID, + }).Asserts(rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(user.ID.String()), rbac.ActionCreate) + })) + s.Run("GetOAuth2ProviderAppTokenByPrefix", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{ + UserID: user.ID, + }) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + secret := dbgen.OAuth2ProviderAppSecret(s.T(), db, database.OAuth2ProviderAppSecret{ + AppID: app.ID, + }) + token := dbgen.OAuth2ProviderAppToken(s.T(), db, database.OAuth2ProviderAppToken{ + AppSecretID: secret.ID, + APIKeyID: key.ID, + }) + check.Args(token.HashPrefix).Asserts(rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(user.ID.String()), rbac.ActionRead) + })) + s.Run("DeleteOAuth2ProviderAppTokensByAppAndUserID", s.Subtest(func(db database.Store, check *expects) { + user := dbgen.User(s.T(), db, database.User{}) + key, _ := dbgen.APIKey(s.T(), db, database.APIKey{ + UserID: user.ID, + }) + app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{}) + secret := dbgen.OAuth2ProviderAppSecret(s.T(), db, database.OAuth2ProviderAppSecret{ + AppID: app.ID, + }) + for i := 0; i < 5; i++ { + _ = dbgen.OAuth2ProviderAppToken(s.T(), db, database.OAuth2ProviderAppToken{ + AppSecretID: secret.ID, + APIKeyID: key.ID, + }) + } + check.Args(database.DeleteOAuth2ProviderAppTokensByAppAndUserIDParams{ + AppID: app.ID, + UserID: user.ID, + }).Asserts(rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(user.ID.String()), rbac.ActionDelete) + })) +} diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 4ab1c1e526..c24f4cb826 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -707,6 +707,7 @@ func OAuth2ProviderAppSecret(t testing.TB, db database.Store, seed database.OAut app, err := db.InsertOAuth2ProviderAppSecret(genCtx, database.InsertOAuth2ProviderAppSecretParams{ ID: takeFirst(seed.ID, uuid.New()), CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()), + SecretPrefix: takeFirstSlice(seed.SecretPrefix, []byte("prefix")), HashedSecret: takeFirstSlice(seed.HashedSecret, []byte("hashed-secret")), DisplaySecret: takeFirst(seed.DisplaySecret, "secret"), AppID: takeFirst(seed.AppID, uuid.New()), @@ -715,6 +716,34 @@ func OAuth2ProviderAppSecret(t testing.TB, db database.Store, seed database.OAut return app } +func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2ProviderAppCode) database.OAuth2ProviderAppCode { + code, err := db.InsertOAuth2ProviderAppCode(genCtx, database.InsertOAuth2ProviderAppCodeParams{ + ID: takeFirst(seed.ID, uuid.New()), + CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()), + ExpiresAt: takeFirst(seed.CreatedAt, dbtime.Now()), + SecretPrefix: takeFirstSlice(seed.SecretPrefix, []byte("prefix")), + HashedSecret: takeFirstSlice(seed.HashedSecret, []byte("hashed-secret")), + AppID: takeFirst(seed.AppID, uuid.New()), + UserID: takeFirst(seed.UserID, uuid.New()), + }) + require.NoError(t, err, "insert oauth2 app code") + return code +} + +func OAuth2ProviderAppToken(t testing.TB, db database.Store, seed database.OAuth2ProviderAppToken) database.OAuth2ProviderAppToken { + token, err := db.InsertOAuth2ProviderAppToken(genCtx, database.InsertOAuth2ProviderAppTokenParams{ + ID: takeFirst(seed.ID, uuid.New()), + CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()), + ExpiresAt: takeFirst(seed.CreatedAt, dbtime.Now()), + HashPrefix: takeFirstSlice(seed.HashPrefix, []byte("prefix")), + RefreshHash: takeFirstSlice(seed.RefreshHash, []byte("hashed-secret")), + AppSecretID: takeFirst(seed.AppSecretID, uuid.New()), + APIKeyID: takeFirst(seed.APIKeyID, uuid.New().String()), + }) + require.NoError(t, err, "insert oauth2 app token") + return token +} + func must[V any](v V, err error) V { if err != nil { panic(err) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 5fc53981e9..638fbef175 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -1,6 +1,7 @@ package dbmem import ( + "bytes" "context" "database/sql" "encoding/json" @@ -133,6 +134,8 @@ type data struct { licenses []database.License oauth2ProviderApps []database.OAuth2ProviderApp oauth2ProviderAppSecrets []database.OAuth2ProviderAppSecret + oauth2ProviderAppCodes []database.OAuth2ProviderAppCode + oauth2ProviderAppTokens []database.OAuth2ProviderAppToken parameterSchemas []database.ParameterSchema provisionerDaemons []database.ProvisionerDaemon provisionerJobLogs []database.ProvisionerJobLog @@ -1165,19 +1168,72 @@ func (q *FakeQuerier) DeleteOAuth2ProviderAppByID(_ context.Context, id uuid.UUI q.mutex.Lock() defer q.mutex.Unlock() - for index, app := range q.oauth2ProviderApps { - if app.ID == id { - q.oauth2ProviderApps[index] = q.oauth2ProviderApps[len(q.oauth2ProviderApps)-1] - q.oauth2ProviderApps = q.oauth2ProviderApps[:len(q.oauth2ProviderApps)-1] + index := slices.IndexFunc(q.oauth2ProviderApps, func(app database.OAuth2ProviderApp) bool { + return app.ID == id + }) - secrets := []database.OAuth2ProviderAppSecret{} - for _, secret := range q.oauth2ProviderAppSecrets { - if secret.AppID != id { - secrets = append(secrets, secret) - } - } - q.oauth2ProviderAppSecrets = secrets + if index < 0 { + return sql.ErrNoRows + } + q.oauth2ProviderApps[index] = q.oauth2ProviderApps[len(q.oauth2ProviderApps)-1] + q.oauth2ProviderApps = q.oauth2ProviderApps[:len(q.oauth2ProviderApps)-1] + + // Cascade delete secrets associated with the deleted app. + var deletedSecretIDs []uuid.UUID + q.oauth2ProviderAppSecrets = slices.DeleteFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { + matches := secret.AppID == id + if matches { + deletedSecretIDs = append(deletedSecretIDs, secret.ID) + } + return matches + }) + + // Cascade delete tokens through the deleted secrets. + var keyIDsToDelete []string + q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { + matches := slice.Contains(deletedSecretIDs, token.AppSecretID) + if matches { + keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) + } + return matches + }) + + // Cascade delete API keys linked to the deleted tokens. + q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { + return slices.Contains(keyIDsToDelete, key.ID) + }) + + return nil +} + +func (q *FakeQuerier) DeleteOAuth2ProviderAppCodeByID(_ context.Context, id uuid.UUID) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, code := range q.oauth2ProviderAppCodes { + if code.ID == id { + q.oauth2ProviderAppCodes[index] = q.oauth2ProviderAppCodes[len(q.oauth2ProviderAppCodes)-1] + q.oauth2ProviderAppCodes = q.oauth2ProviderAppCodes[:len(q.oauth2ProviderAppCodes)-1] + return nil + } + } + return sql.ErrNoRows +} + +func (q *FakeQuerier) DeleteOAuth2ProviderAppCodesByAppAndUserID(_ context.Context, arg database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, code := range q.oauth2ProviderAppCodes { + if code.AppID == arg.AppID && code.UserID == arg.UserID { + q.oauth2ProviderAppCodes[index] = q.oauth2ProviderAppCodes[len(q.oauth2ProviderAppCodes)-1] + q.oauth2ProviderAppCodes = q.oauth2ProviderAppCodes[:len(q.oauth2ProviderAppCodes)-1] return nil } } @@ -1188,14 +1244,68 @@ func (q *FakeQuerier) DeleteOAuth2ProviderAppSecretByID(_ context.Context, id uu q.mutex.Lock() defer q.mutex.Unlock() - for index, secret := range q.oauth2ProviderAppSecrets { - if secret.ID == id { - q.oauth2ProviderAppSecrets[index] = q.oauth2ProviderAppSecrets[len(q.oauth2ProviderAppSecrets)-1] - q.oauth2ProviderAppSecrets = q.oauth2ProviderAppSecrets[:len(q.oauth2ProviderAppSecrets)-1] - return nil - } + index := slices.IndexFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { + return secret.ID == id + }) + + if index < 0 { + return sql.ErrNoRows } - return sql.ErrNoRows + + q.oauth2ProviderAppSecrets[index] = q.oauth2ProviderAppSecrets[len(q.oauth2ProviderAppSecrets)-1] + q.oauth2ProviderAppSecrets = q.oauth2ProviderAppSecrets[:len(q.oauth2ProviderAppSecrets)-1] + + // Cascade delete tokens created through the deleted secret. + var keyIDsToDelete []string + q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { + matches := token.AppSecretID == id + if matches { + keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) + } + return matches + }) + + // Cascade delete API keys linked to the deleted tokens. + q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { + return slices.Contains(keyIDsToDelete, key.ID) + }) + + return nil +} + +func (q *FakeQuerier) DeleteOAuth2ProviderAppTokensByAppAndUserID(_ context.Context, arg database.DeleteOAuth2ProviderAppTokensByAppAndUserIDParams) error { + err := validateDatabaseType(arg) + if err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + var keyIDsToDelete []string + q.oauth2ProviderAppTokens = slices.DeleteFunc(q.oauth2ProviderAppTokens, func(token database.OAuth2ProviderAppToken) bool { + // Join secrets and keys to see if the token matches. + secretIdx := slices.IndexFunc(q.oauth2ProviderAppSecrets, func(secret database.OAuth2ProviderAppSecret) bool { + return secret.ID == token.AppSecretID + }) + keyIdx := slices.IndexFunc(q.apiKeys, func(key database.APIKey) bool { + return key.ID == token.APIKeyID + }) + matches := secretIdx != -1 && + q.oauth2ProviderAppSecrets[secretIdx].AppID == arg.AppID && + keyIdx != -1 && q.apiKeys[keyIdx].UserID == arg.UserID + if matches { + keyIDsToDelete = append(keyIDsToDelete, token.APIKeyID) + } + return matches + }) + + // Cascade delete API keys linked to the deleted tokens. + q.apiKeys = slices.DeleteFunc(q.apiKeys, func(key database.APIKey) bool { + return slices.Contains(keyIDsToDelete, key.ID) + }) + + return nil } func (q *FakeQuerier) DeleteOldProvisionerDaemons(_ context.Context) error { @@ -2138,6 +2248,30 @@ func (q *FakeQuerier) GetOAuth2ProviderAppByID(_ context.Context, id uuid.UUID) return database.OAuth2ProviderApp{}, sql.ErrNoRows } +func (q *FakeQuerier) GetOAuth2ProviderAppCodeByID(_ context.Context, id uuid.UUID) (database.OAuth2ProviderAppCode, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, code := range q.oauth2ProviderAppCodes { + if code.ID == id { + return code, nil + } + } + return database.OAuth2ProviderAppCode{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetOAuth2ProviderAppCodeByPrefix(_ context.Context, secretPrefix []byte) (database.OAuth2ProviderAppCode, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, code := range q.oauth2ProviderAppCodes { + if bytes.Equal(code.SecretPrefix, secretPrefix) { + return code, nil + } + } + return database.OAuth2ProviderAppCode{}, sql.ErrNoRows +} + func (q *FakeQuerier) GetOAuth2ProviderAppSecretByID(_ context.Context, id uuid.UUID) (database.OAuth2ProviderAppSecret, error) { q.mutex.Lock() defer q.mutex.Unlock() @@ -2150,6 +2284,18 @@ func (q *FakeQuerier) GetOAuth2ProviderAppSecretByID(_ context.Context, id uuid. return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows } +func (q *FakeQuerier) GetOAuth2ProviderAppSecretByPrefix(_ context.Context, secretPrefix []byte) (database.OAuth2ProviderAppSecret, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, secret := range q.oauth2ProviderAppSecrets { + if bytes.Equal(secret.SecretPrefix, secretPrefix) { + return secret, nil + } + } + return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows +} + func (q *FakeQuerier) GetOAuth2ProviderAppSecretsByAppID(_ context.Context, appID uuid.UUID) ([]database.OAuth2ProviderAppSecret, error) { q.mutex.Lock() defer q.mutex.Unlock() @@ -2178,6 +2324,18 @@ func (q *FakeQuerier) GetOAuth2ProviderAppSecretsByAppID(_ context.Context, appI return []database.OAuth2ProviderAppSecret{}, sql.ErrNoRows } +func (q *FakeQuerier) GetOAuth2ProviderAppTokenByPrefix(_ context.Context, hashPrefix []byte) (database.OAuth2ProviderAppToken, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, token := range q.oauth2ProviderAppTokens { + if bytes.Equal(token.HashPrefix, hashPrefix) { + return token, nil + } + } + return database.OAuth2ProviderAppToken{}, sql.ErrNoRows +} + func (q *FakeQuerier) GetOAuth2ProviderApps(_ context.Context) ([]database.OAuth2ProviderApp, error) { q.mutex.Lock() defer q.mutex.Unlock() @@ -2188,6 +2346,42 @@ func (q *FakeQuerier) GetOAuth2ProviderApps(_ context.Context) ([]database.OAuth return q.oauth2ProviderApps, nil } +func (q *FakeQuerier) GetOAuth2ProviderAppsByUserID(_ context.Context, userID uuid.UUID) ([]database.GetOAuth2ProviderAppsByUserIDRow, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + rows := []database.GetOAuth2ProviderAppsByUserIDRow{} + for _, app := range q.oauth2ProviderApps { + tokens := []database.OAuth2ProviderAppToken{} + for _, secret := range q.oauth2ProviderAppSecrets { + if secret.AppID == app.ID { + for _, token := range q.oauth2ProviderAppTokens { + if token.AppSecretID == secret.ID { + keyIdx := slices.IndexFunc(q.apiKeys, func(key database.APIKey) bool { + return key.ID == token.APIKeyID + }) + if keyIdx != -1 && q.apiKeys[keyIdx].UserID == userID { + tokens = append(tokens, token) + } + } + } + } + } + if len(tokens) > 0 { + rows = append(rows, database.GetOAuth2ProviderAppsByUserIDRow{ + OAuth2ProviderApp: database.OAuth2ProviderApp{ + CallbackURL: app.CallbackURL, + ID: app.ID, + Icon: app.Icon, + Name: app.Name, + }, + TokenCount: int64(len(tokens)), + }) + } + } + return rows, nil +} + func (q *FakeQuerier) GetOAuthSigningKey(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -5240,6 +5434,34 @@ func (q *FakeQuerier) InsertOAuth2ProviderApp(_ context.Context, arg database.In return app, nil } +func (q *FakeQuerier) InsertOAuth2ProviderAppCode(_ context.Context, arg database.InsertOAuth2ProviderAppCodeParams) (database.OAuth2ProviderAppCode, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.OAuth2ProviderAppCode{}, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, app := range q.oauth2ProviderApps { + if app.ID == arg.AppID { + code := database.OAuth2ProviderAppCode{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + ExpiresAt: arg.ExpiresAt, + SecretPrefix: arg.SecretPrefix, + HashedSecret: arg.HashedSecret, + UserID: arg.UserID, + AppID: arg.AppID, + } + q.oauth2ProviderAppCodes = append(q.oauth2ProviderAppCodes, code) + return code, nil + } + } + + return database.OAuth2ProviderAppCode{}, sql.ErrNoRows +} + func (q *FakeQuerier) InsertOAuth2ProviderAppSecret(_ context.Context, arg database.InsertOAuth2ProviderAppSecretParams) (database.OAuth2ProviderAppSecret, error) { err := validateDatabaseType(arg) if err != nil { @@ -5254,6 +5476,7 @@ func (q *FakeQuerier) InsertOAuth2ProviderAppSecret(_ context.Context, arg datab secret := database.OAuth2ProviderAppSecret{ ID: arg.ID, CreatedAt: arg.CreatedAt, + SecretPrefix: arg.SecretPrefix, HashedSecret: arg.HashedSecret, DisplaySecret: arg.DisplaySecret, AppID: arg.AppID, @@ -5266,6 +5489,35 @@ func (q *FakeQuerier) InsertOAuth2ProviderAppSecret(_ context.Context, arg datab return database.OAuth2ProviderAppSecret{}, sql.ErrNoRows } +func (q *FakeQuerier) InsertOAuth2ProviderAppToken(_ context.Context, arg database.InsertOAuth2ProviderAppTokenParams) (database.OAuth2ProviderAppToken, error) { + err := validateDatabaseType(arg) + if err != nil { + return database.OAuth2ProviderAppToken{}, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, secret := range q.oauth2ProviderAppSecrets { + if secret.ID == arg.AppSecretID { + //nolint:gosimple // Go wants database.OAuth2ProviderAppToken(arg), but we cannot be sure the structs will remain identical. + token := database.OAuth2ProviderAppToken{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + ExpiresAt: arg.ExpiresAt, + HashPrefix: arg.HashPrefix, + RefreshHash: arg.RefreshHash, + APIKeyID: arg.APIKeyID, + AppSecretID: arg.AppSecretID, + } + q.oauth2ProviderAppTokens = append(q.oauth2ProviderAppTokens, token) + return token, nil + } + } + + return database.OAuth2ProviderAppToken{}, sql.ErrNoRows +} + func (q *FakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { if err := validateDatabaseType(arg); err != nil { return database.Organization{}, err @@ -6372,6 +6624,7 @@ func (q *FakeQuerier) UpdateOAuth2ProviderAppSecretByID(_ context.Context, arg d newSecret := database.OAuth2ProviderAppSecret{ ID: arg.ID, CreatedAt: secret.CreatedAt, + SecretPrefix: secret.SecretPrefix, HashedSecret: secret.HashedSecret, DisplaySecret: secret.DisplaySecret, AppID: secret.AppID, diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index e99ae24d9a..efbf018480 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -225,6 +225,20 @@ func (m metricsStore) DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.U return r0 } +func (m metricsStore) DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error { + start := time.Now() + r0 := m.s.DeleteOAuth2ProviderAppCodeByID(ctx, id) + m.queryLatencies.WithLabelValues("DeleteOAuth2ProviderAppCodeByID").Observe(time.Since(start).Seconds()) + return r0 +} + +func (m metricsStore) DeleteOAuth2ProviderAppCodesByAppAndUserID(ctx context.Context, arg database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams) error { + start := time.Now() + r0 := m.s.DeleteOAuth2ProviderAppCodesByAppAndUserID(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteOAuth2ProviderAppCodesByAppAndUserID").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) DeleteOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) error { start := time.Now() r0 := m.s.DeleteOAuth2ProviderAppSecretByID(ctx, id) @@ -232,6 +246,13 @@ func (m metricsStore) DeleteOAuth2ProviderAppSecretByID(ctx context.Context, id return r0 } +func (m metricsStore) DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx context.Context, arg database.DeleteOAuth2ProviderAppTokensByAppAndUserIDParams) error { + start := time.Now() + r0 := m.s.DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteOAuth2ProviderAppTokensByAppAndUserID").Observe(time.Since(start).Seconds()) + return r0 +} + func (m metricsStore) DeleteOldProvisionerDaemons(ctx context.Context) error { start := time.Now() r0 := m.s.DeleteOldProvisionerDaemons(ctx) @@ -615,6 +636,20 @@ func (m metricsStore) GetOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID return r0, r1 } +func (m metricsStore) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderAppCode, error) { + start := time.Now() + r0, r1 := m.s.GetOAuth2ProviderAppCodeByID(ctx, id) + m.queryLatencies.WithLabelValues("GetOAuth2ProviderAppCodeByID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m metricsStore) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (database.OAuth2ProviderAppCode, error) { + start := time.Now() + r0, r1 := m.s.GetOAuth2ProviderAppCodeByPrefix(ctx, secretPrefix) + m.queryLatencies.WithLabelValues("GetOAuth2ProviderAppCodeByPrefix").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderAppSecret, error) { start := time.Now() r0, r1 := m.s.GetOAuth2ProviderAppSecretByID(ctx, id) @@ -622,6 +657,13 @@ func (m metricsStore) GetOAuth2ProviderAppSecretByID(ctx context.Context, id uui return r0, r1 } +func (m metricsStore) GetOAuth2ProviderAppSecretByPrefix(ctx context.Context, secretPrefix []byte) (database.OAuth2ProviderAppSecret, error) { + start := time.Now() + r0, r1 := m.s.GetOAuth2ProviderAppSecretByPrefix(ctx, secretPrefix) + m.queryLatencies.WithLabelValues("GetOAuth2ProviderAppSecretByPrefix").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, appID uuid.UUID) ([]database.OAuth2ProviderAppSecret, error) { start := time.Now() r0, r1 := m.s.GetOAuth2ProviderAppSecretsByAppID(ctx, appID) @@ -629,6 +671,13 @@ func (m metricsStore) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, ap return r0, r1 } +func (m metricsStore) GetOAuth2ProviderAppTokenByPrefix(ctx context.Context, hashPrefix []byte) (database.OAuth2ProviderAppToken, error) { + start := time.Now() + r0, r1 := m.s.GetOAuth2ProviderAppTokenByPrefix(ctx, hashPrefix) + m.queryLatencies.WithLabelValues("GetOAuth2ProviderAppTokenByPrefix").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetOAuth2ProviderApps(ctx context.Context) ([]database.OAuth2ProviderApp, error) { start := time.Now() r0, r1 := m.s.GetOAuth2ProviderApps(ctx) @@ -636,6 +685,13 @@ func (m metricsStore) GetOAuth2ProviderApps(ctx context.Context) ([]database.OAu return r0, r1 } +func (m metricsStore) GetOAuth2ProviderAppsByUserID(ctx context.Context, userID uuid.UUID) ([]database.GetOAuth2ProviderAppsByUserIDRow, error) { + start := time.Now() + r0, r1 := m.s.GetOAuth2ProviderAppsByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetOAuth2ProviderAppsByUserID").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetOAuthSigningKey(ctx context.Context) (string, error) { start := time.Now() r0, r1 := m.s.GetOAuthSigningKey(ctx) @@ -1425,6 +1481,13 @@ func (m metricsStore) InsertOAuth2ProviderApp(ctx context.Context, arg database. return r0, r1 } +func (m metricsStore) InsertOAuth2ProviderAppCode(ctx context.Context, arg database.InsertOAuth2ProviderAppCodeParams) (database.OAuth2ProviderAppCode, error) { + start := time.Now() + r0, r1 := m.s.InsertOAuth2ProviderAppCode(ctx, arg) + m.queryLatencies.WithLabelValues("InsertOAuth2ProviderAppCode").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) InsertOAuth2ProviderAppSecret(ctx context.Context, arg database.InsertOAuth2ProviderAppSecretParams) (database.OAuth2ProviderAppSecret, error) { start := time.Now() r0, r1 := m.s.InsertOAuth2ProviderAppSecret(ctx, arg) @@ -1432,6 +1495,13 @@ func (m metricsStore) InsertOAuth2ProviderAppSecret(ctx context.Context, arg dat return r0, r1 } +func (m metricsStore) InsertOAuth2ProviderAppToken(ctx context.Context, arg database.InsertOAuth2ProviderAppTokenParams) (database.OAuth2ProviderAppToken, error) { + start := time.Now() + r0, r1 := m.s.InsertOAuth2ProviderAppToken(ctx, arg) + m.queryLatencies.WithLabelValues("InsertOAuth2ProviderAppToken").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { start := time.Now() organization, err := m.s.InsertOrganization(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 062df2b886..5bcca0b1a2 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -342,6 +342,34 @@ func (mr *MockStoreMockRecorder) DeleteOAuth2ProviderAppByID(arg0, arg1 any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOAuth2ProviderAppByID", reflect.TypeOf((*MockStore)(nil).DeleteOAuth2ProviderAppByID), arg0, arg1) } +// DeleteOAuth2ProviderAppCodeByID mocks base method. +func (m *MockStore) DeleteOAuth2ProviderAppCodeByID(arg0 context.Context, arg1 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteOAuth2ProviderAppCodeByID", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteOAuth2ProviderAppCodeByID indicates an expected call of DeleteOAuth2ProviderAppCodeByID. +func (mr *MockStoreMockRecorder) DeleteOAuth2ProviderAppCodeByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOAuth2ProviderAppCodeByID", reflect.TypeOf((*MockStore)(nil).DeleteOAuth2ProviderAppCodeByID), arg0, arg1) +} + +// DeleteOAuth2ProviderAppCodesByAppAndUserID mocks base method. +func (m *MockStore) DeleteOAuth2ProviderAppCodesByAppAndUserID(arg0 context.Context, arg1 database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteOAuth2ProviderAppCodesByAppAndUserID", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteOAuth2ProviderAppCodesByAppAndUserID indicates an expected call of DeleteOAuth2ProviderAppCodesByAppAndUserID. +func (mr *MockStoreMockRecorder) DeleteOAuth2ProviderAppCodesByAppAndUserID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOAuth2ProviderAppCodesByAppAndUserID", reflect.TypeOf((*MockStore)(nil).DeleteOAuth2ProviderAppCodesByAppAndUserID), arg0, arg1) +} + // DeleteOAuth2ProviderAppSecretByID mocks base method. func (m *MockStore) DeleteOAuth2ProviderAppSecretByID(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() @@ -356,6 +384,20 @@ func (mr *MockStoreMockRecorder) DeleteOAuth2ProviderAppSecretByID(arg0, arg1 an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOAuth2ProviderAppSecretByID", reflect.TypeOf((*MockStore)(nil).DeleteOAuth2ProviderAppSecretByID), arg0, arg1) } +// DeleteOAuth2ProviderAppTokensByAppAndUserID mocks base method. +func (m *MockStore) DeleteOAuth2ProviderAppTokensByAppAndUserID(arg0 context.Context, arg1 database.DeleteOAuth2ProviderAppTokensByAppAndUserIDParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteOAuth2ProviderAppTokensByAppAndUserID", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteOAuth2ProviderAppTokensByAppAndUserID indicates an expected call of DeleteOAuth2ProviderAppTokensByAppAndUserID. +func (mr *MockStoreMockRecorder) DeleteOAuth2ProviderAppTokensByAppAndUserID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOAuth2ProviderAppTokensByAppAndUserID", reflect.TypeOf((*MockStore)(nil).DeleteOAuth2ProviderAppTokensByAppAndUserID), arg0, arg1) +} + // DeleteOldProvisionerDaemons mocks base method. func (m *MockStore) DeleteOldProvisionerDaemons(arg0 context.Context) error { m.ctrl.T.Helper() @@ -1219,6 +1261,36 @@ func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppByID(arg0, arg1 any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppByID", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppByID), arg0, arg1) } +// GetOAuth2ProviderAppCodeByID mocks base method. +func (m *MockStore) GetOAuth2ProviderAppCodeByID(arg0 context.Context, arg1 uuid.UUID) (database.OAuth2ProviderAppCode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOAuth2ProviderAppCodeByID", arg0, arg1) + ret0, _ := ret[0].(database.OAuth2ProviderAppCode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOAuth2ProviderAppCodeByID indicates an expected call of GetOAuth2ProviderAppCodeByID. +func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppCodeByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppCodeByID", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppCodeByID), arg0, arg1) +} + +// GetOAuth2ProviderAppCodeByPrefix mocks base method. +func (m *MockStore) GetOAuth2ProviderAppCodeByPrefix(arg0 context.Context, arg1 []byte) (database.OAuth2ProviderAppCode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOAuth2ProviderAppCodeByPrefix", arg0, arg1) + ret0, _ := ret[0].(database.OAuth2ProviderAppCode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOAuth2ProviderAppCodeByPrefix indicates an expected call of GetOAuth2ProviderAppCodeByPrefix. +func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppCodeByPrefix(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppCodeByPrefix", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppCodeByPrefix), arg0, arg1) +} + // GetOAuth2ProviderAppSecretByID mocks base method. func (m *MockStore) GetOAuth2ProviderAppSecretByID(arg0 context.Context, arg1 uuid.UUID) (database.OAuth2ProviderAppSecret, error) { m.ctrl.T.Helper() @@ -1234,6 +1306,21 @@ func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppSecretByID(arg0, arg1 any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppSecretByID", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppSecretByID), arg0, arg1) } +// GetOAuth2ProviderAppSecretByPrefix mocks base method. +func (m *MockStore) GetOAuth2ProviderAppSecretByPrefix(arg0 context.Context, arg1 []byte) (database.OAuth2ProviderAppSecret, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOAuth2ProviderAppSecretByPrefix", arg0, arg1) + ret0, _ := ret[0].(database.OAuth2ProviderAppSecret) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOAuth2ProviderAppSecretByPrefix indicates an expected call of GetOAuth2ProviderAppSecretByPrefix. +func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppSecretByPrefix(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppSecretByPrefix", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppSecretByPrefix), arg0, arg1) +} + // GetOAuth2ProviderAppSecretsByAppID mocks base method. func (m *MockStore) GetOAuth2ProviderAppSecretsByAppID(arg0 context.Context, arg1 uuid.UUID) ([]database.OAuth2ProviderAppSecret, error) { m.ctrl.T.Helper() @@ -1249,6 +1336,21 @@ func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppSecretsByAppID(arg0, arg1 a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppSecretsByAppID", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppSecretsByAppID), arg0, arg1) } +// GetOAuth2ProviderAppTokenByPrefix mocks base method. +func (m *MockStore) GetOAuth2ProviderAppTokenByPrefix(arg0 context.Context, arg1 []byte) (database.OAuth2ProviderAppToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOAuth2ProviderAppTokenByPrefix", arg0, arg1) + ret0, _ := ret[0].(database.OAuth2ProviderAppToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOAuth2ProviderAppTokenByPrefix indicates an expected call of GetOAuth2ProviderAppTokenByPrefix. +func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppTokenByPrefix(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppTokenByPrefix", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppTokenByPrefix), arg0, arg1) +} + // GetOAuth2ProviderApps mocks base method. func (m *MockStore) GetOAuth2ProviderApps(arg0 context.Context) ([]database.OAuth2ProviderApp, error) { m.ctrl.T.Helper() @@ -1264,6 +1366,21 @@ func (mr *MockStoreMockRecorder) GetOAuth2ProviderApps(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderApps", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderApps), arg0) } +// GetOAuth2ProviderAppsByUserID mocks base method. +func (m *MockStore) GetOAuth2ProviderAppsByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.GetOAuth2ProviderAppsByUserIDRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOAuth2ProviderAppsByUserID", arg0, arg1) + ret0, _ := ret[0].([]database.GetOAuth2ProviderAppsByUserIDRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOAuth2ProviderAppsByUserID indicates an expected call of GetOAuth2ProviderAppsByUserID. +func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppsByUserID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppsByUserID", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppsByUserID), arg0, arg1) +} + // GetOAuthSigningKey mocks base method. func (m *MockStore) GetOAuthSigningKey(arg0 context.Context) (string, error) { m.ctrl.T.Helper() @@ -2999,6 +3116,21 @@ func (mr *MockStoreMockRecorder) InsertOAuth2ProviderApp(arg0, arg1 any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOAuth2ProviderApp", reflect.TypeOf((*MockStore)(nil).InsertOAuth2ProviderApp), arg0, arg1) } +// InsertOAuth2ProviderAppCode mocks base method. +func (m *MockStore) InsertOAuth2ProviderAppCode(arg0 context.Context, arg1 database.InsertOAuth2ProviderAppCodeParams) (database.OAuth2ProviderAppCode, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertOAuth2ProviderAppCode", arg0, arg1) + ret0, _ := ret[0].(database.OAuth2ProviderAppCode) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertOAuth2ProviderAppCode indicates an expected call of InsertOAuth2ProviderAppCode. +func (mr *MockStoreMockRecorder) InsertOAuth2ProviderAppCode(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOAuth2ProviderAppCode", reflect.TypeOf((*MockStore)(nil).InsertOAuth2ProviderAppCode), arg0, arg1) +} + // InsertOAuth2ProviderAppSecret mocks base method. func (m *MockStore) InsertOAuth2ProviderAppSecret(arg0 context.Context, arg1 database.InsertOAuth2ProviderAppSecretParams) (database.OAuth2ProviderAppSecret, error) { m.ctrl.T.Helper() @@ -3014,6 +3146,21 @@ func (mr *MockStoreMockRecorder) InsertOAuth2ProviderAppSecret(arg0, arg1 any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOAuth2ProviderAppSecret", reflect.TypeOf((*MockStore)(nil).InsertOAuth2ProviderAppSecret), arg0, arg1) } +// InsertOAuth2ProviderAppToken mocks base method. +func (m *MockStore) InsertOAuth2ProviderAppToken(arg0 context.Context, arg1 database.InsertOAuth2ProviderAppTokenParams) (database.OAuth2ProviderAppToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertOAuth2ProviderAppToken", arg0, arg1) + ret0, _ := ret[0].(database.OAuth2ProviderAppToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertOAuth2ProviderAppToken indicates an expected call of InsertOAuth2ProviderAppToken. +func (mr *MockStoreMockRecorder) InsertOAuth2ProviderAppToken(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOAuth2ProviderAppToken", reflect.TypeOf((*MockStore)(nil).InsertOAuth2ProviderAppToken), arg0, arg1) +} + // InsertOrganization mocks base method. func (m *MockStore) InsertOrganization(arg0 context.Context, arg1 database.InsertOrganizationParams) (database.Organization, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index d7c699dd93..1818a03b8b 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -67,7 +67,8 @@ CREATE TYPE login_type AS ENUM ( 'github', 'oidc', 'token', - 'none' + 'none', + 'oauth2_provider_app' ); COMMENT ON TYPE login_type IS 'Specifies the method of authentication. "none" is a special case in which no authentication method is allowed.'; @@ -187,6 +188,17 @@ CREATE TYPE workspace_transition AS ENUM ( 'delete' ); +CREATE FUNCTION delete_deleted_oauth2_provider_app_token_api_key() RETURNS trigger + LANGUAGE plpgsql + AS $$ +DECLARE +BEGIN + DELETE FROM api_keys + WHERE id = OLD.api_key_id; + RETURN OLD; +END; +$$; + CREATE FUNCTION delete_deleted_user_resources() RETURNS trigger LANGUAGE plpgsql AS $$ @@ -490,17 +502,42 @@ CREATE SEQUENCE licenses_id_seq ALTER SEQUENCE licenses_id_seq OWNED BY licenses.id; +CREATE TABLE oauth2_provider_app_codes ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + expires_at timestamp with time zone NOT NULL, + secret_prefix bytea NOT NULL, + hashed_secret bytea NOT NULL, + user_id uuid NOT NULL, + app_id uuid NOT NULL +); + +COMMENT ON TABLE oauth2_provider_app_codes IS 'Codes are meant to be exchanged for access tokens.'; + CREATE TABLE oauth2_provider_app_secrets ( id uuid NOT NULL, created_at timestamp with time zone NOT NULL, last_used_at timestamp with time zone, hashed_secret bytea NOT NULL, display_secret text NOT NULL, - app_id uuid NOT NULL + app_id uuid NOT NULL, + secret_prefix bytea NOT NULL ); COMMENT ON COLUMN oauth2_provider_app_secrets.display_secret IS 'The tail end of the original secret so secrets can be differentiated.'; +CREATE TABLE oauth2_provider_app_tokens ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + expires_at timestamp with time zone NOT NULL, + hash_prefix bytea NOT NULL, + refresh_hash bytea NOT NULL, + app_secret_id uuid NOT NULL, + api_key_id text NOT NULL +); + +COMMENT ON COLUMN oauth2_provider_app_tokens.refresh_hash IS 'Refresh tokens provide a way to refresh an access token (API key). An expired API key can be refreshed if this token is not yet expired, meaning this expiry can outlive an API key.'; + CREATE TABLE oauth2_provider_apps ( id uuid NOT NULL, created_at timestamp with time zone NOT NULL, @@ -1354,12 +1391,24 @@ ALTER TABLE ONLY licenses ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id); -ALTER TABLE ONLY oauth2_provider_app_secrets - ADD CONSTRAINT oauth2_provider_app_secrets_app_id_hashed_secret_key UNIQUE (app_id, hashed_secret); +ALTER TABLE ONLY oauth2_provider_app_codes + ADD CONSTRAINT oauth2_provider_app_codes_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY oauth2_provider_app_codes + ADD CONSTRAINT oauth2_provider_app_codes_secret_prefix_key UNIQUE (secret_prefix); ALTER TABLE ONLY oauth2_provider_app_secrets ADD CONSTRAINT oauth2_provider_app_secrets_pkey PRIMARY KEY (id); +ALTER TABLE ONLY oauth2_provider_app_secrets + ADD CONSTRAINT oauth2_provider_app_secrets_secret_prefix_key UNIQUE (secret_prefix); + +ALTER TABLE ONLY oauth2_provider_app_tokens + ADD CONSTRAINT oauth2_provider_app_tokens_hash_prefix_key UNIQUE (hash_prefix); + +ALTER TABLE ONLY oauth2_provider_app_tokens + ADD CONSTRAINT oauth2_provider_app_tokens_pkey PRIMARY KEY (id); + ALTER TABLE ONLY oauth2_provider_apps ADD CONSTRAINT oauth2_provider_apps_name_key UNIQUE (name); @@ -1572,6 +1621,8 @@ CREATE TRIGGER tailnet_notify_peer_change AFTER INSERT OR DELETE OR UPDATE ON ta CREATE TRIGGER tailnet_notify_tunnel_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_tunnels FOR EACH ROW EXECUTE FUNCTION tailnet_notify_tunnel_change(); +CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_provider_app_tokens FOR EACH ROW EXECUTE FUNCTION delete_deleted_oauth2_provider_app_token_api_key(); + CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted(); CREATE TRIGGER trigger_update_users AFTER INSERT OR UPDATE ON users FOR EACH ROW WHEN ((new.deleted = true)) EXECUTE FUNCTION delete_deleted_user_resources(); @@ -1605,9 +1656,21 @@ ALTER TABLE ONLY jfrog_xray_scans ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE; +ALTER TABLE ONLY oauth2_provider_app_codes + ADD CONSTRAINT oauth2_provider_app_codes_app_id_fkey FOREIGN KEY (app_id) REFERENCES oauth2_provider_apps(id) ON DELETE CASCADE; + +ALTER TABLE ONLY oauth2_provider_app_codes + ADD CONSTRAINT oauth2_provider_app_codes_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY oauth2_provider_app_secrets ADD CONSTRAINT oauth2_provider_app_secrets_app_id_fkey FOREIGN KEY (app_id) REFERENCES oauth2_provider_apps(id) ON DELETE CASCADE; +ALTER TABLE ONLY oauth2_provider_app_tokens + ADD CONSTRAINT oauth2_provider_app_tokens_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE CASCADE; + +ALTER TABLE ONLY oauth2_provider_app_tokens + ADD CONSTRAINT oauth2_provider_app_tokens_app_secret_id_fkey FOREIGN KEY (app_secret_id) REFERENCES oauth2_provider_app_secrets(id) ON DELETE CASCADE; + ALTER TABLE ONLY organization_members ADD CONSTRAINT organization_members_organization_id_uuid_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 8428d48968..ad9ef76cbb 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -15,7 +15,11 @@ const ( ForeignKeyGroupsOrganizationID ForeignKeyConstraint = "groups_organization_id_fkey" // ALTER TABLE ONLY groups ADD CONSTRAINT groups_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; ForeignKeyJfrogXrayScansAgentID ForeignKeyConstraint = "jfrog_xray_scans_agent_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE; ForeignKeyJfrogXrayScansWorkspaceID ForeignKeyConstraint = "jfrog_xray_scans_workspace_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE; + ForeignKeyOauth2ProviderAppCodesAppID ForeignKeyConstraint = "oauth2_provider_app_codes_app_id_fkey" // ALTER TABLE ONLY oauth2_provider_app_codes ADD CONSTRAINT oauth2_provider_app_codes_app_id_fkey FOREIGN KEY (app_id) REFERENCES oauth2_provider_apps(id) ON DELETE CASCADE; + ForeignKeyOauth2ProviderAppCodesUserID ForeignKeyConstraint = "oauth2_provider_app_codes_user_id_fkey" // ALTER TABLE ONLY oauth2_provider_app_codes ADD CONSTRAINT oauth2_provider_app_codes_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyOauth2ProviderAppSecretsAppID ForeignKeyConstraint = "oauth2_provider_app_secrets_app_id_fkey" // ALTER TABLE ONLY oauth2_provider_app_secrets ADD CONSTRAINT oauth2_provider_app_secrets_app_id_fkey FOREIGN KEY (app_id) REFERENCES oauth2_provider_apps(id) ON DELETE CASCADE; + ForeignKeyOauth2ProviderAppTokensAPIKeyID ForeignKeyConstraint = "oauth2_provider_app_tokens_api_key_id_fkey" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT oauth2_provider_app_tokens_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE CASCADE; + ForeignKeyOauth2ProviderAppTokensAppSecretID ForeignKeyConstraint = "oauth2_provider_app_tokens_app_secret_id_fkey" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT oauth2_provider_app_tokens_app_secret_id_fkey FOREIGN KEY (app_secret_id) REFERENCES oauth2_provider_app_secrets(id) ON DELETE CASCADE; ForeignKeyOrganizationMembersOrganizationIDUUID ForeignKeyConstraint = "organization_members_organization_id_uuid_fkey" // ALTER TABLE ONLY organization_members ADD CONSTRAINT organization_members_organization_id_uuid_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; ForeignKeyOrganizationMembersUserIDUUID ForeignKeyConstraint = "organization_members_user_id_uuid_fkey" // ALTER TABLE ONLY organization_members ADD CONSTRAINT organization_members_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyParameterSchemasJobID ForeignKeyConstraint = "parameter_schemas_job_id_fkey" // ALTER TABLE ONLY parameter_schemas ADD CONSTRAINT parameter_schemas_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000195_oauth2_provider_codes.down.sql b/coderd/database/migrations/000195_oauth2_provider_codes.down.sql new file mode 100644 index 0000000000..320e088a95 --- /dev/null +++ b/coderd/database/migrations/000195_oauth2_provider_codes.down.sql @@ -0,0 +1,18 @@ +DROP TRIGGER IF EXISTS trigger_delete_oauth2_provider_app_token ON oauth2_provider_app_tokens; +DROP FUNCTION IF EXISTS delete_deleted_oauth2_provider_app_token_api_key; + +DROP TABLE oauth2_provider_app_tokens; +DROP TABLE oauth2_provider_app_codes; + +-- It is not possible to drop enum values from enum types, so the UP on +-- login_type has "IF NOT EXISTS". + +-- The constraints on the secret prefix (which is used as an id embedded in the +-- secret) are dropped, but avoid completely reverting back to the previous +-- behavior since that will render existing secrets unusable once upgraded +-- again. OAuth2 is blocked outside of development mode in previous versions, +-- so users will not be able to create broken secrets. This is really just to +-- make sure tests keep working (say for a bisect). +ALTER TABLE ONLY oauth2_provider_app_secrets + DROP CONSTRAINT oauth2_provider_app_secrets_secret_prefix_key, + ALTER COLUMN secret_prefix DROP NOT NULL; diff --git a/coderd/database/migrations/000195_oauth2_provider_codes.up.sql b/coderd/database/migrations/000195_oauth2_provider_codes.up.sql new file mode 100644 index 0000000000..d21d947d07 --- /dev/null +++ b/coderd/database/migrations/000195_oauth2_provider_codes.up.sql @@ -0,0 +1,65 @@ +CREATE TABLE oauth2_provider_app_codes ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + expires_at timestamp with time zone NOT NULL, + secret_prefix bytea NOT NULL, + hashed_secret bytea NOT NULL, + user_id uuid NOT NULL REFERENCES users (id) ON DELETE CASCADE, + app_id uuid NOT NULL REFERENCES oauth2_provider_apps (id) ON DELETE CASCADE, + PRIMARY KEY (id), + UNIQUE(secret_prefix) +); + +COMMENT ON TABLE oauth2_provider_app_codes IS 'Codes are meant to be exchanged for access tokens.'; + +CREATE TABLE oauth2_provider_app_tokens ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + expires_at timestamp with time zone NOT NULL, + hash_prefix bytea NOT NULL, + refresh_hash bytea NOT NULL, + app_secret_id uuid NOT NULL REFERENCES oauth2_provider_app_secrets (id) ON DELETE CASCADE, + api_key_id text NOT NULL REFERENCES api_keys (id) ON DELETE CASCADE, + PRIMARY KEY (id), + UNIQUE(hash_prefix) +); + +COMMENT ON COLUMN oauth2_provider_app_tokens.refresh_hash IS 'Refresh tokens provide a way to refresh an access token (API key). An expired API key can be refreshed if this token is not yet expired, meaning this expiry can outlive an API key.'; + +-- When we delete a token, delete the API key associated with it. +CREATE FUNCTION delete_deleted_oauth2_provider_app_token_api_key() RETURNS trigger + LANGUAGE plpgsql + AS $$ +DECLARE +BEGIN + DELETE FROM api_keys + WHERE id = OLD.api_key_id; + RETURN OLD; +END; +$$; + +CREATE TRIGGER trigger_delete_oauth2_provider_app_token +AFTER DELETE ON oauth2_provider_app_tokens +FOR EACH ROW +EXECUTE PROCEDURE delete_deleted_oauth2_provider_app_token_api_key(); + +ALTER TYPE login_type ADD VALUE IF NOT EXISTS 'oauth2_provider_app'; + +-- Switch to an ID we will prefix to the raw secret that we give to the user +-- (instead of matching on the entire secret as the ID, since they will be +-- salted and we can no longer do that). OAuth2 is blocked outside of +-- development mode so there should be no production secrets unless they +-- previously upgraded, in which case they keep their original prefixes and will +-- be fine. Add a random ID for the development mode case so the upgrade does +-- not fail, at least. +ALTER TABLE ONLY oauth2_provider_app_secrets + ADD COLUMN IF NOT EXISTS secret_prefix bytea NULL; + +UPDATE oauth2_provider_app_secrets + SET secret_prefix = substr(md5(random()::text), 0, 10)::bytea + WHERE secret_prefix IS NULL; + +ALTER TABLE ONLY oauth2_provider_app_secrets + ALTER COLUMN secret_prefix SET NOT NULL, + ADD CONSTRAINT oauth2_provider_app_secrets_secret_prefix_key UNIQUE (secret_prefix), + DROP CONSTRAINT IF EXISTS oauth2_provider_app_secrets_app_id_hashed_secret_key; diff --git a/coderd/database/migrations/testdata/fixtures/000195_oauth2_provider_codes.up.sql b/coderd/database/migrations/testdata/fixtures/000195_oauth2_provider_codes.up.sql new file mode 100644 index 0000000000..d764f7908c --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000195_oauth2_provider_codes.up.sql @@ -0,0 +1,23 @@ +INSERT INTO oauth2_provider_app_codes + (id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id) +VALUES ( + 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00', + '2023-06-15 10:23:54+00', + CAST('abcdefg' AS bytea), + CAST('abcdefg' AS bytea), + '0ed9befc-4911-4ccf-a8e2-559bf72daa94', + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' +); + +INSERT INTO oauth2_provider_app_tokens + (id, created_at, expires_at, hash_prefix, refresh_hash, app_secret_id, api_key_id) +VALUES ( + 'd0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:25:33+00', + '2023-12-15 11:40:20+00', + CAST('gfedcba' AS bytea), + CAST('abcdefg' AS bytea), + 'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'peuLZhMXt4' +); diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 24a7e0f0e6..9e77772839 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -290,6 +290,22 @@ func (l License) RBACObject() rbac.Object { return rbac.ResourceLicense.WithIDString(strconv.FormatInt(int64(l.ID), 10)) } +func (c OAuth2ProviderAppCode) RBACObject() rbac.Object { + return rbac.ResourceOAuth2ProviderAppCodeToken.WithOwner(c.UserID.String()) +} + +func (OAuth2ProviderAppSecret) RBACObject() rbac.Object { + return rbac.ResourceOAuth2ProviderAppSecret +} + +func (OAuth2ProviderApp) RBACObject() rbac.Object { + return rbac.ResourceOAuth2ProviderApp +} + +func (a GetOAuth2ProviderAppsByUserIDRow) RBACObject() rbac.Object { + return a.OAuth2ProviderApp.RBACObject() +} + type WorkspaceAgentConnectionStatus struct { Status WorkspaceAgentStatus `json:"status"` FirstConnectedAt *time.Time `json:"first_connected_at"` diff --git a/coderd/database/models.go b/coderd/database/models.go index 7156d772a3..4d65e5e344 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -593,11 +593,12 @@ func AllLogSourceValues() []LogSource { type LoginType string const ( - LoginTypePassword LoginType = "password" - LoginTypeGithub LoginType = "github" - LoginTypeOIDC LoginType = "oidc" - LoginTypeToken LoginType = "token" - LoginTypeNone LoginType = "none" + LoginTypePassword LoginType = "password" + LoginTypeGithub LoginType = "github" + LoginTypeOIDC LoginType = "oidc" + LoginTypeToken LoginType = "token" + LoginTypeNone LoginType = "none" + LoginTypeOAuth2ProviderApp LoginType = "oauth2_provider_app" ) func (e *LoginType) Scan(src interface{}) error { @@ -641,7 +642,8 @@ func (e LoginType) Valid() bool { LoginTypeGithub, LoginTypeOIDC, LoginTypeToken, - LoginTypeNone: + LoginTypeNone, + LoginTypeOAuth2ProviderApp: return true } return false @@ -654,6 +656,7 @@ func AllLoginTypeValues() []LoginType { LoginTypeOIDC, LoginTypeToken, LoginTypeNone, + LoginTypeOAuth2ProviderApp, } } @@ -1807,6 +1810,17 @@ type OAuth2ProviderApp struct { CallbackURL string `db:"callback_url" json:"callback_url"` } +// Codes are meant to be exchanged for access tokens. +type OAuth2ProviderAppCode struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + SecretPrefix []byte `db:"secret_prefix" json:"secret_prefix"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + AppID uuid.UUID `db:"app_id" json:"app_id"` +} + type OAuth2ProviderAppSecret struct { ID uuid.UUID `db:"id" json:"id"` CreatedAt time.Time `db:"created_at" json:"created_at"` @@ -1815,6 +1829,18 @@ type OAuth2ProviderAppSecret struct { // The tail end of the original secret so secrets can be differentiated. DisplaySecret string `db:"display_secret" json:"display_secret"` AppID uuid.UUID `db:"app_id" json:"app_id"` + SecretPrefix []byte `db:"secret_prefix" json:"secret_prefix"` +} + +type OAuth2ProviderAppToken struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + HashPrefix []byte `db:"hash_prefix" json:"hash_prefix"` + // Refresh tokens provide a way to refresh an access token (API key). An expired API key can be refreshed if this token is not yet expired, meaning this expiry can outlive an API key. + RefreshHash []byte `db:"refresh_hash" json:"refresh_hash"` + AppSecretID uuid.UUID `db:"app_secret_id" json:"app_secret_id"` + APIKeyID string `db:"api_key_id" json:"api_key_id"` } type Organization struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 385230518b..92ee81f85f 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -60,7 +60,10 @@ type sqlcQuerier interface { DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error + DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error + DeleteOAuth2ProviderAppCodesByAppAndUserID(ctx context.Context, arg DeleteOAuth2ProviderAppCodesByAppAndUserIDParams) error DeleteOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) error + DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx context.Context, arg DeleteOAuth2ProviderAppTokensByAppAndUserIDParams) error // Delete provisioner daemons that have been created at least a week ago // and have not connected to coderd since a week. // A provisioner daemon with "zeroed" last_seen_at column indicates possible @@ -131,9 +134,14 @@ type sqlcQuerier interface { GetLicenses(ctx context.Context) ([]License, error) GetLogoURL(ctx context.Context) (string, error) GetOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderApp, error) + GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderAppCode, error) + GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (OAuth2ProviderAppCode, error) GetOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderAppSecret, error) + GetOAuth2ProviderAppSecretByPrefix(ctx context.Context, secretPrefix []byte) (OAuth2ProviderAppSecret, error) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, appID uuid.UUID) ([]OAuth2ProviderAppSecret, error) + GetOAuth2ProviderAppTokenByPrefix(ctx context.Context, hashPrefix []byte) (OAuth2ProviderAppToken, error) GetOAuth2ProviderApps(ctx context.Context) ([]OAuth2ProviderApp, error) + GetOAuth2ProviderAppsByUserID(ctx context.Context, userID uuid.UUID) ([]GetOAuth2ProviderAppsByUserIDRow, error) GetOAuthSigningKey(ctx context.Context) (string, error) GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error) GetOrganizationByName(ctx context.Context, name string) (Organization, error) @@ -290,7 +298,9 @@ type sqlcQuerier interface { // If the name conflicts, do nothing. InsertMissingGroups(ctx context.Context, arg InsertMissingGroupsParams) ([]Group, error) InsertOAuth2ProviderApp(ctx context.Context, arg InsertOAuth2ProviderAppParams) (OAuth2ProviderApp, error) + InsertOAuth2ProviderAppCode(ctx context.Context, arg InsertOAuth2ProviderAppCodeParams) (OAuth2ProviderAppCode, error) InsertOAuth2ProviderAppSecret(ctx context.Context, arg InsertOAuth2ProviderAppSecretParams) (OAuth2ProviderAppSecret, error) + InsertOAuth2ProviderAppToken(ctx context.Context, arg InsertOAuth2ProviderAppTokenParams) (OAuth2ProviderAppToken, error) InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error) InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error) InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index f32293dedd..a23c9f9769 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2691,6 +2691,29 @@ func (q *sqlQuerier) DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UU return err } +const deleteOAuth2ProviderAppCodeByID = `-- name: DeleteOAuth2ProviderAppCodeByID :exec +DELETE FROM oauth2_provider_app_codes WHERE id = $1 +` + +func (q *sqlQuerier) DeleteOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) error { + _, err := q.db.ExecContext(ctx, deleteOAuth2ProviderAppCodeByID, id) + return err +} + +const deleteOAuth2ProviderAppCodesByAppAndUserID = `-- name: DeleteOAuth2ProviderAppCodesByAppAndUserID :exec +DELETE FROM oauth2_provider_app_codes WHERE app_id = $1 AND user_id = $2 +` + +type DeleteOAuth2ProviderAppCodesByAppAndUserIDParams struct { + AppID uuid.UUID `db:"app_id" json:"app_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) DeleteOAuth2ProviderAppCodesByAppAndUserID(ctx context.Context, arg DeleteOAuth2ProviderAppCodesByAppAndUserIDParams) error { + _, err := q.db.ExecContext(ctx, deleteOAuth2ProviderAppCodesByAppAndUserID, arg.AppID, arg.UserID) + return err +} + const deleteOAuth2ProviderAppSecretByID = `-- name: DeleteOAuth2ProviderAppSecretByID :exec DELETE FROM oauth2_provider_app_secrets WHERE id = $1 ` @@ -2700,6 +2723,28 @@ func (q *sqlQuerier) DeleteOAuth2ProviderAppSecretByID(ctx context.Context, id u return err } +const deleteOAuth2ProviderAppTokensByAppAndUserID = `-- name: DeleteOAuth2ProviderAppTokensByAppAndUserID :exec +DELETE FROM + oauth2_provider_app_tokens +USING + oauth2_provider_app_secrets, api_keys +WHERE + oauth2_provider_app_secrets.id = oauth2_provider_app_tokens.app_secret_id + AND api_keys.id = oauth2_provider_app_tokens.api_key_id + AND oauth2_provider_app_secrets.app_id = $1 + AND api_keys.user_id = $2 +` + +type DeleteOAuth2ProviderAppTokensByAppAndUserIDParams struct { + AppID uuid.UUID `db:"app_id" json:"app_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx context.Context, arg DeleteOAuth2ProviderAppTokensByAppAndUserIDParams) error { + _, err := q.db.ExecContext(ctx, deleteOAuth2ProviderAppTokensByAppAndUserID, arg.AppID, arg.UserID) + return err +} + const getOAuth2ProviderAppByID = `-- name: GetOAuth2ProviderAppByID :one SELECT id, created_at, updated_at, name, icon, callback_url FROM oauth2_provider_apps WHERE id = $1 ` @@ -2718,8 +2763,46 @@ func (q *sqlQuerier) GetOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) return i, err } +const getOAuth2ProviderAppCodeByID = `-- name: GetOAuth2ProviderAppCodeByID :one +SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id FROM oauth2_provider_app_codes WHERE id = $1 +` + +func (q *sqlQuerier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderAppCode, error) { + row := q.db.QueryRowContext(ctx, getOAuth2ProviderAppCodeByID, id) + var i OAuth2ProviderAppCode + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.ExpiresAt, + &i.SecretPrefix, + &i.HashedSecret, + &i.UserID, + &i.AppID, + ) + return i, err +} + +const getOAuth2ProviderAppCodeByPrefix = `-- name: GetOAuth2ProviderAppCodeByPrefix :one +SELECT id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id FROM oauth2_provider_app_codes WHERE secret_prefix = $1 +` + +func (q *sqlQuerier) GetOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (OAuth2ProviderAppCode, error) { + row := q.db.QueryRowContext(ctx, getOAuth2ProviderAppCodeByPrefix, secretPrefix) + var i OAuth2ProviderAppCode + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.ExpiresAt, + &i.SecretPrefix, + &i.HashedSecret, + &i.UserID, + &i.AppID, + ) + return i, err +} + const getOAuth2ProviderAppSecretByID = `-- name: GetOAuth2ProviderAppSecretByID :one -SELECT id, created_at, last_used_at, hashed_secret, display_secret, app_id FROM oauth2_provider_app_secrets WHERE id = $1 +SELECT id, created_at, last_used_at, hashed_secret, display_secret, app_id, secret_prefix FROM oauth2_provider_app_secrets WHERE id = $1 ` func (q *sqlQuerier) GetOAuth2ProviderAppSecretByID(ctx context.Context, id uuid.UUID) (OAuth2ProviderAppSecret, error) { @@ -2732,12 +2815,32 @@ func (q *sqlQuerier) GetOAuth2ProviderAppSecretByID(ctx context.Context, id uuid &i.HashedSecret, &i.DisplaySecret, &i.AppID, + &i.SecretPrefix, + ) + return i, err +} + +const getOAuth2ProviderAppSecretByPrefix = `-- name: GetOAuth2ProviderAppSecretByPrefix :one +SELECT id, created_at, last_used_at, hashed_secret, display_secret, app_id, secret_prefix FROM oauth2_provider_app_secrets WHERE secret_prefix = $1 +` + +func (q *sqlQuerier) GetOAuth2ProviderAppSecretByPrefix(ctx context.Context, secretPrefix []byte) (OAuth2ProviderAppSecret, error) { + row := q.db.QueryRowContext(ctx, getOAuth2ProviderAppSecretByPrefix, secretPrefix) + var i OAuth2ProviderAppSecret + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.LastUsedAt, + &i.HashedSecret, + &i.DisplaySecret, + &i.AppID, + &i.SecretPrefix, ) return i, err } const getOAuth2ProviderAppSecretsByAppID = `-- name: GetOAuth2ProviderAppSecretsByAppID :many -SELECT id, created_at, last_used_at, hashed_secret, display_secret, app_id FROM oauth2_provider_app_secrets WHERE app_id = $1 ORDER BY (created_at, id) ASC +SELECT id, created_at, last_used_at, hashed_secret, display_secret, app_id, secret_prefix FROM oauth2_provider_app_secrets WHERE app_id = $1 ORDER BY (created_at, id) ASC ` func (q *sqlQuerier) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, appID uuid.UUID) ([]OAuth2ProviderAppSecret, error) { @@ -2756,6 +2859,7 @@ func (q *sqlQuerier) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, app &i.HashedSecret, &i.DisplaySecret, &i.AppID, + &i.SecretPrefix, ); err != nil { return nil, err } @@ -2770,6 +2874,25 @@ func (q *sqlQuerier) GetOAuth2ProviderAppSecretsByAppID(ctx context.Context, app return items, nil } +const getOAuth2ProviderAppTokenByPrefix = `-- name: GetOAuth2ProviderAppTokenByPrefix :one +SELECT id, created_at, expires_at, hash_prefix, refresh_hash, app_secret_id, api_key_id FROM oauth2_provider_app_tokens WHERE hash_prefix = $1 +` + +func (q *sqlQuerier) GetOAuth2ProviderAppTokenByPrefix(ctx context.Context, hashPrefix []byte) (OAuth2ProviderAppToken, error) { + row := q.db.QueryRowContext(ctx, getOAuth2ProviderAppTokenByPrefix, hashPrefix) + var i OAuth2ProviderAppToken + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.ExpiresAt, + &i.HashPrefix, + &i.RefreshHash, + &i.AppSecretID, + &i.APIKeyID, + ) + return i, err +} + const getOAuth2ProviderApps = `-- name: GetOAuth2ProviderApps :many SELECT id, created_at, updated_at, name, icon, callback_url FROM oauth2_provider_apps ORDER BY (name, id) ASC ` @@ -2804,6 +2927,59 @@ func (q *sqlQuerier) GetOAuth2ProviderApps(ctx context.Context) ([]OAuth2Provide return items, nil } +const getOAuth2ProviderAppsByUserID = `-- name: GetOAuth2ProviderAppsByUserID :many +SELECT + COUNT(DISTINCT oauth2_provider_app_tokens.id) as token_count, + oauth2_provider_apps.id, oauth2_provider_apps.created_at, oauth2_provider_apps.updated_at, oauth2_provider_apps.name, oauth2_provider_apps.icon, oauth2_provider_apps.callback_url +FROM oauth2_provider_app_tokens + INNER JOIN oauth2_provider_app_secrets + ON oauth2_provider_app_secrets.id = oauth2_provider_app_tokens.app_secret_id + INNER JOIN oauth2_provider_apps + ON oauth2_provider_apps.id = oauth2_provider_app_secrets.app_id + INNER JOIN api_keys + ON api_keys.id = oauth2_provider_app_tokens.api_key_id +WHERE + api_keys.user_id = $1 +GROUP BY + oauth2_provider_apps.id +` + +type GetOAuth2ProviderAppsByUserIDRow struct { + TokenCount int64 `db:"token_count" json:"token_count"` + OAuth2ProviderApp OAuth2ProviderApp `db:"oauth2_provider_app" json:"oauth2_provider_app"` +} + +func (q *sqlQuerier) GetOAuth2ProviderAppsByUserID(ctx context.Context, userID uuid.UUID) ([]GetOAuth2ProviderAppsByUserIDRow, error) { + rows, err := q.db.QueryContext(ctx, getOAuth2ProviderAppsByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetOAuth2ProviderAppsByUserIDRow + for rows.Next() { + var i GetOAuth2ProviderAppsByUserIDRow + if err := rows.Scan( + &i.TokenCount, + &i.OAuth2ProviderApp.ID, + &i.OAuth2ProviderApp.CreatedAt, + &i.OAuth2ProviderApp.UpdatedAt, + &i.OAuth2ProviderApp.Name, + &i.OAuth2ProviderApp.Icon, + &i.OAuth2ProviderApp.CallbackURL, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertOAuth2ProviderApp = `-- name: InsertOAuth2ProviderApp :one INSERT INTO oauth2_provider_apps ( id, @@ -2852,10 +3028,64 @@ func (q *sqlQuerier) InsertOAuth2ProviderApp(ctx context.Context, arg InsertOAut return i, err } +const insertOAuth2ProviderAppCode = `-- name: InsertOAuth2ProviderAppCode :one +INSERT INTO oauth2_provider_app_codes ( + id, + created_at, + expires_at, + secret_prefix, + hashed_secret, + app_id, + user_id +) VALUES( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) RETURNING id, created_at, expires_at, secret_prefix, hashed_secret, user_id, app_id +` + +type InsertOAuth2ProviderAppCodeParams struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + SecretPrefix []byte `db:"secret_prefix" json:"secret_prefix"` + HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` + AppID uuid.UUID `db:"app_id" json:"app_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) InsertOAuth2ProviderAppCode(ctx context.Context, arg InsertOAuth2ProviderAppCodeParams) (OAuth2ProviderAppCode, error) { + row := q.db.QueryRowContext(ctx, insertOAuth2ProviderAppCode, + arg.ID, + arg.CreatedAt, + arg.ExpiresAt, + arg.SecretPrefix, + arg.HashedSecret, + arg.AppID, + arg.UserID, + ) + var i OAuth2ProviderAppCode + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.ExpiresAt, + &i.SecretPrefix, + &i.HashedSecret, + &i.UserID, + &i.AppID, + ) + return i, err +} + const insertOAuth2ProviderAppSecret = `-- name: InsertOAuth2ProviderAppSecret :one INSERT INTO oauth2_provider_app_secrets ( id, created_at, + secret_prefix, hashed_secret, display_secret, app_id @@ -2864,13 +3094,15 @@ INSERT INTO oauth2_provider_app_secrets ( $2, $3, $4, - $5 -) RETURNING id, created_at, last_used_at, hashed_secret, display_secret, app_id + $5, + $6 +) RETURNING id, created_at, last_used_at, hashed_secret, display_secret, app_id, secret_prefix ` type InsertOAuth2ProviderAppSecretParams struct { ID uuid.UUID `db:"id" json:"id"` CreatedAt time.Time `db:"created_at" json:"created_at"` + SecretPrefix []byte `db:"secret_prefix" json:"secret_prefix"` HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"` DisplaySecret string `db:"display_secret" json:"display_secret"` AppID uuid.UUID `db:"app_id" json:"app_id"` @@ -2880,6 +3112,7 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppSecret(ctx context.Context, arg Inse row := q.db.QueryRowContext(ctx, insertOAuth2ProviderAppSecret, arg.ID, arg.CreatedAt, + arg.SecretPrefix, arg.HashedSecret, arg.DisplaySecret, arg.AppID, @@ -2892,6 +3125,60 @@ func (q *sqlQuerier) InsertOAuth2ProviderAppSecret(ctx context.Context, arg Inse &i.HashedSecret, &i.DisplaySecret, &i.AppID, + &i.SecretPrefix, + ) + return i, err +} + +const insertOAuth2ProviderAppToken = `-- name: InsertOAuth2ProviderAppToken :one +INSERT INTO oauth2_provider_app_tokens ( + id, + created_at, + expires_at, + hash_prefix, + refresh_hash, + app_secret_id, + api_key_id +) VALUES( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) RETURNING id, created_at, expires_at, hash_prefix, refresh_hash, app_secret_id, api_key_id +` + +type InsertOAuth2ProviderAppTokenParams struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + HashPrefix []byte `db:"hash_prefix" json:"hash_prefix"` + RefreshHash []byte `db:"refresh_hash" json:"refresh_hash"` + AppSecretID uuid.UUID `db:"app_secret_id" json:"app_secret_id"` + APIKeyID string `db:"api_key_id" json:"api_key_id"` +} + +func (q *sqlQuerier) InsertOAuth2ProviderAppToken(ctx context.Context, arg InsertOAuth2ProviderAppTokenParams) (OAuth2ProviderAppToken, error) { + row := q.db.QueryRowContext(ctx, insertOAuth2ProviderAppToken, + arg.ID, + arg.CreatedAt, + arg.ExpiresAt, + arg.HashPrefix, + arg.RefreshHash, + arg.AppSecretID, + arg.APIKeyID, + ) + var i OAuth2ProviderAppToken + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.ExpiresAt, + &i.HashPrefix, + &i.RefreshHash, + &i.AppSecretID, + &i.APIKeyID, ) return i, err } @@ -2936,7 +3223,7 @@ func (q *sqlQuerier) UpdateOAuth2ProviderAppByID(ctx context.Context, arg Update const updateOAuth2ProviderAppSecretByID = `-- name: UpdateOAuth2ProviderAppSecretByID :one UPDATE oauth2_provider_app_secrets SET last_used_at = $2 -WHERE id = $1 RETURNING id, created_at, last_used_at, hashed_secret, display_secret, app_id +WHERE id = $1 RETURNING id, created_at, last_used_at, hashed_secret, display_secret, app_id, secret_prefix ` type UpdateOAuth2ProviderAppSecretByIDParams struct { @@ -2954,6 +3241,7 @@ func (q *sqlQuerier) UpdateOAuth2ProviderAppSecretByID(ctx context.Context, arg &i.HashedSecret, &i.DisplaySecret, &i.AppID, + &i.SecretPrefix, ) return i, err } diff --git a/coderd/database/queries/oauth2.sql b/coderd/database/queries/oauth2.sql index cd9a150d0b..e2ccd6111e 100644 --- a/coderd/database/queries/oauth2.sql +++ b/coderd/database/queries/oauth2.sql @@ -38,10 +38,14 @@ SELECT * FROM oauth2_provider_app_secrets WHERE id = $1; -- name: GetOAuth2ProviderAppSecretsByAppID :many SELECT * FROM oauth2_provider_app_secrets WHERE app_id = $1 ORDER BY (created_at, id) ASC; +-- name: GetOAuth2ProviderAppSecretByPrefix :one +SELECT * FROM oauth2_provider_app_secrets WHERE secret_prefix = $1; + -- name: InsertOAuth2ProviderAppSecret :one INSERT INTO oauth2_provider_app_secrets ( id, created_at, + secret_prefix, hashed_secret, display_secret, app_id @@ -50,7 +54,8 @@ INSERT INTO oauth2_provider_app_secrets ( $2, $3, $4, - $5 + $5, + $6 ) RETURNING *; -- name: UpdateOAuth2ProviderAppSecretByID :one @@ -60,3 +65,83 @@ WHERE id = $1 RETURNING *; -- name: DeleteOAuth2ProviderAppSecretByID :exec DELETE FROM oauth2_provider_app_secrets WHERE id = $1; + +-- name: GetOAuth2ProviderAppCodeByID :one +SELECT * FROM oauth2_provider_app_codes WHERE id = $1; + +-- name: GetOAuth2ProviderAppCodeByPrefix :one +SELECT * FROM oauth2_provider_app_codes WHERE secret_prefix = $1; + +-- name: InsertOAuth2ProviderAppCode :one +INSERT INTO oauth2_provider_app_codes ( + id, + created_at, + expires_at, + secret_prefix, + hashed_secret, + app_id, + user_id +) VALUES( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) RETURNING *; + +-- name: DeleteOAuth2ProviderAppCodeByID :exec +DELETE FROM oauth2_provider_app_codes WHERE id = $1; + +-- name: DeleteOAuth2ProviderAppCodesByAppAndUserID :exec +DELETE FROM oauth2_provider_app_codes WHERE app_id = $1 AND user_id = $2; + +-- name: InsertOAuth2ProviderAppToken :one +INSERT INTO oauth2_provider_app_tokens ( + id, + created_at, + expires_at, + hash_prefix, + refresh_hash, + app_secret_id, + api_key_id +) VALUES( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) RETURNING *; + +-- name: GetOAuth2ProviderAppTokenByPrefix :one +SELECT * FROM oauth2_provider_app_tokens WHERE hash_prefix = $1; + +-- name: GetOAuth2ProviderAppsByUserID :many +SELECT + COUNT(DISTINCT oauth2_provider_app_tokens.id) as token_count, + sqlc.embed(oauth2_provider_apps) +FROM oauth2_provider_app_tokens + INNER JOIN oauth2_provider_app_secrets + ON oauth2_provider_app_secrets.id = oauth2_provider_app_tokens.app_secret_id + INNER JOIN oauth2_provider_apps + ON oauth2_provider_apps.id = oauth2_provider_app_secrets.app_id + INNER JOIN api_keys + ON api_keys.id = oauth2_provider_app_tokens.api_key_id +WHERE + api_keys.user_id = $1 +GROUP BY + oauth2_provider_apps.id; + +-- name: DeleteOAuth2ProviderAppTokensByAppAndUserID :exec +DELETE FROM + oauth2_provider_app_tokens +USING + oauth2_provider_app_secrets, api_keys +WHERE + oauth2_provider_app_secrets.id = oauth2_provider_app_tokens.app_secret_id + AND api_keys.id = oauth2_provider_app_tokens.api_key_id + AND oauth2_provider_app_secrets.app_id = $1 + AND api_keys.user_id = $2; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 49140d597a..621946e7b4 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -99,4 +99,8 @@ sql: display_app_ssh_helper: DisplayAppSSHHelper oauth2_provider_app: OAuth2ProviderApp oauth2_provider_app_secret: OAuth2ProviderAppSecret + oauth2_provider_app_code: OAuth2ProviderAppCode + oauth2_provider_app_token: OAuth2ProviderAppToken + api_key_id: APIKeyID callback_url: CallbackURL + login_type_oauth2_provider_app: LoginTypeOAuth2ProviderApp diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index fa1efffb81..498fc24a80 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -22,8 +22,12 @@ const ( UniqueJfrogXrayScansPkey UniqueConstraint = "jfrog_xray_scans_pkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_pkey PRIMARY KEY (agent_id, workspace_id); UniqueLicensesJWTKey UniqueConstraint = "licenses_jwt_key" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt); UniqueLicensesPkey UniqueConstraint = "licenses_pkey" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id); - UniqueOauth2ProviderAppSecretsAppIDHashedSecretKey UniqueConstraint = "oauth2_provider_app_secrets_app_id_hashed_secret_key" // ALTER TABLE ONLY oauth2_provider_app_secrets ADD CONSTRAINT oauth2_provider_app_secrets_app_id_hashed_secret_key UNIQUE (app_id, hashed_secret); + UniqueOauth2ProviderAppCodesPkey UniqueConstraint = "oauth2_provider_app_codes_pkey" // ALTER TABLE ONLY oauth2_provider_app_codes ADD CONSTRAINT oauth2_provider_app_codes_pkey PRIMARY KEY (id); + UniqueOauth2ProviderAppCodesSecretPrefixKey UniqueConstraint = "oauth2_provider_app_codes_secret_prefix_key" // ALTER TABLE ONLY oauth2_provider_app_codes ADD CONSTRAINT oauth2_provider_app_codes_secret_prefix_key UNIQUE (secret_prefix); UniqueOauth2ProviderAppSecretsPkey UniqueConstraint = "oauth2_provider_app_secrets_pkey" // ALTER TABLE ONLY oauth2_provider_app_secrets ADD CONSTRAINT oauth2_provider_app_secrets_pkey PRIMARY KEY (id); + UniqueOauth2ProviderAppSecretsSecretPrefixKey UniqueConstraint = "oauth2_provider_app_secrets_secret_prefix_key" // ALTER TABLE ONLY oauth2_provider_app_secrets ADD CONSTRAINT oauth2_provider_app_secrets_secret_prefix_key UNIQUE (secret_prefix); + UniqueOauth2ProviderAppTokensHashPrefixKey UniqueConstraint = "oauth2_provider_app_tokens_hash_prefix_key" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT oauth2_provider_app_tokens_hash_prefix_key UNIQUE (hash_prefix); + UniqueOauth2ProviderAppTokensPkey UniqueConstraint = "oauth2_provider_app_tokens_pkey" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT oauth2_provider_app_tokens_pkey PRIMARY KEY (id); UniqueOauth2ProviderAppsNameKey UniqueConstraint = "oauth2_provider_apps_name_key" // ALTER TABLE ONLY oauth2_provider_apps ADD CONSTRAINT oauth2_provider_apps_name_key UNIQUE (name); UniqueOauth2ProviderAppsPkey UniqueConstraint = "oauth2_provider_apps_pkey" // ALTER TABLE ONLY oauth2_provider_apps ADD CONSTRAINT oauth2_provider_apps_pkey PRIMARY KEY (id); UniqueOrganizationMembersPkey UniqueConstraint = "organization_members_pkey" // ALTER TABLE ONLY organization_members ADD CONSTRAINT organization_members_pkey PRIMARY KEY (organization_id, user_id); diff --git a/coderd/httpapi/queryparams.go b/coderd/httpapi/queryparams.go index 9b7daf2310..822cfea22d 100644 --- a/coderd/httpapi/queryparams.go +++ b/coderd/httpapi/queryparams.go @@ -23,16 +23,16 @@ type QueryParamParser struct { // Parsed is a map of all query params that were parsed. This is useful // for checking if extra query params were passed in. Parsed map[string]bool - // RequiredParams is a map of all query params that are required. This is useful + // RequiredNotEmptyParams is a map of all query params that are required. This is useful // for forcing a value to be provided. - RequiredParams map[string]bool + RequiredNotEmptyParams map[string]bool } func NewQueryParamParser() *QueryParamParser { return &QueryParamParser{ - Errors: []codersdk.ValidationError{}, - Parsed: map[string]bool{}, - RequiredParams: map[string]bool{}, + Errors: []codersdk.ValidationError{}, + Parsed: map[string]bool{}, + RequiredNotEmptyParams: map[string]bool{}, } } @@ -90,8 +90,10 @@ func (p *QueryParamParser) Boolean(vals url.Values, def bool, queryParam string) return v } -func (p *QueryParamParser) Required(queryParam string) *QueryParamParser { - p.RequiredParams[queryParam] = true +func (p *QueryParamParser) RequiredNotEmpty(queryParam ...string) *QueryParamParser { + for _, q := range queryParam { + p.RequiredNotEmptyParams[q] = true + } return p } @@ -121,6 +123,27 @@ func (p *QueryParamParser) UUIDs(vals url.Values, def []uuid.UUID, queryParam st }) } +func (p *QueryParamParser) RedirectURL(vals url.Values, base *url.URL, queryParam string) *url.URL { + v, err := parseQueryParam(p, vals, url.Parse, base, queryParam) + if err != nil { + p.Errors = append(p.Errors, codersdk.ValidationError{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q must be a valid url: %s", queryParam, err.Error()), + }) + } + + // It can be a sub-directory but not a sub-domain, as we have apps on + // sub-domains and that seems too dangerous. + if v.Host != base.Host || !strings.HasPrefix(v.Path, base.Path) { + p.Errors = append(p.Errors, codersdk.ValidationError{ + Field: queryParam, + Detail: fmt.Sprintf("Query param %q must be a subset of %s", queryParam, base), + }) + } + + return v +} + func (p *QueryParamParser) Time(vals url.Values, def time.Time, queryParam, layout string) time.Time { return p.timeWithMutate(vals, def, queryParam, layout, nil) } @@ -233,10 +256,10 @@ func ParseCustomList[T any](parser *QueryParamParser, vals url.Values, def []T, func parseQueryParam[T any](parser *QueryParamParser, vals url.Values, parse func(v string) (T, error), def T, queryParam string) (T, error) { parser.addParsed(queryParam) // If the query param is required and not present, return an error. - if parser.RequiredParams[queryParam] && (!vals.Has(queryParam)) { + if parser.RequiredNotEmptyParams[queryParam] && (!vals.Has(queryParam) || vals.Get(queryParam) == "") { parser.Errors = append(parser.Errors, codersdk.ValidationError{ Field: queryParam, - Detail: fmt.Sprintf("Query param %q is required", queryParam), + Detail: fmt.Sprintf("Query param %q is required and cannot be empty", queryParam), }) return def, nil } diff --git a/coderd/httpapi/queryparams_test.go b/coderd/httpapi/queryparams_test.go index f919b478df..b9773bfa25 100644 --- a/coderd/httpapi/queryparams_test.go +++ b/coderd/httpapi/queryparams_test.go @@ -320,9 +320,14 @@ func TestParseQueryParams(t *testing.T) { t.Parallel() parser := httpapi.NewQueryParamParser() - parser.Required("test_value") + parser.RequiredNotEmpty("test_value") parser.UUID(url.Values{}, uuid.New(), "test_value") require.Len(t, parser.Errors, 1) + + parser = httpapi.NewQueryParamParser() + parser.RequiredNotEmpty("test_value") + parser.String(url.Values{"test_value": {""}}, "", "test_value") + require.Len(t, parser.Errors, 1) }) } diff --git a/coderd/httpmw/oauth2.go b/coderd/httpmw/oauth2.go index dbb763bc9d..98baaae4c4 100644 --- a/coderd/httpmw/oauth2.go +++ b/coderd/httpmw/oauth2.go @@ -6,6 +6,8 @@ import ( "net/http" "reflect" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" "golang.org/x/oauth2" "github.com/coder/coder/v2/coderd/database" @@ -194,9 +196,44 @@ func ExtractOAuth2ProviderApp(db database.Store) func(http.Handler) http.Handler return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - appID, ok := ParseUUIDParam(rw, r, "app") - if !ok { - return + + // App can come from a URL param, query param, or form value. + paramID := "app" + var appID uuid.UUID + if chi.URLParam(r, paramID) != "" { + var ok bool + appID, ok = ParseUUIDParam(rw, r, "app") + if !ok { + return + } + } else { + // If not provided by the url, then it is provided according to the + // oauth 2 spec. This can occur with query params, or in the body as + // form parameters. + // This also depends on if you are doing a POST (tokens) or GET (authorize). + paramAppID := r.URL.Query().Get("client_id") + if paramAppID == "" { + // Check the form params! + if r.ParseForm() == nil { + paramAppID = r.Form.Get("client_id") + } + } + if paramAppID == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing OAuth2 client ID.", + }) + return + } + + var err error + appID, err = uuid.Parse(paramAppID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid OAuth2 client ID.", + Detail: err.Error(), + }) + return + } } app, err := db.GetOAuth2ProviderAppByID(ctx, appID) diff --git a/coderd/insights.go b/coderd/insights.go index 4f29e2ef85..214eae5510 100644 --- a/coderd/insights.go +++ b/coderd/insights.go @@ -72,8 +72,8 @@ func (api *API) insightsUserActivity(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() p := httpapi.NewQueryParamParser(). - Required("start_time"). - Required("end_time") + RequiredNotEmpty("start_time"). + RequiredNotEmpty("end_time") vals := r.URL.Query() var ( // The QueryParamParser does not preserve timezone, so we need @@ -161,8 +161,8 @@ func (api *API) insightsUserLatency(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() p := httpapi.NewQueryParamParser(). - Required("start_time"). - Required("end_time") + RequiredNotEmpty("start_time"). + RequiredNotEmpty("end_time") vals := r.URL.Query() var ( // The QueryParamParser does not preserve timezone, so we need @@ -253,8 +253,8 @@ func (api *API) insightsTemplates(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() p := httpapi.NewQueryParamParser(). - Required("start_time"). - Required("end_time") + RequiredNotEmpty("start_time"). + RequiredNotEmpty("end_time") vals := r.URL.Query() var ( // The QueryParamParser does not preserve timezone, so we need diff --git a/coderd/rbac/object.go b/coderd/rbac/object.go index ace060b314..51b6da339c 100644 --- a/coderd/rbac/object.go +++ b/coderd/rbac/object.go @@ -213,12 +213,20 @@ var ( Type: "oauth2_app", } - // ResourceOAuth2ProviderAppSecrets CRUD. + // ResourceOAuth2ProviderAppSecret CRUD. // create/delete = Make or delete an OAuth2 app secret. // update = Update last used date. // read = Read OAuth2 app hashed or truncated secret. ResourceOAuth2ProviderAppSecret = Object{ - Type: "oauth2_app_secrets", + Type: "oauth2_app_secret", + } + + // ResourceOAuth2ProviderAppCodeToken CRUD. + // create/delete = Make or delete an OAuth2 app code or token. + // update = None + // read = Check if OAuth2 app code or token exists. + ResourceOAuth2ProviderAppCodeToken = Object{ + Type: "oauth2_app_code_token", } ) diff --git a/coderd/rbac/object_gen.go b/coderd/rbac/object_gen.go index 4668f56b06..b1cac5704e 100644 --- a/coderd/rbac/object_gen.go +++ b/coderd/rbac/object_gen.go @@ -12,6 +12,7 @@ func AllResources() []Object { ResourceGroup, ResourceLicense, ResourceOAuth2ProviderApp, + ResourceOAuth2ProviderAppCodeToken, ResourceOAuth2ProviderAppSecret, ResourceOrgRoleAssignment, ResourceOrganization, diff --git a/coderd/rbac/roles.go b/coderd/rbac/roles.go index d6a53d5b9b..ebe122386f 100644 --- a/coderd/rbac/roles.go +++ b/coderd/rbac/roles.go @@ -148,6 +148,8 @@ func ReloadBuiltinRoles(opts *RoleOptions) { ResourceRoleAssignment.Type: {ActionRead}, // All users can see the provisioner daemons. ResourceProvisionerDaemon.Type: {ActionRead}, + // All users can see OAuth2 provider applications. + ResourceOAuth2ProviderApp.Type: {ActionRead}, }), Org: map[string][]Permission{}, User: append(allPermsExcept(ResourceWorkspaceDormant, ResourceUser, ResourceOrganizationMember), diff --git a/coderd/users.go b/coderd/users.go index 62739c1d16..be4e46ea7f 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -583,7 +583,7 @@ func (api *API) userByName(rw http.ResponseWriter, r *http.Request) { func (api *API) userAutofillParameters(rw http.ResponseWriter, r *http.Request) { user := httpmw.UserParam(r) - p := httpapi.NewQueryParamParser().Required("template_id") + p := httpapi.NewQueryParamParser().RequiredNotEmpty("template_id") templateID := p.UUID(r.URL.Query(), uuid.UUID{}, "template_id") p.ErrorExcessParams(r.URL.Query()) if len(p.Errors) > 0 { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 984cfc62a3..70d8a64efa 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -636,7 +636,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { values := r.URL.Query() parser := httpapi.NewQueryParamParser() - reconnect := parser.Required("reconnect").UUID(values, uuid.New(), "reconnect") + reconnect := parser.RequiredNotEmpty("reconnect").UUID(values, uuid.New(), "reconnect") height := parser.UInt(values, 80, "height") width := parser.UInt(values, 80, "width") if len(parser.Errors) > 0 { diff --git a/codersdk/oauth2.go b/codersdk/oauth2.go index 318743959d..d4d20c26b1 100644 --- a/codersdk/oauth2.go +++ b/codersdk/oauth2.go @@ -28,10 +28,21 @@ type OAuth2AppEndpoints struct { DeviceAuth string `json:"device_authorization"` } +type OAuth2ProviderAppFilter struct { + UserID uuid.UUID `json:"user_id,omitempty" format:"uuid"` +} + // OAuth2ProviderApps returns the applications configured to authenticate using // Coder as an OAuth2 provider. -func (c *Client) OAuth2ProviderApps(ctx context.Context) ([]OAuth2ProviderApp, error) { - res, err := c.Request(ctx, http.MethodGet, "/api/v2/oauth2-provider/apps", nil) +func (c *Client) OAuth2ProviderApps(ctx context.Context, filter OAuth2ProviderAppFilter) ([]OAuth2ProviderApp, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/v2/oauth2-provider/apps", nil, + func(r *http.Request) { + if filter.UserID != uuid.Nil { + q := r.URL.Query() + q.Set("user_id", filter.UserID.String()) + r.URL.RawQuery = q.Encode() + } + }) if err != nil { return []OAuth2ProviderApp{}, err } @@ -168,3 +179,51 @@ func (c *Client) DeleteOAuth2ProviderAppSecret(ctx context.Context, appID uuid.U } return nil } + +type OAuth2ProviderGrantType string + +const ( + OAuth2ProviderGrantTypeAuthorizationCode OAuth2ProviderGrantType = "authorization_code" + OAuth2ProviderGrantTypeRefreshToken OAuth2ProviderGrantType = "refresh_token" +) + +func (e OAuth2ProviderGrantType) Valid() bool { + switch e { + case OAuth2ProviderGrantTypeAuthorizationCode, OAuth2ProviderGrantTypeRefreshToken: + return true + } + return false +} + +type OAuth2ProviderResponseType string + +const ( + OAuth2ProviderResponseTypeCode OAuth2ProviderResponseType = "code" +) + +func (e OAuth2ProviderResponseType) Valid() bool { + //nolint:gocritic,revive // More cases might be added later. + switch e { + case OAuth2ProviderResponseTypeCode: + return true + } + return false +} + +// RevokeOAuth2ProviderApp completely revokes an app's access for the +// authenticated user. +func (c *Client) RevokeOAuth2ProviderApp(ctx context.Context, appID uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, "/login/oauth2/tokens", nil, func(r *http.Request) { + q := r.URL.Query() + q.Set("client_id", appID.String()) + r.URL.RawQuery = q.Encode() + }) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} diff --git a/docs/api/enterprise.md b/docs/api/enterprise.md index cb100f346f..3e80637666 100644 --- a/docs/api/enterprise.md +++ b/docs/api/enterprise.md @@ -534,6 +534,127 @@ curl -X DELETE http://coder-server:8080/api/v2/licenses/{id} \ To perform this operation, you must be authenticated. [Learn more](authentication.md). +## OAuth2 authorization request. + +### Code samples + +```shell +# Example request using curl +curl -X POST http://coder-server:8080/api/v2/login/oauth2/authorize?client_id=string&state=string&response_type=code \ + -H 'Coder-Session-Token: API_KEY' +``` + +`POST /login/oauth2/authorize` + +### Parameters + +| Name | In | Type | Required | Description | +| --------------- | ----- | ------ | -------- | --------------------------------- | +| `client_id` | query | string | true | Client ID | +| `state` | query | string | true | A random unguessable string | +| `response_type` | query | string | true | Response type | +| `redirect_uri` | query | string | false | Redirect here after authorization | +| `scope` | query | string | false | Token scopes (currently ignored) | + +#### Enumerated Values + +| Parameter | Value | +| --------------- | ------ | +| `response_type` | `code` | + +### Responses + +| Status | Meaning | Description | Schema | +| ------ | ---------------------------------------------------------- | ----------- | ------ | +| 302 | [Found](https://tools.ietf.org/html/rfc7231#section-6.4.3) | Found | | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## OAuth2 token exchange. + +### Code samples + +```shell +# Example request using curl +curl -X POST http://coder-server:8080/api/v2/login/oauth2/tokens \ + -H 'Accept: application/json' +``` + +`POST /login/oauth2/tokens` + +> Body parameter + +```yaml +client_id: string +client_secret: string +code: string +refresh_token: string +grant_type: authorization_code +``` + +### Parameters + +| Name | In | Type | Required | Description | +| ----------------- | ---- | ------ | -------- | ------------------------------------------------------------- | +| `body` | body | object | false | | +| `» client_id` | body | string | false | Client ID, required if grant_type=authorization_code | +| `» client_secret` | body | string | false | Client secret, required if grant_type=authorization_code | +| `» code` | body | string | false | Authorization code, required if grant_type=authorization_code | +| `» refresh_token` | body | string | false | Refresh token, required if grant_type=refresh_token | +| `» grant_type` | body | string | true | Grant type | + +#### Enumerated Values + +| Parameter | Value | +| -------------- | -------------------- | +| `» grant_type` | `authorization_code` | +| `» grant_type` | `refresh_token` | + +### Example responses + +> 200 Response + +```json +{ + "access_token": "string", + "expiry": "string", + "refresh_token": "string", + "token_type": "string" +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +| ------ | ------------------------------------------------------- | ----------- | -------------------------------------- | +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [oauth2.Token](schemas.md#oauth2token) | + +## Delete OAuth2 application tokens. + +### Code samples + +```shell +# Example request using curl +curl -X DELETE http://coder-server:8080/api/v2/login/oauth2/tokens?client_id=string \ + -H 'Coder-Session-Token: API_KEY' +``` + +`DELETE /login/oauth2/tokens` + +### Parameters + +| Name | In | Type | Required | Description | +| ----------- | ----- | ------ | -------- | ----------- | +| `client_id` | query | string | true | Client ID | + +### Responses + +| Status | Meaning | Description | Schema | +| ------ | --------------------------------------------------------------- | ----------- | ------ | +| 204 | [No Content](https://tools.ietf.org/html/rfc7231#section-6.3.5) | No Content | | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + ## Get OAuth2 applications. ### Code samples @@ -547,6 +668,12 @@ curl -X GET http://coder-server:8080/api/v2/oauth2-provider/apps \ `GET /oauth2-provider/apps` +### Parameters + +| Name | In | Type | Required | Description | +| --------- | ----- | ------ | -------- | -------------------------------------------- | +| `user_id` | query | string | false | Filter by applications authorized for a user | + ### Example responses > 200 Response diff --git a/docs/api/schemas.md b/docs/api/schemas.md index be5724e361..6c07f1f98a 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -8722,6 +8722,27 @@ _None_ | `udp` | boolean | false | | a UDP STUN round trip completed | | `upnP` | string | false | | Upnp is whether UPnP appears present on the LAN. Empty means not checked. | +## oauth2.Token + +```json +{ + "access_token": "string", + "expiry": "string", + "refresh_token": "string", + "token_type": "string" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +| ------------------------------------------------------------------------------------------------------------------------------------------------------- | ------ | -------- | ------------ | --------------------------------------------------------------------------------------------------------------------------- | +| `access_token` | string | false | | Access token is the token that authorizes and authenticates the requests. | +| `expiry` | string | false | | Expiry is the optional expiration time of the access token. | +| If zero, TokenSource implementations will reuse the same token forever and RefreshToken or equivalent mechanisms for that TokenSource will not be used. | +| `refresh_token` | string | false | | Refresh token is a token that's used by the application (as opposed to the user) to refresh the access token if it expires. | +| `token_type` | string | false | | Token type is the type of token. The Type method returns either this or "Bearer", the default. | + ## tailcfg.DERPHomeParams ```json diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 95611f671d..3c90b9992c 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -167,6 +167,28 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { return nil, xerrors.Errorf("failed to get deployment ID: %w", err) } + api.AGPL.RootHandler.Group(func(r chi.Router) { + r.Use( + api.oAuth2ProviderMiddleware, + // Fetch the app as system because in the /tokens route there will be no + // authenticated user. + httpmw.AsAuthzSystem(httpmw.ExtractOAuth2ProviderApp(options.Database)), + ) + // Oauth2 linking routes do not make sense under the /api/v2 path. + r.Route("/login", func(r chi.Router) { + r.Route("/oauth2", func(r chi.Router) { + r.Group(func(r chi.Router) { + r.Use(apiKeyMiddleware) + r.Get("/authorize", api.postOAuth2ProviderAppAuthorize()) + r.Delete("/tokens", api.deleteOAuth2ProviderAppTokens()) + }) + // The /tokens endpoint will be called from an unauthorized client so we + // cannot require an API key. + r.Post("/tokens", api.postOAuth2ProviderAppToken()) + }) + }) + }) + api.AGPL.APIHandler.Group(func(r chi.Router) { r.Get("/entitlements", api.serveEntitlements) // /regions overrides the AGPL /regions endpoint diff --git a/enterprise/coderd/identityprovider/authorize.go b/enterprise/coderd/identityprovider/authorize.go new file mode 100644 index 0000000000..f41a0842e9 --- /dev/null +++ b/enterprise/coderd/identityprovider/authorize.go @@ -0,0 +1,140 @@ +package identityprovider + +import ( + "database/sql" + "errors" + "net/http" + "net/url" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" +) + +type authorizeParams struct { + clientID string + redirectURL *url.URL + responseType codersdk.OAuth2ProviderResponseType + scope []string + state string +} + +func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizeParams, []codersdk.ValidationError, error) { + p := httpapi.NewQueryParamParser() + vals := r.URL.Query() + + p.RequiredNotEmpty("state", "response_type", "client_id") + + params := authorizeParams{ + clientID: p.String(vals, "", "client_id"), + redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"), + responseType: httpapi.ParseCustom(p, vals, "", "response_type", httpapi.ParseEnum[codersdk.OAuth2ProviderResponseType]), + scope: p.Strings(vals, []string{}, "scope"), + state: p.String(vals, "", "state"), + } + + // We add "redirected" when coming from the authorize page. + _ = p.String(vals, "", "redirected") + + p.ErrorExcessParams(vals) + if len(p.Errors) > 0 { + return authorizeParams{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors) + } + return params, nil, nil +} + +// Authorize displays an HTML page for authorizing an application when the user +// has first been redirected to this path and generates a code and redirects to +// the app's callback URL after the user clicks "allow" on that page, which is +// detected via the origin and referer headers. +func Authorize(db database.Store, accessURL *url.URL) http.HandlerFunc { + handler := func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + app := httpmw.OAuth2ProviderApp(r) + + callbackURL, err := url.Parse(app.CallbackURL) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to validate query parameters.", + Detail: err.Error(), + }) + return + } + + params, validationErrs, err := extractAuthorizeParams(r, callbackURL) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid query params.", + Detail: err.Error(), + Validations: validationErrs, + }) + return + } + + // TODO: Ignoring scope for now, but should look into implementing. + code, err := GenerateSecret() + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to generate OAuth2 app authorization code.", + }) + return + } + err = db.InTx(func(tx database.Store) error { + // Delete any previous codes. + err = tx.DeleteOAuth2ProviderAppCodesByAppAndUserID(ctx, database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams{ + AppID: app.ID, + UserID: apiKey.UserID, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return xerrors.Errorf("delete oauth2 app codes: %w", err) + } + + // Insert the new code. + _, err = tx.InsertOAuth2ProviderAppCode(ctx, database.InsertOAuth2ProviderAppCodeParams{ + ID: uuid.New(), + CreatedAt: dbtime.Now(), + // TODO: Configurable expiration? Ten minutes matches GitHub. + // This timeout is only for the code that will be exchanged for the + // access token, not the access token itself. It does not need to be + // long-lived because normally it will be exchanged immediately after it + // is received. If the application does wait before exchanging the + // token (for example suppose they ask the user to confirm and the user + // has left) then they can just retry immediately and get a new code. + ExpiresAt: dbtime.Now().Add(time.Duration(10) * time.Minute), + SecretPrefix: []byte(code.Prefix), + HashedSecret: []byte(code.Hashed), + AppID: app.ID, + UserID: apiKey.UserID, + }) + if err != nil { + return xerrors.Errorf("insert oauth2 authorization code: %w", err) + } + + return nil + }, nil) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to generate OAuth2 authorization code.", + Detail: err.Error(), + }) + return + } + + newQuery := params.redirectURL.Query() + newQuery.Add("code", code.Formatted) + newQuery.Add("state", params.state) + params.redirectURL.RawQuery = newQuery.Encode() + + http.Redirect(rw, r, params.redirectURL.String(), http.StatusTemporaryRedirect) + } + + // Always wrap with its custom mw. + return authorizeMW(accessURL)(http.HandlerFunc(handler)).ServeHTTP +} diff --git a/enterprise/coderd/identityprovider/middleware.go b/enterprise/coderd/identityprovider/middleware.go new file mode 100644 index 0000000000..640ea8652e --- /dev/null +++ b/enterprise/coderd/identityprovider/middleware.go @@ -0,0 +1,149 @@ +package identityprovider + +import ( + "net/http" + "net/url" + + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/site" +) + +// authorizeMW serves to remove some code from the primary authorize handler. +// It decides when to show the html allow page, and when to just continue. +func authorizeMW(accessURL *url.URL) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + origin := r.Header.Get(httpmw.OriginHeader) + originU, err := url.Parse(origin) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid origin header.", + Detail: err.Error(), + }) + return + } + + referer := r.Referer() + refererU, err := url.Parse(referer) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid referer header.", + Detail: err.Error(), + }) + return + } + + app := httpmw.OAuth2ProviderApp(r) + ua := httpmw.UserAuthorization(r) + + // url.Parse() allows empty URLs, which is fine because the origin is not + // always set by browsers (or other tools like cURL). If the origin does + // exist, we will make sure it matches. We require `referer` to be set at + // a minimum in order to detect whether "allow" has been pressed, however. + cameFromSelf := (origin == "" || originU.Hostname() == accessURL.Hostname()) && + refererU.Hostname() == accessURL.Hostname() && + refererU.Path == "/login/oauth2/authorize" + + // If we were redirected here from this same page it means the user + // pressed the allow button so defer to the authorize handler which + // generates the code, otherwise show the HTML allow page. + // TODO: Skip this step if the user has already clicked allow before, and + // we can just reuse the token. + if cameFromSelf { + next.ServeHTTP(rw, r) + return + } + + // TODO: For now only browser-based auth flow is officially supported but + // in a future PR we should support a cURL-based flow where we output text + // instead of HTML. + if r.URL.Query().Get("redirected") != "" { + // When the user first comes into the page, referer might be blank which + // is OK. But if they click "allow" and their browser has *still* not + // sent the referer header, we have no way of telling whether they + // actually clicked the button. "Redirected" means they *might* have + // pressed it, but it could also mean an app added it for them as part + // of their redirect, so we cannot use it as a replacement for referer + // and the best we can do is error. + if referer == "" { + site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ + Status: http.StatusInternalServerError, + HideStatus: false, + Title: "Referer header missing", + Description: "We cannot continue authorization because your client has not sent the referer header.", + RetryEnabled: false, + DashboardURL: accessURL.String(), + Warnings: nil, + }) + return + } + site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ + Status: http.StatusInternalServerError, + HideStatus: false, + Title: "Oauth Redirect Loop", + Description: "Oauth redirect loop detected.", + RetryEnabled: false, + DashboardURL: accessURL.String(), + Warnings: nil, + }) + return + } + + callbackURL, err := url.Parse(app.CallbackURL) + if err != nil { + site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ + Status: http.StatusInternalServerError, + HideStatus: false, + Title: "Internal Server Error", + Description: err.Error(), + RetryEnabled: false, + DashboardURL: accessURL.String(), + Warnings: nil, + }) + return + } + + // Extract the form parameters for two reasons: + // 1. We need the redirect URI to build the cancel URI. + // 2. Since validation will run once the user clicks "allow", it is + // better to validate now to avoid wasting the user's time clicking a + // button that will just error anyway. + params, validationErrs, err := extractAuthorizeParams(r, callbackURL) + if err != nil { + errStr := make([]string, len(validationErrs)) + for i, err := range validationErrs { + errStr[i] = err.Detail + } + site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ + Status: http.StatusBadRequest, + HideStatus: false, + Title: "Invalid Query Parameters", + Description: "One or more query parameters are missing or invalid.", + RetryEnabled: false, + DashboardURL: accessURL.String(), + Warnings: errStr, + }) + return + } + + cancel := params.redirectURL + cancelQuery := params.redirectURL.Query() + cancelQuery.Add("error", "access_denied") + cancel.RawQuery = cancelQuery.Encode() + + redirect := r.URL + vals := redirect.Query() + vals.Add("redirected", "true") // For loop detection. + r.URL.RawQuery = vals.Encode() + site.RenderOAuthAllowPage(rw, r, site.RenderOAuthAllowData{ + AppIcon: app.Icon, + AppName: app.Name, + CancelURI: cancel.String(), + RedirectURI: r.URL.String(), + Username: ua.ActorName, + }) + }) + } +} diff --git a/enterprise/coderd/identityprovider/revoke.go b/enterprise/coderd/identityprovider/revoke.go new file mode 100644 index 0000000000..cddc150bbe --- /dev/null +++ b/enterprise/coderd/identityprovider/revoke.go @@ -0,0 +1,44 @@ +package identityprovider + +import ( + "database/sql" + "errors" + "net/http" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" +) + +func RevokeApp(db database.Store) http.HandlerFunc { + return func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + app := httpmw.OAuth2ProviderApp(r) + + err := db.InTx(func(tx database.Store) error { + err := tx.DeleteOAuth2ProviderAppCodesByAppAndUserID(ctx, database.DeleteOAuth2ProviderAppCodesByAppAndUserIDParams{ + AppID: app.ID, + UserID: apiKey.UserID, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + err = tx.DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx, database.DeleteOAuth2ProviderAppTokensByAppAndUserIDParams{ + AppID: app.ID, + UserID: apiKey.UserID, + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + return nil + }, nil) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + httpapi.Write(ctx, rw, http.StatusNoContent, nil) + } +} diff --git a/enterprise/coderd/identityprovider/secrets.go b/enterprise/coderd/identityprovider/secrets.go new file mode 100644 index 0000000000..72524b3d2a --- /dev/null +++ b/enterprise/coderd/identityprovider/secrets.go @@ -0,0 +1,77 @@ +package identityprovider + +import ( + "fmt" + "strings" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/userpassword" + "github.com/coder/coder/v2/cryptorand" +) + +type OAuth2ProviderAppSecret struct { + // Formatted contains the secret. This value is owned by the client, not the + // server. It is formatted to include the prefix. + Formatted string + // Prefix is the ID of this secret owned by the server. When a client uses a + // secret, this is the matching string to do a lookup on the hashed value. We + // cannot use the hashed value directly because the server does not store the + // salt. + Prefix string + // Hashed is the server stored hash(secret,salt,...). Used for verifying a + // secret. + Hashed string +} + +// GenerateSecret generates a secret to be used as a client secret, refresh +// token, or authorization code. +func GenerateSecret() (OAuth2ProviderAppSecret, error) { + // 40 characters matches the length of GitHub's client secrets. + secret, err := cryptorand.String(40) + if err != nil { + return OAuth2ProviderAppSecret{}, err + } + + // This ID is prefixed to the secret so it can be used to look up the secret + // when the user provides it, since we cannot just re-hash it to match as we + // will not have the salt. + prefix, err := cryptorand.String(10) + if err != nil { + return OAuth2ProviderAppSecret{}, err + } + + hashed, err := userpassword.Hash(secret) + if err != nil { + return OAuth2ProviderAppSecret{}, err + } + + return OAuth2ProviderAppSecret{ + Formatted: fmt.Sprintf("coder_%s_%s", prefix, secret), + Prefix: prefix, + Hashed: hashed, + }, nil +} + +type parsedSecret struct { + prefix string + secret string +} + +// parseSecret extracts the ID and original secret from a secret. +func parseSecret(secret string) (parsedSecret, error) { + parts := strings.Split(secret, "_") + if len(parts) != 3 { + return parsedSecret{}, xerrors.Errorf("incorrect number of parts: %d", len(parts)) + } + if parts[0] != "coder" { + return parsedSecret{}, xerrors.Errorf("incorrect scheme: %s", parts[0]) + } + if len(parts[1]) == 0 { + return parsedSecret{}, xerrors.Errorf("prefix is invalid") + } + if len(parts[2]) == 0 { + return parsedSecret{}, xerrors.Errorf("invalid") + } + return parsedSecret{parts[1], parts[2]}, nil +} diff --git a/enterprise/coderd/identityprovider/tokens.go b/enterprise/coderd/identityprovider/tokens.go new file mode 100644 index 0000000000..0673eb7d1a --- /dev/null +++ b/enterprise/coderd/identityprovider/tokens.go @@ -0,0 +1,378 @@ +package identityprovider + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/google/uuid" + "golang.org/x/oauth2" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/apikey" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/userpassword" + "github.com/coder/coder/v2/codersdk" +) + +var ( + // errBadSecret means the user provided a bad secret. + errBadSecret = xerrors.New("Invalid client secret") + // errBadCode means the user provided a bad code. + errBadCode = xerrors.New("Invalid code") + // errBadToken means the user provided a bad token. + errBadToken = xerrors.New("Invalid token") +) + +type tokenParams struct { + clientID string + clientSecret string + code string + grantType codersdk.OAuth2ProviderGrantType + redirectURL *url.URL + refreshToken string +} + +func extractTokenParams(r *http.Request, callbackURL *url.URL) (tokenParams, []codersdk.ValidationError, error) { + p := httpapi.NewQueryParamParser() + err := r.ParseForm() + if err != nil { + return tokenParams{}, nil, xerrors.Errorf("parse form: %w", err) + } + + vals := r.Form + p.RequiredNotEmpty("grant_type") + grantType := httpapi.ParseCustom(p, vals, "", "grant_type", httpapi.ParseEnum[codersdk.OAuth2ProviderGrantType]) + switch grantType { + case codersdk.OAuth2ProviderGrantTypeRefreshToken: + p.RequiredNotEmpty("refresh_token") + case codersdk.OAuth2ProviderGrantTypeAuthorizationCode: + p.RequiredNotEmpty("client_secret", "client_id", "code") + } + + params := tokenParams{ + clientID: p.String(vals, "", "client_id"), + clientSecret: p.String(vals, "", "client_secret"), + code: p.String(vals, "", "code"), + grantType: grantType, + redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"), + refreshToken: p.String(vals, "", "refresh_token"), + } + + p.ErrorExcessParams(vals) + if len(p.Errors) > 0 { + return tokenParams{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors) + } + return params, nil, nil +} + +func Tokens(db database.Store, defaultLifetime time.Duration) http.HandlerFunc { + return func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + app := httpmw.OAuth2ProviderApp(r) + + callbackURL, err := url.Parse(app.CallbackURL) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to validate form values.", + Detail: err.Error(), + }) + return + } + + params, validationErrs, err := extractTokenParams(r, callbackURL) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid query params.", + Detail: err.Error(), + Validations: validationErrs, + }) + return + } + + var token oauth2.Token + //nolint:gocritic,revive // More cases will be added later. + switch params.grantType { + // TODO: Client creds, device code. + case codersdk.OAuth2ProviderGrantTypeRefreshToken: + token, err = refreshTokenGrant(ctx, db, app, defaultLifetime, params) + case codersdk.OAuth2ProviderGrantTypeAuthorizationCode: + token, err = authorizationCodeGrant(ctx, db, app, defaultLifetime, params) + default: + // Grant types are validated by the parser, so getting through here means + // the developer added a type but forgot to add a case here. + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Unhandled grant type.", + Detail: fmt.Sprintf("Grant type %q is unhandled", params.grantType), + }) + return + } + + if errors.Is(err, errBadCode) || errors.Is(err, errBadSecret) { + httpapi.Write(r.Context(), rw, http.StatusUnauthorized, codersdk.Response{ + Message: err.Error(), + }) + return + } + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to exchange token", + Detail: err.Error(), + }) + return + } + + // Some client libraries allow this to be "application/x-www-form-urlencoded". We can implement that upon + // request. The same libraries should also accept JSON. If implemented, choose based on "Accept" header. + httpapi.Write(ctx, rw, http.StatusOK, token) + } +} + +func authorizationCodeGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, defaultLifetime time.Duration, params tokenParams) (oauth2.Token, error) { + // Validate the client secret. + secret, err := parseSecret(params.clientSecret) + if err != nil { + return oauth2.Token{}, errBadSecret + } + //nolint:gocritic // Users cannot read secrets so we must use the system. + dbSecret, err := db.GetOAuth2ProviderAppSecretByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(secret.prefix)) + if errors.Is(err, sql.ErrNoRows) { + return oauth2.Token{}, errBadSecret + } + if err != nil { + return oauth2.Token{}, err + } + equal, err := userpassword.Compare(string(dbSecret.HashedSecret), secret.secret) + if err != nil { + return oauth2.Token{}, xerrors.Errorf("unable to compare secret: %w", err) + } + if !equal { + return oauth2.Token{}, errBadSecret + } + + // Validate the authorization code. + code, err := parseSecret(params.code) + if err != nil { + return oauth2.Token{}, errBadCode + } + //nolint:gocritic // There is no user yet so we must use the system. + dbCode, err := db.GetOAuth2ProviderAppCodeByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(code.prefix)) + if errors.Is(err, sql.ErrNoRows) { + return oauth2.Token{}, errBadCode + } + if err != nil { + return oauth2.Token{}, err + } + equal, err = userpassword.Compare(string(dbCode.HashedSecret), code.secret) + if err != nil { + return oauth2.Token{}, xerrors.Errorf("unable to compare code: %w", err) + } + if !equal { + return oauth2.Token{}, errBadCode + } + + // Ensure the code has not expired. + if dbCode.ExpiresAt.Before(dbtime.Now()) { + return oauth2.Token{}, errBadCode + } + + // Generate a refresh token. + refreshToken, err := GenerateSecret() + if err != nil { + return oauth2.Token{}, err + } + + // Generate the API key we will swap for the code. + // TODO: We are ignoring scopes for now. + tokenName := fmt.Sprintf("%s_%s_oauth_session_token", dbCode.UserID, app.ID) + key, sessionToken, err := apikey.Generate(apikey.CreateParams{ + UserID: dbCode.UserID, + LoginType: database.LoginTypeOAuth2ProviderApp, + // TODO: This is just the lifetime for api keys, maybe have its own config + // settings. #11693 + DefaultLifetime: defaultLifetime, + // For now, we allow only one token per app and user at a time. + TokenName: tokenName, + }) + if err != nil { + return oauth2.Token{}, err + } + + // Grab the user roles so we can perform the exchange as the user. + //nolint:gocritic // In the token exchange, there is no user actor. + roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), dbCode.UserID) + if err != nil { + return oauth2.Token{}, err + } + userSubj := rbac.Subject{ + ID: dbCode.UserID.String(), + Roles: rbac.RoleNames(roles.Roles), + Groups: roles.Groups, + Scope: rbac.ScopeAll, + } + + // Do the actual token exchange in the database. + err = db.InTx(func(tx database.Store) error { + ctx := dbauthz.As(ctx, userSubj) + err = tx.DeleteOAuth2ProviderAppCodeByID(ctx, dbCode.ID) + if err != nil { + return xerrors.Errorf("delete oauth2 app code: %w", err) + } + + // Delete the previous key, if any. + prevKey, err := tx.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{ + UserID: dbCode.UserID, + TokenName: tokenName, + }) + if err == nil { + err = tx.DeleteAPIKeyByID(ctx, prevKey.ID) + } + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return xerrors.Errorf("delete api key by name: %w", err) + } + + newKey, err := tx.InsertAPIKey(ctx, key) + if err != nil { + return xerrors.Errorf("insert oauth2 access token: %w", err) + } + + _, err = tx.InsertOAuth2ProviderAppToken(ctx, database.InsertOAuth2ProviderAppTokenParams{ + ID: uuid.New(), + CreatedAt: dbtime.Now(), + ExpiresAt: key.ExpiresAt, + HashPrefix: []byte(refreshToken.Prefix), + RefreshHash: []byte(refreshToken.Hashed), + AppSecretID: dbSecret.ID, + APIKeyID: newKey.ID, + }) + if err != nil { + return xerrors.Errorf("insert oauth2 refresh token: %w", err) + } + return nil + }, nil) + if err != nil { + return oauth2.Token{}, err + } + + return oauth2.Token{ + AccessToken: sessionToken, + TokenType: "Bearer", + RefreshToken: refreshToken.Formatted, + Expiry: key.ExpiresAt, + }, nil +} + +func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, defaultLifetime time.Duration, params tokenParams) (oauth2.Token, error) { + // Validate the token. + token, err := parseSecret(params.refreshToken) + if err != nil { + return oauth2.Token{}, errBadToken + } + //nolint:gocritic // There is no user yet so we must use the system. + dbToken, err := db.GetOAuth2ProviderAppTokenByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(token.prefix)) + if errors.Is(err, sql.ErrNoRows) { + return oauth2.Token{}, errBadToken + } + if err != nil { + return oauth2.Token{}, err + } + equal, err := userpassword.Compare(string(dbToken.RefreshHash), token.secret) + if err != nil { + return oauth2.Token{}, xerrors.Errorf("unable to compare token: %w", err) + } + if !equal { + return oauth2.Token{}, errBadToken + } + + // Ensure the token has not expired. + if dbToken.ExpiresAt.Before(dbtime.Now()) { + return oauth2.Token{}, errBadToken + } + + // Grab the user roles so we can perform the refresh as the user. + //nolint:gocritic // There is no user yet so we must use the system. + prevKey, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), dbToken.APIKeyID) + if err != nil { + return oauth2.Token{}, err + } + //nolint:gocritic // There is no user yet so we must use the system. + roles, err := db.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), prevKey.UserID) + if err != nil { + return oauth2.Token{}, err + } + userSubj := rbac.Subject{ + ID: prevKey.UserID.String(), + Roles: rbac.RoleNames(roles.Roles), + Groups: roles.Groups, + Scope: rbac.ScopeAll, + } + + // Generate a new refresh token. + refreshToken, err := GenerateSecret() + if err != nil { + return oauth2.Token{}, err + } + + // Generate the new API key. + // TODO: We are ignoring scopes for now. + tokenName := fmt.Sprintf("%s_%s_oauth_session_token", prevKey.UserID, app.ID) + key, sessionToken, err := apikey.Generate(apikey.CreateParams{ + UserID: prevKey.UserID, + LoginType: database.LoginTypeOAuth2ProviderApp, + // TODO: This is just the lifetime for api keys, maybe have its own config + // settings. #11693 + DefaultLifetime: defaultLifetime, + // For now, we allow only one token per app and user at a time. + TokenName: tokenName, + }) + if err != nil { + return oauth2.Token{}, err + } + + // Replace the token. + err = db.InTx(func(tx database.Store) error { + ctx := dbauthz.As(ctx, userSubj) + err = tx.DeleteAPIKeyByID(ctx, prevKey.ID) // This cascades to the token. + if err != nil { + return xerrors.Errorf("delete oauth2 app token: %w", err) + } + + newKey, err := tx.InsertAPIKey(ctx, key) + if err != nil { + return xerrors.Errorf("insert oauth2 access token: %w", err) + } + + _, err = tx.InsertOAuth2ProviderAppToken(ctx, database.InsertOAuth2ProviderAppTokenParams{ + ID: uuid.New(), + CreatedAt: dbtime.Now(), + ExpiresAt: key.ExpiresAt, + HashPrefix: []byte(refreshToken.Prefix), + RefreshHash: []byte(refreshToken.Hashed), + AppSecretID: dbToken.AppSecretID, + APIKeyID: newKey.ID, + }) + if err != nil { + return xerrors.Errorf("insert oauth2 refresh token: %w", err) + } + return nil + }, nil) + if err != nil { + return oauth2.Token{}, err + } + + return oauth2.Token{ + AccessToken: sessionToken, + TokenType: "Bearer", + RefreshToken: refreshToken.Formatted, + Expiry: key.ExpiresAt, + }, nil +} diff --git a/enterprise/coderd/jfrog.go b/enterprise/coderd/jfrog.go index 7195aee908..9262c673eb 100644 --- a/enterprise/coderd/jfrog.go +++ b/enterprise/coderd/jfrog.go @@ -67,8 +67,8 @@ func (api *API) jFrogXrayScan(rw http.ResponseWriter, r *http.Request) { ctx = r.Context() vals = r.URL.Query() p = httpapi.NewQueryParamParser() - wsID = p.Required("workspace_id").UUID(vals, uuid.UUID{}, "workspace_id") - agentID = p.Required("agent_id").UUID(vals, uuid.UUID{}, "agent_id") + wsID = p.RequiredNotEmpty("workspace_id").UUID(vals, uuid.UUID{}, "workspace_id") + agentID = p.RequiredNotEmpty("agent_id").UUID(vals, uuid.UUID{}, "agent_id") ) if len(p.Errors) > 0 { diff --git a/enterprise/coderd/oauth2.go b/enterprise/coderd/oauth2.go index 3ebb39aaee..0f016d6533 100644 --- a/enterprise/coderd/oauth2.go +++ b/enterprise/coderd/oauth2.go @@ -1,7 +1,7 @@ package coderd import ( - "crypto/sha256" + "fmt" "net/http" "github.com/google/uuid" @@ -13,7 +13,7 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/cryptorand" + "github.com/coder/coder/v2/enterprise/coderd/identityprovider" ) func (api *API) oAuth2ProviderMiddleware(next http.Handler) http.Handler { @@ -45,16 +45,43 @@ func (api *API) oAuth2ProviderMiddleware(next http.Handler) http.Handler { // @Security CoderSessionToken // @Produce json // @Tags Enterprise +// @Param user_id query string false "Filter by applications authorized for a user" // @Success 200 {array} codersdk.OAuth2ProviderApp // @Router /oauth2-provider/apps [get] func (api *API) oAuth2ProviderApps(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - dbApps, err := api.Database.GetOAuth2ProviderApps(ctx) + + rawUserID := r.URL.Query().Get("user_id") + if rawUserID == "" { + dbApps, err := api.Database.GetOAuth2ProviderApps(ctx) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + httpapi.Write(ctx, rw, http.StatusOK, db2sdk.OAuth2ProviderApps(api.AccessURL, dbApps)) + return + } + + userID, err := uuid.Parse(rawUserID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid user UUID", + Detail: fmt.Sprintf("queried user_id=%q", userID), + }) + return + } + + userApps, err := api.Database.GetOAuth2ProviderAppsByUserID(ctx, userID) if err != nil { httpapi.InternalServerError(rw, err) return } - httpapi.Write(ctx, rw, http.StatusOK, db2sdk.OAuth2ProviderApps(api.AccessURL, dbApps)) + + var sdkApps []codersdk.OAuth2ProviderApp + for _, app := range userApps { + sdkApps = append(sdkApps, db2sdk.OAuth2ProviderApp(api.AccessURL, app.OAuth2ProviderApp)) + } + httpapi.Write(ctx, rw, http.StatusOK, sdkApps) } // @Summary Get OAuth2 application. @@ -130,7 +157,7 @@ func (api *API) putOAuth2ProviderApp(rw http.ResponseWriter, r *http.Request) { }) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error creating OAuth2 application.", + Message: "Internal error updating OAuth2 application.", Detail: err.Error(), }) return @@ -200,23 +227,23 @@ func (api *API) oAuth2ProviderAppSecrets(rw http.ResponseWriter, r *http.Request func (api *API) postOAuth2ProviderAppSecret(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() app := httpmw.OAuth2ProviderApp(r) - // 40 characters matches the length of GitHub's client secrets. - rawSecret, err := cryptorand.String(40) + secret, err := identityprovider.GenerateSecret() if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to generate OAuth2 client secret.", + Detail: err.Error(), }) return } - hashed := sha256.Sum256([]byte(rawSecret)) - secret, err := api.Database.InsertOAuth2ProviderAppSecret(ctx, database.InsertOAuth2ProviderAppSecretParams{ + dbSecret, err := api.Database.InsertOAuth2ProviderAppSecret(ctx, database.InsertOAuth2ProviderAppSecretParams{ ID: uuid.New(), CreatedAt: dbtime.Now(), - HashedSecret: hashed[:], + SecretPrefix: []byte(secret.Prefix), + HashedSecret: []byte(secret.Hashed), // DisplaySecret is the last six characters of the original unhashed secret. // This is done so they can be differentiated and it matches how GitHub // displays their client secrets. - DisplaySecret: rawSecret[len(rawSecret)-6:], + DisplaySecret: secret.Formatted[len(secret.Formatted)-6:], AppID: app.ID, }) if err != nil { @@ -227,8 +254,8 @@ func (api *API) postOAuth2ProviderAppSecret(rw http.ResponseWriter, r *http.Requ return } httpapi.Write(ctx, rw, http.StatusOK, codersdk.OAuth2ProviderAppSecretFull{ - ID: secret.ID, - ClientSecretFull: rawSecret, + ID: dbSecret.ID, + ClientSecretFull: secret.Formatted, }) } @@ -253,3 +280,44 @@ func (api *API) deleteOAuth2ProviderAppSecret(rw http.ResponseWriter, r *http.Re } httpapi.Write(ctx, rw, http.StatusNoContent, nil) } + +// @Summary OAuth2 authorization request. +// @ID oauth2-authorization-request +// @Security CoderSessionToken +// @Tags Enterprise +// @Param client_id query string true "Client ID" +// @Param state query string true "A random unguessable string" +// @Param response_type query codersdk.OAuth2ProviderResponseType true "Response type" +// @Param redirect_uri query string false "Redirect here after authorization" +// @Param scope query string false "Token scopes (currently ignored)" +// @Success 302 +// @Router /login/oauth2/authorize [post] +func (api *API) postOAuth2ProviderAppAuthorize() http.HandlerFunc { + return identityprovider.Authorize(api.Database, api.AccessURL) +} + +// @Summary OAuth2 token exchange. +// @ID oauth2-token-exchange +// @Produce json +// @Tags Enterprise +// @Param client_id formData string false "Client ID, required if grant_type=authorization_code" +// @Param client_secret formData string false "Client secret, required if grant_type=authorization_code" +// @Param code formData string false "Authorization code, required if grant_type=authorization_code" +// @Param refresh_token formData string false "Refresh token, required if grant_type=refresh_token" +// @Param grant_type formData codersdk.OAuth2ProviderGrantType true "Grant type" +// @Success 200 {object} oauth2.Token +// @Router /login/oauth2/tokens [post] +func (api *API) postOAuth2ProviderAppToken() http.HandlerFunc { + return identityprovider.Tokens(api.Database, api.DeploymentValues.SessionDuration.Value()) +} + +// @Summary Delete OAuth2 application tokens. +// @ID delete-oauth2-application-tokens +// @Security CoderSessionToken +// @Tags Enterprise +// @Param client_id query string true "Client ID" +// @Success 204 +// @Router /login/oauth2/tokens [delete] +func (api *API) deleteOAuth2ProviderAppTokens() http.HandlerFunc { + return identityprovider.RevokeApp(api.Database) +} diff --git a/enterprise/coderd/oauth2_test.go b/enterprise/coderd/oauth2_test.go index 8a2e4df7bc..94d221882a 100644 --- a/enterprise/coderd/oauth2_test.go +++ b/enterprise/coderd/oauth2_test.go @@ -1,19 +1,34 @@ package coderd_test import ( - "strconv" + "context" + "fmt" + "net/http" + "net/url" + "path" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "github.com/coder/coder/v2/coderd/apikey" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/userpassword" + "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/identityprovider" "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/testutil" ) -func TestOAuthApps(t *testing.T) { +func TestOAuth2ProviderApps(t *testing.T) { t.Parallel() t.Run("Validation", func(t *testing.T) { @@ -25,7 +40,7 @@ func TestOAuthApps(t *testing.T) { }, }}) - ctx := testutil.Context(t, testutil.WaitLong) + topCtx := testutil.Context(t, testutil.WaitLong) tests := []struct { name string @@ -128,7 +143,7 @@ func TestOAuthApps(t *testing.T) { CallbackURL: "http://coder.com", } //nolint:gocritic // OAauth2 app management requires owner permission. - _, err := client.PostOAuth2ProviderApp(ctx, req) + _, err := client.PostOAuth2ProviderApp(topCtx, req) require.NoError(t, err) // Generate an application for testing PUTs. @@ -137,13 +152,14 @@ func TestOAuthApps(t *testing.T) { CallbackURL: "http://coder.com", } //nolint:gocritic // OAauth2 app management requires owner permission. - existingApp, err := client.PostOAuth2ProviderApp(ctx, req) + existingApp, err := client.PostOAuth2ProviderApp(topCtx, req) require.NoError(t, err) for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) //nolint:gocritic // OAauth2 app management requires owner permission. _, err := client.PostOAuth2ProviderApp(ctx, test.req) @@ -162,71 +178,62 @@ func TestOAuthApps(t *testing.T) { t.Run("DeleteNonExisting", func(t *testing.T) { t.Parallel() - client, _ := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ + client, owner := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ codersdk.FeatureOAuth2Provider: 1, }, }}) + another, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) ctx := testutil.Context(t, testutil.WaitLong) - //nolint:gocritic // OAauth2 app management requires owner permission. - _, err := client.OAuth2ProviderApp(ctx, uuid.New()) + _, err := another.OAuth2ProviderApp(ctx, uuid.New()) require.Error(t, err) }) t.Run("OK", func(t *testing.T) { t.Parallel() - client, _ := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ + client, owner := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ codersdk.FeatureOAuth2Provider: 1, }, }}) + another, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) ctx := testutil.Context(t, testutil.WaitLong) // No apps yet. - //nolint:gocritic // OAauth2 app management requires owner permission. - apps, err := client.OAuth2ProviderApps(ctx) + apps, err := another.OAuth2ProviderApps(ctx, codersdk.OAuth2ProviderAppFilter{}) require.NoError(t, err) require.Len(t, apps, 0) // Should be able to add apps. - expected := []codersdk.OAuth2ProviderApp{} - for i := 0; i < 5; i++ { - postReq := codersdk.PostOAuth2ProviderAppRequest{ - Name: "foo-" + strconv.Itoa(i), - CallbackURL: "http://" + strconv.Itoa(i) + ".localhost:3000", - } - //nolint:gocritic // OAauth2 app management requires owner permission. - app, err := client.PostOAuth2ProviderApp(ctx, postReq) - require.NoError(t, err) - require.Equal(t, postReq.Name, app.Name) - require.Equal(t, postReq.CallbackURL, app.CallbackURL) - expected = append(expected, app) + expected := generateApps(ctx, t, client, "get-apps") + expectedOrder := []codersdk.OAuth2ProviderApp{ + expected.Default, expected.NoPort, expected.Subdomain, + expected.Extra[0], expected.Extra[1], } // Should get all the apps now. - //nolint:gocritic // OAauth2 app management requires owner permission. - apps, err = client.OAuth2ProviderApps(ctx) + apps, err = another.OAuth2ProviderApps(ctx, codersdk.OAuth2ProviderAppFilter{}) require.NoError(t, err) require.Len(t, apps, 5) - require.Equal(t, expected, apps) + require.Equal(t, expectedOrder, apps) // Should be able to keep the same name when updating. req := codersdk.PutOAuth2ProviderAppRequest{ - Name: expected[0].Name, + Name: expected.Default.Name, CallbackURL: "http://coder.com", Icon: "test", } //nolint:gocritic // OAauth2 app management requires owner permission. - newApp, err := client.PutOAuth2ProviderApp(ctx, expected[0].ID, req) + newApp, err := client.PutOAuth2ProviderApp(ctx, expected.Default.ID, req) require.NoError(t, err) require.Equal(t, req.Name, newApp.Name) require.Equal(t, req.CallbackURL, newApp.CallbackURL) require.Equal(t, req.Icon, newApp.Icon) - require.Equal(t, expected[0].ID, newApp.ID) + require.Equal(t, expected.Default.ID, newApp.ID) // Should be able to update name. req = codersdk.PutOAuth2ProviderAppRequest{ @@ -235,34 +242,50 @@ func TestOAuthApps(t *testing.T) { Icon: "test", } //nolint:gocritic // OAauth2 app management requires owner permission. - newApp, err = client.PutOAuth2ProviderApp(ctx, expected[0].ID, req) + newApp, err = client.PutOAuth2ProviderApp(ctx, expected.Default.ID, req) require.NoError(t, err) require.Equal(t, req.Name, newApp.Name) require.Equal(t, req.CallbackURL, newApp.CallbackURL) require.Equal(t, req.Icon, newApp.Icon) - require.Equal(t, expected[0].ID, newApp.ID) + require.Equal(t, expected.Default.ID, newApp.ID) // Should be able to get a single app. - //nolint:gocritic // OAauth2 app management requires owner permission. - got, err := client.OAuth2ProviderApp(ctx, expected[0].ID) + got, err := another.OAuth2ProviderApp(ctx, expected.Default.ID) require.NoError(t, err) require.Equal(t, newApp, got) // Should be able to delete an app. //nolint:gocritic // OAauth2 app management requires owner permission. - err = client.DeleteOAuth2ProviderApp(ctx, expected[0].ID) + err = client.DeleteOAuth2ProviderApp(ctx, expected.Default.ID) require.NoError(t, err) // Should show the new count. - //nolint:gocritic // OAauth2 app management requires owner permission. - newApps, err := client.OAuth2ProviderApps(ctx) + newApps, err := another.OAuth2ProviderApps(ctx, codersdk.OAuth2ProviderAppFilter{}) require.NoError(t, err) require.Len(t, newApps, 4) - require.Equal(t, expected[1:], newApps) + + require.Equal(t, expectedOrder[1:], newApps) + }) + + t.Run("ByUser", func(t *testing.T) { + t.Parallel() + client, owner := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureOAuth2Provider: 1, + }, + }}) + another, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) + ctx := testutil.Context(t, testutil.WaitLong) + _ = generateApps(ctx, t, client, "by-user") + apps, err := another.OAuth2ProviderApps(ctx, codersdk.OAuth2ProviderAppFilter{ + UserID: user.ID, + }) + require.NoError(t, err) + require.Len(t, apps, 0) }) } -func TestOAuthAppSecrets(t *testing.T) { +func TestOAuth2ProviderAppSecrets(t *testing.T) { t.Parallel() client, _ := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ @@ -274,19 +297,7 @@ func TestOAuthAppSecrets(t *testing.T) { topCtx := testutil.Context(t, testutil.WaitLong) // Make some apps. - //nolint:gocritic // OAauth2 app management requires owner permission. - app1, err := client.PostOAuth2ProviderApp(topCtx, codersdk.PostOAuth2ProviderAppRequest{ - Name: "razzle-dazzle", - CallbackURL: "http://localhost", - }) - require.NoError(t, err) - - //nolint:gocritic // OAauth2 app management requires owner permission. - app2, err := client.PostOAuth2ProviderApp(topCtx, codersdk.PostOAuth2ProviderAppRequest{ - Name: "razzle-dazzle-the-sequel", - CallbackURL: "http://localhost", - }) - require.NoError(t, err) + apps := generateApps(topCtx, t, client, "app-secrets") t.Run("DeleteNonExisting", func(t *testing.T) { t.Parallel() @@ -294,7 +305,7 @@ func TestOAuthAppSecrets(t *testing.T) { // Should not be able to create secrets for a non-existent app. //nolint:gocritic // OAauth2 app management requires owner permission. - _, err = client.OAuth2ProviderAppSecrets(ctx, uuid.New()) + _, err := client.OAuth2ProviderAppSecrets(ctx, uuid.New()) require.Error(t, err) // Should not be able to delete non-existing secrets when there is no app. @@ -304,16 +315,16 @@ func TestOAuthAppSecrets(t *testing.T) { // Should not be able to delete non-existing secrets when the app exists. //nolint:gocritic // OAauth2 app management requires owner permission. - err = client.DeleteOAuth2ProviderAppSecret(ctx, app1.ID, uuid.New()) + err = client.DeleteOAuth2ProviderAppSecret(ctx, apps.Default.ID, uuid.New()) require.Error(t, err) // Should not be able to delete an existing secret with the wrong app ID. //nolint:gocritic // OAauth2 app management requires owner permission. - secret, err := client.PostOAuth2ProviderAppSecret(ctx, app2.ID) + secret, err := client.PostOAuth2ProviderAppSecret(ctx, apps.NoPort.ID) require.NoError(t, err) //nolint:gocritic // OAauth2 app management requires owner permission. - err = client.DeleteOAuth2ProviderAppSecret(ctx, app1.ID, secret.ID) + err = client.DeleteOAuth2ProviderAppSecret(ctx, apps.Default.ID, secret.ID) require.Error(t, err) }) @@ -323,26 +334,26 @@ func TestOAuthAppSecrets(t *testing.T) { // No secrets yet. //nolint:gocritic // OAauth2 app management requires owner permission. - secrets, err := client.OAuth2ProviderAppSecrets(ctx, app1.ID) + secrets, err := client.OAuth2ProviderAppSecrets(ctx, apps.Default.ID) require.NoError(t, err) require.Len(t, secrets, 0) // Should be able to create secrets. for i := 0; i < 5; i++ { //nolint:gocritic // OAauth2 app management requires owner permission. - secret, err := client.PostOAuth2ProviderAppSecret(ctx, app1.ID) + secret, err := client.PostOAuth2ProviderAppSecret(ctx, apps.Default.ID) require.NoError(t, err) require.NotEmpty(t, secret.ClientSecretFull) require.True(t, len(secret.ClientSecretFull) > 6) //nolint:gocritic // OAauth2 app management requires owner permission. - _, err = client.PostOAuth2ProviderAppSecret(ctx, app2.ID) + _, err = client.PostOAuth2ProviderAppSecret(ctx, apps.NoPort.ID) require.NoError(t, err) } // Should get secrets now, but only for the one app. //nolint:gocritic // OAauth2 app management requires owner permission. - secrets, err = client.OAuth2ProviderAppSecrets(ctx, app1.ID) + secrets, err = client.OAuth2ProviderAppSecrets(ctx, apps.Default.ID) require.NoError(t, err) require.Len(t, secrets, 5) for _, secret := range secrets { @@ -351,19 +362,779 @@ func TestOAuthAppSecrets(t *testing.T) { // Should be able to delete a secret. //nolint:gocritic // OAauth2 app management requires owner permission. - err = client.DeleteOAuth2ProviderAppSecret(ctx, app1.ID, secrets[0].ID) + err = client.DeleteOAuth2ProviderAppSecret(ctx, apps.Default.ID, secrets[0].ID) require.NoError(t, err) - secrets, err = client.OAuth2ProviderAppSecrets(ctx, app1.ID) + secrets, err = client.OAuth2ProviderAppSecrets(ctx, apps.Default.ID) require.NoError(t, err) require.Len(t, secrets, 4) // No secrets once the app is deleted. //nolint:gocritic // OAauth2 app management requires owner permission. - err = client.DeleteOAuth2ProviderApp(ctx, app1.ID) + err = client.DeleteOAuth2ProviderApp(ctx, apps.Default.ID) require.NoError(t, err) //nolint:gocritic // OAauth2 app management requires owner permission. - _, err = client.OAuth2ProviderAppSecrets(ctx, app1.ID) + _, err = client.OAuth2ProviderAppSecrets(ctx, apps.Default.ID) require.Error(t, err) }) } + +func TestOAuth2ProviderTokenExchange(t *testing.T) { + t.Parallel() + + db, pubsub := dbtestutil.NewDB(t) + ownerClient, owner := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureOAuth2Provider: 1, + }, + }, + }) + topCtx := testutil.Context(t, testutil.WaitLong) + apps := generateApps(topCtx, t, ownerClient, "token-exchange") + + //nolint:gocritic // OAauth2 app management requires owner permission. + secret, err := ownerClient.PostOAuth2ProviderAppSecret(topCtx, apps.Default.ID) + require.NoError(t, err) + + // The typical oauth2 flow from this point is: + // Create an oauth2.Config using the id, secret, endpoints, and redirect: + // cfg := oauth2.Config{ ... } + // Display url for the user to click: + // userClickURL := cfg.AuthCodeURL("random_state") + // userClickURL looks like: https://idp url/authorize? + // client_id=... + // response_type=code + // redirect_uri=.. (back to backstage url) .. + // scope=... + // state=... + // *1* User clicks "Allow" on provided page above + // The redirect_uri is followed which sends back to backstage with the code and state + // Now backstage has the info to do a cfg.Exchange() in the back to get an access token. + // + // ---NOTE---: If the user has already approved this oauth app, then *1* is optional. + // Coder can just immediately redirect back to backstage without user intervention. + tests := []struct { + name string + app codersdk.OAuth2ProviderApp + // The flow is setup(ctx, client, user) -> preAuth(cfg) -> cfg.AuthCodeURL() -> preToken(cfg) -> cfg.Exchange() + setup func(context.Context, *codersdk.Client, codersdk.User) error + preAuth func(valid *oauth2.Config) + authError string + preToken func(valid *oauth2.Config) + tokenError string + + // If null, assume the code should be valid. + defaultCode *string + // custom allows some more advanced manipulation of the oauth2 exchange. + exchangeMutate []oauth2.AuthCodeOption + }{ + { + name: "AuthInParams", + app: apps.Default, + preAuth: func(valid *oauth2.Config) { + valid.Endpoint.AuthStyle = oauth2.AuthStyleInParams + }, + }, + { + name: "AuthInvalidAppID", + app: apps.Default, + preAuth: func(valid *oauth2.Config) { + valid.ClientID = uuid.NewString() + }, + authError: "Resource not found", + }, + { + name: "TokenInvalidAppID", + app: apps.Default, + preToken: func(valid *oauth2.Config) { + valid.ClientID = uuid.NewString() + }, + tokenError: "Resource not found", + }, + { + name: "InvalidPort", + app: apps.NoPort, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Host = newURL.Hostname() + ":8081" + valid.RedirectURL = newURL.String() + }, + authError: "Invalid query params", + }, + { + name: "WrongAppHost", + app: apps.Default, + preAuth: func(valid *oauth2.Config) { + valid.RedirectURL = apps.NoPort.CallbackURL + }, + authError: "Invalid query params", + }, + { + name: "InvalidHostPrefix", + app: apps.NoPort, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Host = "prefix" + newURL.Hostname() + valid.RedirectURL = newURL.String() + }, + authError: "Invalid query params", + }, + { + name: "InvalidHost", + app: apps.NoPort, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Host = "invalid" + valid.RedirectURL = newURL.String() + }, + authError: "Invalid query params", + }, + { + name: "InvalidHostAndPort", + app: apps.NoPort, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Host = "invalid:8080" + valid.RedirectURL = newURL.String() + }, + authError: "Invalid query params", + }, + { + name: "InvalidPath", + app: apps.Default, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Path = path.Join("/prepend", newURL.Path) + valid.RedirectURL = newURL.String() + }, + authError: "Invalid query params", + }, + { + name: "MissingPath", + app: apps.Default, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Path = "/" + valid.RedirectURL = newURL.String() + }, + authError: "Invalid query params", + }, + { + // TODO: This is valid for now, but should it be? + name: "DifferentProtocol", + app: apps.Default, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Scheme = "https" + valid.RedirectURL = newURL.String() + }, + }, + { + name: "NestedPath", + app: apps.Default, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Path = path.Join(newURL.Path, "nested") + valid.RedirectURL = newURL.String() + }, + }, + { + // Some oauth implementations allow this, but our users can host + // at subdomains. So we should not. + name: "Subdomain", + app: apps.Default, + preAuth: func(valid *oauth2.Config) { + newURL := must(url.Parse(valid.RedirectURL)) + newURL.Host = "sub." + newURL.Host + valid.RedirectURL = newURL.String() + }, + authError: "Invalid query params", + }, + { + name: "NoSecretScheme", + app: apps.Default, + preToken: func(valid *oauth2.Config) { + valid.ClientSecret = "1234_4321" + }, + tokenError: "Invalid client secret", + }, + { + name: "InvalidSecretScheme", + app: apps.Default, + preToken: func(valid *oauth2.Config) { + valid.ClientSecret = "notcoder_1234_4321" + }, + tokenError: "Invalid client secret", + }, + { + name: "MissingSecretSecret", + app: apps.Default, + preToken: func(valid *oauth2.Config) { + valid.ClientSecret = "coder_1234" + }, + tokenError: "Invalid client secret", + }, + { + name: "MissingSecretPrefix", + app: apps.Default, + preToken: func(valid *oauth2.Config) { + valid.ClientSecret = "coder__1234" + }, + tokenError: "Invalid client secret", + }, + { + name: "InvalidSecretPrefix", + app: apps.Default, + preToken: func(valid *oauth2.Config) { + valid.ClientSecret = "coder_1234_4321" + }, + tokenError: "Invalid client secret", + }, + { + name: "MissingSecret", + app: apps.Default, + preToken: func(valid *oauth2.Config) { + valid.ClientSecret = "" + }, + tokenError: "Invalid query params", + }, + { + name: "NoCodeScheme", + app: apps.Default, + defaultCode: ptr.Ref("1234_4321"), + tokenError: "Invalid code", + }, + { + name: "InvalidCodeScheme", + app: apps.Default, + defaultCode: ptr.Ref("notcoder_1234_4321"), + tokenError: "Invalid code", + }, + { + name: "MissingCodeSecret", + app: apps.Default, + defaultCode: ptr.Ref("coder_1234"), + tokenError: "Invalid code", + }, + { + name: "MissingCodePrefix", + app: apps.Default, + defaultCode: ptr.Ref("coder__1234"), + tokenError: "Invalid code", + }, + { + name: "InvalidCodePrefix", + app: apps.Default, + defaultCode: ptr.Ref("coder_1234_4321"), + tokenError: "Invalid code", + }, + { + name: "MissingCode", + app: apps.Default, + defaultCode: ptr.Ref(""), + tokenError: "Invalid query params", + }, + { + name: "InvalidGrantType", + app: apps.Default, + tokenError: "Invalid query params", + exchangeMutate: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("grant_type", "foobar"), + }, + }, + { + name: "EmptyGrantType", + app: apps.Default, + tokenError: "Invalid query params", + exchangeMutate: []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("grant_type", ""), + }, + }, + { + name: "ExpiredCode", + app: apps.Default, + defaultCode: ptr.Ref("coder_prefix_code"), + tokenError: "Invalid code", + setup: func(ctx context.Context, client *codersdk.Client, user codersdk.User) error { + // Insert an expired code. + hashedCode, err := userpassword.Hash("prefix_code") + if err != nil { + return err + } + _, err = db.InsertOAuth2ProviderAppCode(ctx, database.InsertOAuth2ProviderAppCodeParams{ + ID: uuid.New(), + CreatedAt: dbtime.Now().Add(-time.Minute * 11), + ExpiresAt: dbtime.Now().Add(-time.Minute), + SecretPrefix: []byte("prefix"), + HashedSecret: []byte(hashedCode), + AppID: apps.Default.ID, + UserID: user.ID, + }) + return err + }, + }, + { + name: "OK", + app: apps.Default, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + // Each test gets its own user, since we allow only one code per user and + // app at a time and running tests in parallel could clobber each other. + userClient, user := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID) + if test.setup != nil { + err := test.setup(ctx, userClient, user) + require.NoError(t, err) + } + + // Each test gets its own oauth2.Config so they can run in parallel. + // In practice, you would only use 1 as a singleton. + valid := &oauth2.Config{ + ClientID: test.app.ID.String(), + ClientSecret: secret.ClientSecretFull, + Endpoint: oauth2.Endpoint{ + AuthURL: test.app.Endpoints.Authorization, + DeviceAuthURL: test.app.Endpoints.DeviceAuth, + TokenURL: test.app.Endpoints.Token, + // TODO: @emyrk we should support both types. + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: test.app.CallbackURL, + Scopes: []string{}, + } + + if test.preAuth != nil { + test.preAuth(valid) + } + + var code string + if test.defaultCode != nil { + code = *test.defaultCode + } else { + var err error + code, err = authorizationFlow(ctx, userClient, valid) + if test.authError != "" { + require.Error(t, err) + require.ErrorContains(t, err, test.authError) + // If this errors the token exchange will fail. So end here. + return + } + require.NoError(t, err) + } + + // Mutate the valid config for the exchange. + if test.preToken != nil { + test.preToken(valid) + } + + // Do the actual exchange. + token, err := valid.Exchange(ctx, code, test.exchangeMutate...) + if test.tokenError != "" { + require.Error(t, err) + require.ErrorContains(t, err, test.tokenError) + } else { + require.NoError(t, err) + require.NotEmpty(t, token.AccessToken) + require.True(t, time.Now().After(token.Expiry)) + + // Check that the token works. + newClient := codersdk.New(userClient.URL) + newClient.SetSessionToken(token.AccessToken) + + gotUser, err := newClient.User(ctx, codersdk.Me) + require.NoError(t, err) + require.Equal(t, user.ID, gotUser.ID) + } + }) + } +} + +func TestOAuth2ProviderTokenRefresh(t *testing.T) { + t.Parallel() + topCtx := testutil.Context(t, testutil.WaitLong) + + db, pubsub := dbtestutil.NewDB(t) + ownerClient, owner := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureOAuth2Provider: 1, + }, + }, + }) + apps := generateApps(topCtx, t, ownerClient, "token-refresh") + + //nolint:gocritic // OAauth2 app management requires owner permission. + secret, err := ownerClient.PostOAuth2ProviderAppSecret(topCtx, apps.Default.ID) + require.NoError(t, err) + + // One path not tested here is when the token is empty, because Go's OAuth2 + // client library will not even try to make the request. + tests := []struct { + name string + app codersdk.OAuth2ProviderApp + // If null, assume the token should be valid. + defaultToken *string + error string + expires time.Time + }{ + { + name: "NoTokenScheme", + app: apps.Default, + defaultToken: ptr.Ref("1234_4321"), + error: "Invalid token", + }, + { + name: "InvalidTokenScheme", + app: apps.Default, + defaultToken: ptr.Ref("notcoder_1234_4321"), + error: "Invalid token", + }, + { + name: "MissingTokenSecret", + app: apps.Default, + defaultToken: ptr.Ref("coder_1234"), + error: "Invalid token", + }, + { + name: "MissingTokenPrefix", + app: apps.Default, + defaultToken: ptr.Ref("coder__1234"), + error: "Invalid token", + }, + { + name: "InvalidTokenPrefix", + app: apps.Default, + defaultToken: ptr.Ref("coder_1234_4321"), + error: "Invalid token", + }, + { + name: "Expired", + app: apps.Default, + expires: time.Now().Add(time.Minute * -1), + error: "Invalid token", + }, + { + name: "OK", + app: apps.Default, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + userClient, user := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID) + + // Insert the token and its key. + key, sessionToken, err := apikey.Generate(apikey.CreateParams{ + UserID: user.ID, + LoginType: database.LoginTypeOAuth2ProviderApp, + ExpiresAt: time.Now().Add(time.Hour * 10), + }) + require.NoError(t, err) + + newKey, err := db.InsertAPIKey(ctx, key) + require.NoError(t, err) + + token, err := identityprovider.GenerateSecret() + require.NoError(t, err) + + expires := test.expires + if expires.IsZero() { + expires = time.Now().Add(time.Hour * 10) + } + + _, err = db.InsertOAuth2ProviderAppToken(ctx, database.InsertOAuth2ProviderAppTokenParams{ + ID: uuid.New(), + CreatedAt: dbtime.Now(), + ExpiresAt: expires, + HashPrefix: []byte(token.Prefix), + RefreshHash: []byte(token.Hashed), + AppSecretID: secret.ID, + APIKeyID: newKey.ID, + }) + require.NoError(t, err) + + // Check that the key works. + newClient := codersdk.New(userClient.URL) + newClient.SetSessionToken(sessionToken) + gotUser, err := newClient.User(ctx, codersdk.Me) + require.NoError(t, err) + require.Equal(t, user.ID, gotUser.ID) + + cfg := &oauth2.Config{ + ClientID: test.app.ID.String(), + ClientSecret: secret.ClientSecretFull, + Endpoint: oauth2.Endpoint{ + AuthURL: test.app.Endpoints.Authorization, + DeviceAuthURL: test.app.Endpoints.DeviceAuth, + TokenURL: test.app.Endpoints.Token, + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: test.app.CallbackURL, + Scopes: []string{}, + } + + // Test whether it can be refreshed. + refreshToken := token.Formatted + if test.defaultToken != nil { + refreshToken = *test.defaultToken + } + refreshed, err := cfg.TokenSource(ctx, &oauth2.Token{ + AccessToken: sessionToken, + RefreshToken: refreshToken, + Expiry: time.Now().Add(time.Minute * -1), + }).Token() + + if test.error != "" { + require.Error(t, err) + require.ErrorContains(t, err, test.error) + } else { + require.NoError(t, err) + require.NotEmpty(t, refreshed.AccessToken) + + // Old token is now invalid. + _, err = newClient.User(ctx, codersdk.Me) + require.Error(t, err) + require.ErrorContains(t, err, "401") + + // Refresh token is valid. + newClient := codersdk.New(userClient.URL) + newClient.SetSessionToken(refreshed.AccessToken) + + gotUser, err := newClient.User(ctx, codersdk.Me) + require.NoError(t, err) + require.Equal(t, user.ID, gotUser.ID) + } + }) + } +} + +type exchangeSetup struct { + cfg *oauth2.Config + app codersdk.OAuth2ProviderApp + secret codersdk.OAuth2ProviderAppSecretFull + code string +} + +func TestOAuth2ProviderRevoke(t *testing.T) { + t.Parallel() + + client, owner := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureOAuth2Provider: 1, + }, + }}) + + tests := []struct { + name string + // fn performs some action that removes the user's code and token. + fn func(context.Context, *codersdk.Client, exchangeSetup) + // replacesToken specifies whether the action replaces the token or only + // deletes it. + replacesToken bool + }{ + { + name: "DeleteApp", + fn: func(ctx context.Context, _ *codersdk.Client, s exchangeSetup) { + //nolint:gocritic // OAauth2 app management requires owner permission. + err := client.DeleteOAuth2ProviderApp(ctx, s.app.ID) + require.NoError(t, err) + }, + }, + { + name: "DeleteSecret", + fn: func(ctx context.Context, _ *codersdk.Client, s exchangeSetup) { + //nolint:gocritic // OAauth2 app management requires owner permission. + err := client.DeleteOAuth2ProviderAppSecret(ctx, s.app.ID, s.secret.ID) + require.NoError(t, err) + }, + }, + { + name: "DeleteToken", + fn: func(ctx context.Context, client *codersdk.Client, s exchangeSetup) { + err := client.RevokeOAuth2ProviderApp(ctx, s.app.ID) + require.NoError(t, err) + }, + }, + { + name: "OverrideCodeAndToken", + fn: func(ctx context.Context, client *codersdk.Client, s exchangeSetup) { + // Generating a new code should wipe out the old code. + code, err := authorizationFlow(ctx, client, s.cfg) + require.NoError(t, err) + + // Generating a new token should wipe out the old token. + _, err = s.cfg.Exchange(ctx, code) + require.NoError(t, err) + }, + replacesToken: true, + }, + } + + setup := func(ctx context.Context, testClient *codersdk.Client, name string) exchangeSetup { + // We need a new app each time because we only allow one code and token per + // app and user at the moment and because the test might delete the app. + //nolint:gocritic // OAauth2 app management requires owner permission. + app, err := client.PostOAuth2ProviderApp(ctx, codersdk.PostOAuth2ProviderAppRequest{ + Name: name, + CallbackURL: "http://localhost", + }) + require.NoError(t, err) + + // We need a new secret every time because the test might delete the secret. + //nolint:gocritic // OAauth2 app management requires owner permission. + secret, err := client.PostOAuth2ProviderAppSecret(ctx, app.ID) + require.NoError(t, err) + + cfg := &oauth2.Config{ + ClientID: app.ID.String(), + ClientSecret: secret.ClientSecretFull, + Endpoint: oauth2.Endpoint{ + AuthURL: app.Endpoints.Authorization, + DeviceAuthURL: app.Endpoints.DeviceAuth, + TokenURL: app.Endpoints.Token, + AuthStyle: oauth2.AuthStyleInParams, + }, + RedirectURL: app.CallbackURL, + Scopes: []string{}, + } + + // Go through the auth flow to get a code. + code, err := authorizationFlow(ctx, testClient, cfg) + require.NoError(t, err) + + return exchangeSetup{ + cfg: cfg, + app: app, + secret: secret, + code: code, + } + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + testClient, testUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) + + testEntities := setup(ctx, testClient, test.name+"-1") + + // Delete before the exchange completes (code should delete and attempting + // to finish the exchange should fail). + test.fn(ctx, testClient, testEntities) + + // Exchange should fail because the code should be gone. + _, err := testEntities.cfg.Exchange(ctx, testEntities.code) + require.Error(t, err) + + // Try again, this time letting the exchange complete first. + testEntities = setup(ctx, testClient, test.name+"-2") + token, err := testEntities.cfg.Exchange(ctx, testEntities.code) + require.NoError(t, err) + + // Validate the returned access token and that the app is listed. + newClient := codersdk.New(client.URL) + newClient.SetSessionToken(token.AccessToken) + + gotUser, err := newClient.User(ctx, codersdk.Me) + require.NoError(t, err) + require.Equal(t, testUser.ID, gotUser.ID) + + filter := codersdk.OAuth2ProviderAppFilter{UserID: testUser.ID} + apps, err := testClient.OAuth2ProviderApps(ctx, filter) + require.NoError(t, err) + require.Contains(t, apps, testEntities.app) + + // Should not show up for another user. + apps, err = client.OAuth2ProviderApps(ctx, codersdk.OAuth2ProviderAppFilter{UserID: owner.UserID}) + require.NoError(t, err) + require.Len(t, apps, 0) + + // Perform the deletion. + test.fn(ctx, testClient, testEntities) + + // App should no longer show up for the user unless it was replaced. + if !test.replacesToken { + apps, err = testClient.OAuth2ProviderApps(ctx, filter) + require.NoError(t, err) + require.NotContains(t, apps, testEntities.app, fmt.Sprintf("contains %q", testEntities.app.Name)) + } + + // The token should no longer be valid. + _, err = newClient.User(ctx, codersdk.Me) + require.Error(t, err) + require.ErrorContains(t, err, "401") + }) + } +} + +type provisionedApps struct { + Default codersdk.OAuth2ProviderApp + NoPort codersdk.OAuth2ProviderApp + Subdomain codersdk.OAuth2ProviderApp + // For sorting purposes these are included. You will likely never touch them. + Extra []codersdk.OAuth2ProviderApp +} + +func generateApps(ctx context.Context, t *testing.T, client *codersdk.Client, suffix string) provisionedApps { + create := func(name, callback string) codersdk.OAuth2ProviderApp { + name = fmt.Sprintf("%s-%s", name, suffix) + //nolint:gocritic // OAauth2 app management requires owner permission. + app, err := client.PostOAuth2ProviderApp(ctx, codersdk.PostOAuth2ProviderAppRequest{ + Name: name, + CallbackURL: callback, + Icon: "", + }) + require.NoError(t, err) + require.Equal(t, name, app.Name) + require.Equal(t, callback, app.CallbackURL) + return app + } + + return provisionedApps{ + Default: create("razzle-dazzle-a", "http://localhost1:8080/foo/bar"), + NoPort: create("razzle-dazzle-b", "http://localhost2"), + Subdomain: create("razzle-dazzle-z", "http://30.localhost:3000"), + Extra: []codersdk.OAuth2ProviderApp{ + create("second-to-last", "http://20.localhost:3000"), + create("woo-10", "http://10.localhost:3000"), + }, + } +} + +func authorizationFlow(ctx context.Context, client *codersdk.Client, cfg *oauth2.Config) (string, error) { + state := uuid.NewString() + return oidctest.OAuth2GetCode( + cfg.AuthCodeURL(state), + func(req *http.Request) (*http.Response, error) { + // TODO: Would be better if client had a .Do() method. + // TODO: Is this the best way to handle redirects? + client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + return client.Request(ctx, req.Method, req.URL.String(), nil, func(req *http.Request) { + // Set the referer so the request bypasses the HTML page (normally you + // have to click "allow" first, and the way we detect that is using the + // referer header). + req.Header.Set("Referer", req.URL.String()) + }) + }, + ) +} diff --git a/site/site.go b/site/site.go index 4da69e6b3a..7875fa3140 100644 --- a/site/site.go +++ b/site/site.go @@ -51,6 +51,11 @@ var ( errorHTML string errorTemplate *htmltemplate.Template + + //go:embed static/oauth2allow.html + oauthHTML string + + oauthTemplate *htmltemplate.Template ) func init() { @@ -59,6 +64,11 @@ func init() { if err != nil { panic(err) } + + oauthTemplate, err = htmltemplate.New("error").Parse(oauthHTML) + if err != nil { + panic(err) + } } type Options struct { @@ -914,3 +924,31 @@ func (jfs justFilesSystem) Open(name string) (fs.File, error) { return f, nil } + +// RenderOAuthAllowData contains the variables that are found in +// site/static/oauth2allow.html. +type RenderOAuthAllowData struct { + AppIcon string + AppName string + CancelURI string + RedirectURI string + Username string +} + +// RenderOAuthAllowPage renders the static page for a user to "Allow" an create +// a new oauth2 link with an external site. This is when Coder is acting as the +// identity provider. +// +// This has to be done statically because Golang has to handle the full request. +// It cannot defer to the FE typescript easily. +func RenderOAuthAllowPage(rw http.ResponseWriter, r *http.Request, data RenderOAuthAllowData) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + + err := oauthTemplate.Execute(rw, data) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ + Message: "Failed to render oauth page: " + err.Error(), + }) + return + } +} diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 9c72b8c1c9..1d746273a8 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -707,6 +707,11 @@ export interface OAuth2ProviderApp { readonly endpoints: OAuth2AppEndpoints; } +// From codersdk/oauth2.go +export interface OAuth2ProviderAppFilter { + readonly user_id?: string; +} + // From codersdk/oauth2.go export interface OAuth2ProviderAppSecret { readonly id: string; @@ -1988,6 +1993,19 @@ export const LoginTypes: LoginType[] = [ "token", ]; +// From codersdk/oauth2.go +export type OAuth2ProviderGrantType = "authorization_code" | "refresh_token"; +export const OAuth2ProviderGrantTypes: OAuth2ProviderGrantType[] = [ + "authorization_code", + "refresh_token", +]; + +// From codersdk/oauth2.go +export type OAuth2ProviderResponseType = "code"; +export const OAuth2ProviderResponseTypes: OAuth2ProviderResponseType[] = [ + "code", +]; + // From codersdk/provisionerdaemons.go export type ProvisionerJobStatus = | "canceled" diff --git a/site/static/oauth2allow.html b/site/static/oauth2allow.html new file mode 100644 index 0000000000..a7a7aaffc3 --- /dev/null +++ b/site/static/oauth2allow.html @@ -0,0 +1,168 @@ +{{/* This template is used by application handlers to render allowing oauth2 +links */}} + + + + + + + Application {{.AppName}} + + + +
+
+ {{- if .AppIcon }} + +
+
+ {{end}} + + + + + + + + + + + + + + + +
+

Authorize {{ .AppName }}

+

+ Allow {{ .AppName }} to have full access to your + {{ .Username }} account? +

+
+ Allow + Cancel +
+
+ +