mirror of https://github.com/coder/coder.git
fix: allow for alternate usernames on conflict (#4614)
This commit is contained in:
parent
3c40698033
commit
61683f1961
|
@ -29,6 +29,7 @@ import (
|
|||
"time"
|
||||
|
||||
"cloud.google.com/go/compute/metadata"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/fullsailor/pkcs7"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
|
@ -36,6 +37,7 @@ import (
|
|||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/api/idtoken"
|
||||
"google.golang.org/api/option"
|
||||
|
@ -725,6 +727,80 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
|
|||
}
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
key *rsa.PrivateKey
|
||||
issuer string
|
||||
}
|
||||
|
||||
func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
|
||||
t.Helper()
|
||||
|
||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
if issuer == "" {
|
||||
issuer = "https://coder.com"
|
||||
}
|
||||
|
||||
return &OIDCConfig{
|
||||
key: pkey,
|
||||
issuer: issuer,
|
||||
}
|
||||
}
|
||||
|
||||
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
token, err := base64.StdEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode code: %w", err)
|
||||
}
|
||||
return (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": string(token),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = o.issuer
|
||||
}
|
||||
|
||||
if _, ok := claims["sub"]; !ok {
|
||||
claims["sub"] = "testme"
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return base64.StdEncoding.EncodeToString([]byte(signed))
|
||||
}
|
||||
|
||||
func (o *OIDCConfig) OIDCConfig() *coderd.OIDCConfig {
|
||||
return &coderd.OIDCConfig{
|
||||
OAuth2Config: o,
|
||||
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{o.key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
|
||||
// instance authentication for Azure.
|
||||
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
|
||||
|
@ -805,3 +881,19 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
|
|||
require.True(t, errors.As(err, &cerr))
|
||||
return cerr
|
||||
}
|
||||
|
||||
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
|
||||
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
|
||||
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
|
||||
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
|
||||
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
|
||||
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
|
||||
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
|
||||
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
|
||||
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
|
||||
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
|
||||
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
|
||||
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
|
||||
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
|
|
@ -2221,6 +2221,12 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
|
|||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for _, user := range q.users {
|
||||
if user.Username == arg.Username && !user.Deleted {
|
||||
return database.User{}, errDuplicateKey
|
||||
}
|
||||
}
|
||||
|
||||
user := database.User{
|
||||
ID: arg.ID,
|
||||
Email: arg.Email,
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
|
@ -390,6 +391,38 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
|||
organizationID = organizations[0].ID
|
||||
}
|
||||
|
||||
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
|
||||
Username: params.Username,
|
||||
})
|
||||
if err == nil {
|
||||
var (
|
||||
original = params.Username
|
||||
validUsername bool
|
||||
)
|
||||
for i := 0; i < 10; i++ {
|
||||
alternate := fmt.Sprintf("%s-%s", original, namesgenerator.GetRandomName(1))
|
||||
|
||||
params.Username = httpapi.UsernameFrom(alternate)
|
||||
|
||||
_, err := tx.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
|
||||
Username: params.Username,
|
||||
})
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
validUsername = true
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user by email/username: %w", err)
|
||||
}
|
||||
}
|
||||
if !validUsername {
|
||||
return httpError{
|
||||
code: http.StatusConflict,
|
||||
msg: fmt.Sprintf("exhausted alternatives for taken username %q", original),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user, _, err = api.CreateUser(ctx, tx, CreateUserRequest{
|
||||
CreateUserRequest: codersdk.CreateUserRequest{
|
||||
Email: params.Email,
|
||||
|
|
|
@ -3,13 +3,12 @@ package coderd_test
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
@ -450,17 +449,19 @@ func TestUserOIDC(t *testing.T) {
|
|||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := createOIDCConfig(t, tc.Claims)
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
config := conf.OIDCConfig()
|
||||
config.AllowSignups = tc.AllowSignups
|
||||
config.EmailDomain = tc.EmailDomain
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
})
|
||||
resp := oidcCallback(t, client)
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, tc.Claims))
|
||||
assert.Equal(t, tc.StatusCode, resp.StatusCode)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
ctx, _ := testutil.Context(t)
|
||||
|
||||
if tc.Username != "" {
|
||||
client.SessionToken = authCookieValue(resp.Cookies())
|
||||
|
@ -478,10 +479,50 @@ func TestUserOIDC(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
t.Run("AlternateUsername", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf := coderdtest.NewOIDCConfig(t, "")
|
||||
|
||||
config := conf.OIDCConfig()
|
||||
config.AllowSignups = true
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
})
|
||||
|
||||
code := conf.EncodeClaims(t, jwt.MapClaims{
|
||||
"email": "jon@coder.com",
|
||||
})
|
||||
resp := oidcCallback(t, client, code)
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
ctx, _ := testutil.Context(t)
|
||||
|
||||
client.SessionToken = authCookieValue(resp.Cookies())
|
||||
user, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "jon", user.Username)
|
||||
|
||||
// Pass a different subject field so that we prompt creating a
|
||||
// new user.
|
||||
code = conf.EncodeClaims(t, jwt.MapClaims{
|
||||
"email": "jon@example2.com",
|
||||
"sub": "diff",
|
||||
})
|
||||
resp = oidcCallback(t, client, code)
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
client.SessionToken = authCookieValue(resp.Cookies())
|
||||
user, err = client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.HasPrefix(user.Username, "jon-"), "username %q should have prefix %q", user.Username, "jon-")
|
||||
})
|
||||
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
resp := oidcCallback(t, client)
|
||||
resp := oidcCallback(t, client, "asdf")
|
||||
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode)
|
||||
})
|
||||
|
||||
|
@ -492,7 +533,7 @@ func TestUserOIDC(t *testing.T) {
|
|||
OAuth2Config: &oauth2Config{},
|
||||
},
|
||||
})
|
||||
resp := oidcCallback(t, client)
|
||||
resp := oidcCallback(t, client, "asdf")
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
|
@ -514,48 +555,16 @@ func TestUserOIDC(t *testing.T) {
|
|||
Verifier: verifier,
|
||||
},
|
||||
})
|
||||
resp := oidcCallback(t, client)
|
||||
resp := oidcCallback(t, client, "asdf")
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
// createOIDCConfig generates a new OIDCConfig that returns a static token
|
||||
// with the claims provided.
|
||||
func createOIDCConfig(t *testing.T, claims jwt.MapClaims) *coderd.OIDCConfig {
|
||||
t.Helper()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
claims["iss"] = "https://coder.com"
|
||||
claims["sub"] = "hello"
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifier := oidc.NewVerifier("https://coder.com", &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
|
||||
return &coderd.OIDCConfig{
|
||||
OAuth2Config: &oauth2Config{
|
||||
token: (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": signed,
|
||||
}),
|
||||
},
|
||||
Verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
state := "somestate"
|
||||
oauthURL, err := client.URL.Parse("/api/v2/users/oauth2/github/callback?code=asd&state=" + state)
|
||||
require.NoError(t, err)
|
||||
|
@ -573,19 +582,18 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
|||
return res
|
||||
}
|
||||
|
||||
func oidcCallback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
|
||||
t.Helper()
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
state := "somestate"
|
||||
oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback?code=asd&state=" + state)
|
||||
oauthURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/users/oidc/callback?code=%s&state=somestate", code))
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: codersdk.OAuth2StateKey,
|
||||
Value: state,
|
||||
Value: "somestate",
|
||||
})
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -140,8 +140,12 @@ func TestEntitlements(t *testing.T) {
|
|||
t.Run("TooManyUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
db.InsertUser(context.Background(), database.InsertUserParams{})
|
||||
db.InsertUser(context.Background(), database.InsertUserParams{})
|
||||
db.InsertUser(context.Background(), database.InsertUserParams{
|
||||
Username: "test1",
|
||||
})
|
||||
db.InsertUser(context.Background(), database.InsertUserParams{
|
||||
Username: "test2",
|
||||
})
|
||||
db.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
UserLimit: 1,
|
||||
|
|
Loading…
Reference in New Issue