fix: allow for alternate usernames on conflict (#4614)

This commit is contained in:
Jon Ayers 2022-10-17 22:07:11 -05:00 committed by GitHub
parent 3c40698033
commit 61683f1961
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 192 additions and 49 deletions

View File

@ -29,6 +29,7 @@ import (
"time"
"cloud.google.com/go/compute/metadata"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/fullsailor/pkcs7"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
@ -36,6 +37,7 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
@ -725,6 +727,80 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
}
}
type OIDCConfig struct {
key *rsa.PrivateKey
issuer string
}
func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
t.Helper()
block, _ := pem.Decode([]byte(testRSAPrivateKey))
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)
if issuer == "" {
issuer = "https://coder.com"
}
return &OIDCConfig{
key: pkey,
issuer: issuer,
}
}
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "/?state=" + url.QueryEscape(state)
}
func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return nil
}
func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
token, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, xerrors.Errorf("decode code: %w", err)
}
return (&oauth2.Token{
AccessToken: "token",
}).WithExtra(map[string]interface{}{
"id_token": string(token),
}), nil
}
func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
t.Helper()
if _, ok := claims["exp"]; !ok {
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
}
if _, ok := claims["iss"]; !ok {
claims["iss"] = o.issuer
}
if _, ok := claims["sub"]; !ok {
claims["sub"] = "testme"
}
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString([]byte(signed))
}
func (o *OIDCConfig) OIDCConfig() *coderd.OIDCConfig {
return &coderd.OIDCConfig{
OAuth2Config: o,
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{o.key.Public()},
}, &oidc.Config{
SkipClientIDCheck: true,
}),
}
}
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
// instance authentication for Azure.
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
@ -805,3 +881,19 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
require.True(t, errors.As(err, &cerr))
return cerr
}
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
-----END RSA PRIVATE KEY-----`

View File

@ -2221,6 +2221,12 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
q.mutex.Lock()
defer q.mutex.Unlock()
for _, user := range q.users {
if user.Username == arg.Username && !user.Deleted {
return database.User{}, errDuplicateKey
}
}
user := database.User{
ID: arg.ID,
Email: arg.Email,

View File

@ -13,6 +13,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
@ -390,6 +391,38 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
organizationID = organizations[0].ID
}
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
Username: params.Username,
})
if err == nil {
var (
original = params.Username
validUsername bool
)
for i := 0; i < 10; i++ {
alternate := fmt.Sprintf("%s-%s", original, namesgenerator.GetRandomName(1))
params.Username = httpapi.UsernameFrom(alternate)
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
Username: params.Username,
})
if xerrors.Is(err, sql.ErrNoRows) {
validUsername = true
break
}
if err != nil {
return xerrors.Errorf("get user by email/username: %w", err)
}
}
if !validUsername {
return httpError{
code: http.StatusConflict,
msg: fmt.Sprintf("exhausted alternatives for taken username %q", original),
}
}
}
user, _, err = api.CreateUser(ctx, tx, CreateUserRequest{
CreateUserRequest: codersdk.CreateUserRequest{
Email: params.Email,

View File

@ -3,13 +3,12 @@ package coderd_test
import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
@ -450,17 +449,19 @@ func TestUserOIDC(t *testing.T) {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
config := createOIDCConfig(t, tc.Claims)
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig()
config.AllowSignups = tc.AllowSignups
config.EmailDomain = tc.EmailDomain
client := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: config,
})
resp := oidcCallback(t, client)
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.Claims))
assert.Equal(t, tc.StatusCode, resp.StatusCode)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ctx, _ := testutil.Context(t)
if tc.Username != "" {
client.SessionToken = authCookieValue(resp.Cookies())
@ -478,10 +479,50 @@ func TestUserOIDC(t *testing.T) {
})
}
t.Run("AlternateUsername", func(t *testing.T) {
t.Parallel()
conf := coderdtest.NewOIDCConfig(t, "")
config := conf.OIDCConfig()
config.AllowSignups = true
client := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: config,
})
code := conf.EncodeClaims(t, jwt.MapClaims{
"email": "jon@coder.com",
})
resp := oidcCallback(t, client, code)
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
ctx, _ := testutil.Context(t)
client.SessionToken = authCookieValue(resp.Cookies())
user, err := client.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, "jon", user.Username)
// Pass a different subject field so that we prompt creating a
// new user.
code = conf.EncodeClaims(t, jwt.MapClaims{
"email": "jon@example2.com",
"sub": "diff",
})
resp = oidcCallback(t, client, code)
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
client.SessionToken = authCookieValue(resp.Cookies())
user, err = client.User(ctx, "me")
require.NoError(t, err)
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")
})
t.Run("Disabled", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
resp := oidcCallback(t, client)
resp := oidcCallback(t, client, "asdf")
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode)
})
@ -492,7 +533,7 @@ func TestUserOIDC(t *testing.T) {
OAuth2Config: &oauth2Config{},
},
})
resp := oidcCallback(t, client)
resp := oidcCallback(t, client, "asdf")
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
@ -514,48 +555,16 @@ func TestUserOIDC(t *testing.T) {
Verifier: verifier,
},
})
resp := oidcCallback(t, client)
resp := oidcCallback(t, client, "asdf")
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}
// createOIDCConfig generates a new OIDCConfig that returns a static token
// with the claims provided.
func createOIDCConfig(t *testing.T, claims jwt.MapClaims) *coderd.OIDCConfig {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
claims["iss"] = "https://coder.com"
claims["sub"] = "hello"
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(key)
require.NoError(t, err)
verifier := oidc.NewVerifier("https://coder.com", &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{key.Public()},
}, &oidc.Config{
SkipClientIDCheck: true,
})
return &coderd.OIDCConfig{
OAuth2Config: &oauth2Config{
token: (&oauth2.Token{
AccessToken: "token",
}).WithExtra(map[string]interface{}{
"id_token": signed,
}),
},
Verifier: verifier,
}
}
func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
state := "somestate"
oauthURL, err := client.URL.Parse("/api/v2/users/oauth2/github/callback?code=asd&state=" + state)
require.NoError(t, err)
@ -573,19 +582,18 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
return res
}
func oidcCallback(t *testing.T, client *codersdk.Client) *http.Response {
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
t.Helper()
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
state := "somestate"
oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback?code=asd&state=" + state)
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code))
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: codersdk.OAuth2StateKey,
Value: state,
Value: "somestate",
})
res, err := client.HTTPClient.Do(req)
require.NoError(t, err)

View File

@ -140,8 +140,12 @@ func TestEntitlements(t *testing.T) {
t.Run("TooManyUsers", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertUser(context.Background(), database.InsertUserParams{})
db.InsertUser(context.Background(), database.InsertUserParams{})
db.InsertUser(context.Background(), database.InsertUserParams{
Username: "test1",
})
db.InsertUser(context.Background(), database.InsertUserParams{
Username: "test2",
})
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
UserLimit: 1,