feat: allow storing extra oauth token properties in the database (#10152)

This commit is contained in:
Kyle Carberry 2023-10-09 18:49:30 -05:00 committed by GitHub
parent 35538e1051
commit 863c2e7b64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 223 additions and 60 deletions

View File

@ -2251,6 +2251,8 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder
provider.NoRefresh = b
case "SCOPES":
provider.Scopes = strings.Split(v.Value, " ")
case "EXTRA_TOKEN_KEYS":
provider.ExtraTokenKeys = strings.Split(v.Value, " ")
case "APP_INSTALL_URL":
provider.AppInstallURL = v.Value
case "APP_INSTALLATIONS_URL":

6
coderd/apidoc/docs.go generated
View File

@ -8381,6 +8381,12 @@ const docTemplate = `{
"description": "DisplayName is shown in the UI to identify the auth config.",
"type": "string"
},
"extra_token_keys": {
"type": "array",
"items": {
"type": "string"
}
},
"id": {
"description": "ID is a unique identifier for the auth config.\nIt defaults to ` + "`" + `type` + "`" + ` when not provided.",
"type": "string"

View File

@ -7515,6 +7515,12 @@
"description": "DisplayName is shown in the UI to identify the auth config.",
"type": "string"
},
"extra_token_keys": {
"type": "array",
"items": {
"type": "string"
}
},
"id": {
"description": "ID is a unique identifier for the auth config.\nIt defaults to `type` when not provided.",
"type": "string"

View File

@ -68,6 +68,7 @@ type FakeIDP struct {
// "Authorized Redirect URLs". This can be used to emulate that.
hookValidRedirectURL func(redirectURL string) error
hookUserInfo func(email string) (jwt.MapClaims, error)
hookMutateToken func(token map[string]interface{})
fakeCoderd func(req *http.Request) (*http.Response, error)
hookOnRefresh func(email string) error
// Custom authentication for the client. This is useful if you want
@ -112,6 +113,14 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) {
}
}
// WithExtra returns extra fields that be accessed on the returned Oauth Token.
// These extra fields can override the default fields (id_token, access_token, etc).
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookMutateToken = mutateToken
}
}
func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookAuthenticateClient = hook
@ -621,6 +630,9 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
"expires_in": int64((time.Minute * 5).Seconds()),
"id_token": f.encodeClaims(t, claims),
}
if f.hookMutateToken != nil {
f.hookMutateToken(token)
}
// Store the claims for the next refresh
f.refreshIDTokenClaims.Store(refreshToken, claims)

View File

@ -4246,6 +4246,7 @@ func (q *FakeQuerier) InsertExternalAuthLink(_ context.Context, arg database.Ins
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID,
OAuthExpiry: arg.OAuthExpiry,
OAuthExtra: arg.OAuthExtra,
}
q.externalAuthLinks = append(q.externalAuthLinks, gitAuthLink)
return gitAuthLink, nil
@ -5301,6 +5302,7 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID
gitAuthLink.OAuthExpiry = arg.OAuthExpiry
gitAuthLink.OAuthExtra = arg.OAuthExtra
q.externalAuthLinks[index] = gitAuthLink
return gitAuthLink, nil

View File

@ -514,6 +514,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
}
func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAuthLink) database.ExternalAuthLink {
msg := takeFirst(&orig.OAuthExtra, &pqtype.NullRawMessage{})
link, err := db.InsertExternalAuthLink(genCtx, database.InsertExternalAuthLinkParams{
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
UserID: takeFirst(orig.UserID, uuid.New()),
@ -524,6 +525,7 @@ func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAut
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
OAuthExtra: *msg,
})
require.NoError(t, err, "insert external auth link")

View File

@ -359,7 +359,8 @@ CREATE TABLE external_auth_links (
oauth_refresh_token text NOT NULL,
oauth_expiry timestamp with time zone NOT NULL,
oauth_access_token_key_id text,
oauth_refresh_token_key_id text
oauth_refresh_token_key_id text,
oauth_extra jsonb
);
COMMENT ON COLUMN external_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';

View File

@ -0,0 +1 @@
ALTER TABLE external_auth_links DROP COLUMN "oauth_extra";

View File

@ -0,0 +1 @@
ALTER TABLE external_auth_links ADD COLUMN "oauth_extra" jsonb;

View File

@ -1680,7 +1680,8 @@ type ExternalAuthLink struct {
// The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
// The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
}
type File struct {

View File

@ -751,7 +751,7 @@ func (q *sqlQuerier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest strin
}
const getExternalAuthLink = `-- name: GetExternalAuthLink :one
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM external_auth_links WHERE provider_id = $1 AND user_id = $2
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra FROM external_auth_links WHERE provider_id = $1 AND user_id = $2
`
type GetExternalAuthLinkParams struct {
@ -772,12 +772,13 @@ func (q *sqlQuerier) GetExternalAuthLink(ctx context.Context, arg GetExternalAut
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
&i.OAuthExtra,
)
return i, err
}
const getExternalAuthLinksByUserID = `-- name: GetExternalAuthLinksByUserID :many
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM external_auth_links WHERE user_id = $1
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra FROM external_auth_links WHERE user_id = $1
`
func (q *sqlQuerier) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error) {
@ -799,6 +800,7 @@ func (q *sqlQuerier) GetExternalAuthLinksByUserID(ctx context.Context, userID uu
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
&i.OAuthExtra,
); err != nil {
return nil, err
}
@ -823,7 +825,8 @@ INSERT INTO external_auth_links (
oauth_access_token_key_id,
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry
oauth_expiry,
oauth_extra
) VALUES (
$1,
$2,
@ -833,20 +836,22 @@ INSERT INTO external_auth_links (
$6,
$7,
$8,
$9
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
$9,
$10
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra
`
type InsertExternalAuthLinkParams struct {
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
}
func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExternalAuthLinkParams) (ExternalAuthLink, error) {
@ -860,6 +865,7 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter
arg.OAuthRefreshToken,
arg.OAuthRefreshTokenKeyID,
arg.OAuthExpiry,
arg.OAuthExtra,
)
var i ExternalAuthLink
err := row.Scan(
@ -872,6 +878,7 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
&i.OAuthExtra,
)
return i, err
}
@ -883,19 +890,21 @@ UPDATE external_auth_links SET
oauth_access_token_key_id = $5,
oauth_refresh_token = $6,
oauth_refresh_token_key_id = $7,
oauth_expiry = $8
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
oauth_expiry = $8,
oauth_extra = $9
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra
`
type UpdateExternalAuthLinkParams struct {
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
}
func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) {
@ -908,6 +917,7 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter
arg.OAuthRefreshToken,
arg.OAuthRefreshTokenKeyID,
arg.OAuthExpiry,
arg.OAuthExtra,
)
var i ExternalAuthLink
err := row.Scan(
@ -920,6 +930,7 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
&i.OAuthExtra,
)
return i, err
}

View File

@ -14,7 +14,8 @@ INSERT INTO external_auth_links (
oauth_access_token_key_id,
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry
oauth_expiry,
oauth_extra
) VALUES (
$1,
$2,
@ -24,7 +25,8 @@ INSERT INTO external_auth_links (
$6,
$7,
$8,
$9
$9,
$10
) RETURNING *;
-- name: UpdateExternalAuthLink :one
@ -34,5 +36,6 @@ UPDATE external_auth_links SET
oauth_access_token_key_id = $5,
oauth_refresh_token = $6,
oauth_refresh_token_key_id = $7,
oauth_expiry = $8
oauth_expiry = $8,
oauth_extra = $9
WHERE provider_id = $1 AND user_id = $2 RETURNING *;

View File

@ -53,6 +53,7 @@ overrides:
oauth_id_token: OAuthIDToken
oauth_refresh_token: OAuthRefreshToken
oauth_refresh_token_key_id: OAuthRefreshTokenKeyID
oauth_extra: OAuthExtra
parameter_type_system_hcl: ParameterTypeSystemHCL
userstatus: UserStatus
gitsshkey: GitSSHKey

View File

@ -14,6 +14,7 @@ import (
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/sqlc-dev/pqtype"
)
// @Summary Get external auth by ID
@ -132,6 +133,8 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required
OAuthExpiry: token.Expiry,
// No extra data from device auth!
OAuthExtra: pqtype.NullRawMessage{},
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -150,6 +153,7 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: token.Expiry,
OAuthExtra: pqtype.NullRawMessage{},
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -201,7 +205,15 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
apiKey = httpmw.APIKey(r)
)
_, err := api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
extra, err := externalAuthConfig.GenerateTokenExtra(state.Token)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to generate token extra.",
Detail: err.Error(),
})
return
}
_, err = api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: externalAuthConfig.ID,
UserID: apiKey.UserID,
})
@ -224,6 +236,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
OAuthRefreshToken: state.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required
OAuthExpiry: state.Token.Expiry,
OAuthExtra: extra,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -242,6 +255,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
OAuthRefreshToken: state.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: state.Token.Expiry,
OAuthExtra: extra,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{

View File

@ -15,6 +15,7 @@ import (
"golang.org/x/xerrors"
"github.com/google/go-github/v43/github"
"github.com/sqlc-dev/pqtype"
xgithub "golang.org/x/oauth2/github"
"github.com/coder/coder/v2/coderd/database"
@ -44,6 +45,14 @@ type Config struct {
// DisplayIcon is the path to an image that will be displayed to the user.
DisplayIcon string
// ExtraTokenKeys is a list of extra properties to
// store in the database returned from the token endpoint.
//
// e.g. Slack returns `authed_user` in the token which is
// a payload that contains information about the authenticated
// user.
ExtraTokenKeys []string
// NoRefresh stops Coder from using the refresh token
// to renew the access token.
//
@ -69,6 +78,25 @@ type Config struct {
AppInstallationsURL string
}
// GenerateTokenExtra generates the extra token data to store in the database.
func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) {
if len(c.ExtraTokenKeys) == 0 {
return pqtype.NullRawMessage{}, nil
}
extraMap := map[string]interface{}{}
for _, key := range c.ExtraTokenKeys {
extraMap[key] = token.Extra(key)
}
data, err := json.Marshal(extraMap)
if err != nil {
return pqtype.NullRawMessage{}, err
}
return pqtype.NullRawMessage{
RawMessage: data,
Valid: true,
}, nil
}
// RefreshToken automatically refreshes the token if expired and permitted.
// It returns the token and a bool indicating if the token is valid.
func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAuthLink database.ExternalAuthLink) (database.ExternalAuthLink, bool, error) {
@ -101,6 +129,12 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
// we aren't trying to surface an error, we're just trying to obtain a valid token.
return externalAuthLink, false, nil
}
extra, err := c.GenerateTokenExtra(token)
if err != nil {
return externalAuthLink, false, xerrors.Errorf("generate token extra: %w", err)
}
r := retry.New(50*time.Millisecond, 200*time.Millisecond)
// See the comment below why the retry and cancel is required.
retryCtx, retryCtxCancel := context.WithTimeout(ctx, time.Second)
@ -135,6 +169,7 @@ validate:
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: token.Expiry,
OAuthExtra: extra,
})
if err != nil {
return updatedAuthLink, false, xerrors.Errorf("update external auth link: %w", err)
@ -539,6 +574,14 @@ var defaults = map[codersdk.EnhancedExternalAuthProvider]codersdk.ExternalAuthCo
DeviceCodeURL: "https://github.com/login/device/code",
AppInstallationsURL: "https://api.github.com/user/installations",
},
codersdk.EnhancedExternalAuthProviderSlack: {
AuthURL: "https://slack.com/oauth/v2/authorize",
TokenURL: "https://slack.com/api/oauth.v2.access",
DisplayName: "Slack",
DisplayIcon: "/icon/slack.svg",
// See: https://api.slack.com/authentication/oauth-v2#exchanging
ExtraTokenKeys: []string{"authed_user"},
},
}
// jwtConfig is a new OAuth2 config that uses a custom

View File

@ -2,6 +2,7 @@ package externalauth_test
import (
"context"
"encoding/json"
"net/http"
"net/url"
"testing"
@ -43,7 +44,7 @@ func TestRefreshToken(t *testing.T) {
return nil, xerrors.New("should not be called")
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.NoRefresh = true
},
})
@ -74,7 +75,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, nil
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.NoRefresh = true
},
})
@ -117,7 +118,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, xerrors.New(staticError)
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
},
})
@ -142,7 +143,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
},
})
@ -175,7 +176,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
})
@ -205,7 +206,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, nil
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
})
@ -236,7 +237,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, nil
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
DB: db,
@ -260,6 +261,41 @@ func TestRefreshToken(t *testing.T) {
require.NoError(t, err)
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB")
})
t.Run("WithExtra", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
fake, config, link := setupOauth2Test(t, testConfig{
FakeIDPOpts: []oidctest.FakeIDPOpt{
oidctest.WithMutateToken(func(token map[string]interface{}) {
token["authed_user"] = map[string]interface{}{
"access_token": token["access_token"],
}
}),
},
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderSlack.String()
cfg.ExtraTokenKeys = []string{"authed_user"}
cfg.ValidateURL = ""
},
DB: db,
})
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
// Force a refresh
link.OAuthExpiry = expired
updated, ok, err := config.RefreshToken(ctx, db, link)
require.NoError(t, err)
require.True(t, ok)
require.True(t, updated.OAuthExtra.Valid)
extra := map[string]interface{}{}
require.NoError(t, json.Unmarshal(updated.OAuthExtra.RawMessage, &extra))
mapping, ok := extra["authed_user"].(map[string]interface{})
require.True(t, ok)
require.Equal(t, updated.OAuthAccessToken, mapping["access_token"])
})
}
func TestConvertYAML(t *testing.T) {
@ -344,7 +380,7 @@ func TestConvertYAML(t *testing.T) {
type testConfig struct {
FakeIDPOpts []oidctest.FakeIDPOpt
CoderOIDCConfigOpts []func(cfg *coderd.OIDCConfig)
GitConfigOpt func(cfg *externalauth.Config)
ExternalAuthOpt func(cfg *externalauth.Config)
// If DB is passed in, the link will be inserted into the DB.
DB database.Store
}
@ -367,7 +403,7 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext
ID: providerID,
ValidateURL: fake.WellknownConfig().UserInfoURL,
}
settings.GitConfigOpt(config)
settings.ExternalAuthOpt(config)
oauthToken, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
"email": "test@coder.com",

View File

@ -336,6 +336,7 @@ type ExternalAuthConfig struct {
AppInstallationsURL string `json:"app_installations_url"`
NoRefresh bool `json:"no_refresh"`
Scopes []string `json:"scopes"`
ExtraTokenKeys []string `json:"extra_token_keys"`
DeviceFlow bool `json:"device_flow"`
DeviceCodeURL string `json:"device_code_url"`
// Regex allows API requesters to match an auth config by

View File

@ -34,6 +34,7 @@ const (
EnhancedExternalAuthProviderGitHub EnhancedExternalAuthProvider = "github"
EnhancedExternalAuthProviderGitLab EnhancedExternalAuthProvider = "gitlab"
EnhancedExternalAuthProviderBitBucket EnhancedExternalAuthProvider = "bitbucket"
EnhancedExternalAuthProviderSlack EnhancedExternalAuthProvider = "slack"
)
type ExternalAuth struct {

1
docs/api/general.md generated
View File

@ -223,6 +223,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
"device_flow": true,
"display_icon": "string",
"display_name": "string",
"extra_token_keys": ["string"],
"id": "string",
"no_refresh": true,
"regex": "string",

5
docs/api/schemas.md generated
View File

@ -638,6 +638,7 @@ _None_
"device_flow": true,
"display_icon": "string",
"display_name": "string",
"extra_token_keys": ["string"],
"id": "string",
"no_refresh": true,
"regex": "string",
@ -2077,6 +2078,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
"device_flow": true,
"display_icon": "string",
"display_name": "string",
"extra_token_keys": ["string"],
"id": "string",
"no_refresh": true,
"regex": "string",
@ -2444,6 +2446,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
"device_flow": true,
"display_icon": "string",
"display_name": "string",
"extra_token_keys": ["string"],
"id": "string",
"no_refresh": true,
"regex": "string",
@ -2856,6 +2859,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
"device_flow": true,
"display_icon": "string",
"display_name": "string",
"extra_token_keys": ["string"],
"id": "string",
"no_refresh": true,
"regex": "string",
@ -2878,6 +2882,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
| `device_flow` | boolean | false | | |
| `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. |
| `display_name` | string | false | | Display name is shown in the UI to identify the auth config. |
| `extra_token_keys` | array of string | false | | |
| `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. |
| `no_refresh` | boolean | false | | |
| `regex` | string | false | | Regex allows API requesters to match an auth config by a string (e.g. coder.com) instead of by it's type. |

View File

@ -48,26 +48,27 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
}
}
gitAuthLinks, err := cryptTx.GetExternalAuthLinksByUserID(ctx, uid)
externalAuthLinks, err := cryptTx.GetExternalAuthLinksByUserID(ctx, uid)
if err != nil {
return xerrors.Errorf("get git auth links for user: %w", err)
}
for _, gitAuthLink := range gitAuthLinks {
if gitAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && gitAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
for _, externalAuthLink := range externalAuthLinks {
if externalAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && externalAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping external auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := cryptTx.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
ProviderID: gitAuthLink.ProviderID,
ProviderID: externalAuthLink.ProviderID,
UserID: uid,
UpdatedAt: gitAuthLink.UpdatedAt,
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
UpdatedAt: externalAuthLink.UpdatedAt,
OAuthAccessToken: externalAuthLink.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
OAuthRefreshToken: externalAuthLink.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: gitAuthLink.OAuthExpiry,
OAuthExpiry: externalAuthLink.OAuthExpiry,
OAuthExtra: externalAuthLink.OAuthExtra,
}); err != nil {
return xerrors.Errorf("update git auth link user_id=%s provider_id=%s: %w", gitAuthLink.UserID, gitAuthLink.ProviderID, err)
return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err)
}
}
return nil
@ -136,26 +137,27 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
}
}
gitAuthLinks, err := tx.GetExternalAuthLinksByUserID(ctx, uid)
externalAuthLinks, err := tx.GetExternalAuthLinksByUserID(ctx, uid)
if err != nil {
return xerrors.Errorf("get git auth links for user: %w", err)
}
for _, gitAuthLink := range gitAuthLinks {
if !gitAuthLink.OAuthAccessTokenKeyID.Valid && !gitAuthLink.OAuthRefreshTokenKeyID.Valid {
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1))
for _, externalAuthLink := range externalAuthLinks {
if !externalAuthLink.OAuthAccessTokenKeyID.Valid && !externalAuthLink.OAuthRefreshTokenKeyID.Valid {
log.Debug(ctx, "skipping external auth link", slog.F("user_id", uid), slog.F("current", idx+1))
continue
}
if _, err := tx.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
ProviderID: gitAuthLink.ProviderID,
ProviderID: externalAuthLink.ProviderID,
UserID: uid,
UpdatedAt: gitAuthLink.UpdatedAt,
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
UpdatedAt: externalAuthLink.UpdatedAt,
OAuthAccessToken: externalAuthLink.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // we explicitly want to clear the key id
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
OAuthRefreshToken: externalAuthLink.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // we explicitly want to clear the key id
OAuthExpiry: gitAuthLink.OAuthExpiry,
OAuthExpiry: externalAuthLink.OAuthExpiry,
OAuthExtra: externalAuthLink.OAuthExtra,
}); err != nil {
return xerrors.Errorf("update git auth link user_id=%s provider_id=%s: %w", gitAuthLink.UserID, gitAuthLink.ProviderID, err)
return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err)
}
}
return nil

View File

@ -453,6 +453,7 @@ export interface ExternalAuthConfig {
readonly app_installations_url: string;
readonly no_refresh: boolean;
readonly scopes: string[];
readonly extra_token_keys: string[];
readonly device_flow: boolean;
readonly device_code_url: string;
readonly regex: string;
@ -1650,12 +1651,14 @@ export type EnhancedExternalAuthProvider =
| "azure-devops"
| "bitbucket"
| "github"
| "gitlab";
| "gitlab"
| "slack";
export const EnhancedExternalAuthProviders: EnhancedExternalAuthProvider[] = [
"azure-devops",
"bitbucket",
"github",
"gitlab",
"slack",
];
// From codersdk/deployment.go

View File

@ -19,6 +19,7 @@ const meta: Meta<typeof ExternalAuthSettingsPageView> = {
app_installations_url: "",
no_refresh: false,
scopes: [],
extra_token_keys: [],
device_flow: true,
device_code_url: "",
display_icon: "",

View File

@ -55,6 +55,7 @@
"ruby.png",
"rubymine.svg",
"rust.svg",
"slack.svg",
"swift.svg",
"tensorflow.svg",
"terminal.svg",

View File

@ -0,0 +1,6 @@
<svg width="127" height="127" xmlns="http://www.w3.org/2000/svg">
<path d="M27.2 80c0 7.3-5.9 13.2-13.2 13.2C6.7 93.2.8 87.3.8 80c0-7.3 5.9-13.2 13.2-13.2h13.2V80zm6.6 0c0-7.3 5.9-13.2 13.2-13.2 7.3 0 13.2 5.9 13.2 13.2v33c0 7.3-5.9 13.2-13.2 13.2-7.3 0-13.2-5.9-13.2-13.2V80z" fill="#E01E5A"/>
<path d="M47 27c-7.3 0-13.2-5.9-13.2-13.2C33.8 6.5 39.7.6 47 .6c7.3 0 13.2 5.9 13.2 13.2V27H47zm0 6.7c7.3 0 13.2 5.9 13.2 13.2 0 7.3-5.9 13.2-13.2 13.2H13.9C6.6 60.1.7 54.2.7 46.9c0-7.3 5.9-13.2 13.2-13.2H47z" fill="#36C5F0"/>
<path d="M99.9 46.9c0-7.3 5.9-13.2 13.2-13.2 7.3 0 13.2 5.9 13.2 13.2 0 7.3-5.9 13.2-13.2 13.2H99.9V46.9zm-6.6 0c0 7.3-5.9 13.2-13.2 13.2-7.3 0-13.2-5.9-13.2-13.2V13.8C66.9 6.5 72.8.6 80.1.6c7.3 0 13.2 5.9 13.2 13.2v33.1z" fill="#2EB67D"/>
<path d="M80.1 99.8c7.3 0 13.2 5.9 13.2 13.2 0 7.3-5.9 13.2-13.2 13.2-7.3 0-13.2-5.9-13.2-13.2V99.8h13.2zm0-6.6c-7.3 0-13.2-5.9-13.2-13.2 0-7.3 5.9-13.2 13.2-13.2h33.1c7.3 0 13.2 5.9 13.2 13.2 0 7.3-5.9 13.2-13.2 13.2H80.1z" fill="#ECB22E"/>
</svg>

After

Width:  |  Height:  |  Size: 1019 B