mirror of https://github.com/coder/coder.git
feat: Add GitHub OAuth (#1050)
* Initial oauth * Add Github authentication * Add AuthMethods endpoint * Add frontend * Rename basic authentication to password * Add flags for configuring GitHub auth * Remove name from API keys * Fix authmethods in test * Add stories and display auth methods error
This commit is contained in:
parent
3976994781
commit
7496c3da81
|
@ -35,6 +35,7 @@
|
|||
"nolint",
|
||||
"nosec",
|
||||
"ntqry",
|
||||
"OIDC",
|
||||
"oneof",
|
||||
"parameterscopeid",
|
||||
"pqtype",
|
||||
|
@ -46,6 +47,7 @@
|
|||
"ptytest",
|
||||
"retrier",
|
||||
"sdkproto",
|
||||
"Signup",
|
||||
"stretchr",
|
||||
"TCGETS",
|
||||
"tcpip",
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
@ -27,6 +28,14 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
|
|||
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
|
||||
}
|
||||
|
||||
func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
|
||||
val, ok := os.LookupEnv(env)
|
||||
if ok {
|
||||
def = strings.Split(val, ",")
|
||||
}
|
||||
flagset.StringArrayVarP(ptr, name, shorthand, def, usage)
|
||||
}
|
||||
|
||||
// Uint8VarP sets a uint8 flag on the given flag set.
|
||||
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
|
||||
val, ok := os.LookupEnv(env)
|
||||
|
|
|
@ -54,6 +54,26 @@ func TestCliflag(t *testing.T) {
|
|||
require.NotContains(t, flagset.FlagUsages(), " - consumes")
|
||||
})
|
||||
|
||||
t.Run("StringArrayDefault", func(t *testing.T) {
|
||||
var ptr []string
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
def := []string{"hello"}
|
||||
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, def, usage)
|
||||
got, err := flagset.GetStringArray(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, def, got)
|
||||
})
|
||||
|
||||
t.Run("StringArrayEnvVar", func(t *testing.T) {
|
||||
var ptr []string
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
t.Setenv(env, "wow,test")
|
||||
cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage)
|
||||
got, err := flagset.GetStringArray(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"wow", "test"}, got)
|
||||
})
|
||||
|
||||
t.Run("IntDefault", func(t *testing.T) {
|
||||
var ptr uint8
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
|
|
|
@ -18,8 +18,11 @@ import (
|
|||
|
||||
"github.com/briandowns/spinner"
|
||||
"github.com/coreos/go-systemd/daemon"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/pion/turn/v2"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/oauth2"
|
||||
xgithub "golang.org/x/oauth2/github"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/api/idtoken"
|
||||
"google.golang.org/api/option"
|
||||
|
@ -51,19 +54,23 @@ func server() *cobra.Command {
|
|||
dev bool
|
||||
postgresURL string
|
||||
// provisionerDaemonCount is a uint8 to ensure a number > 0.
|
||||
provisionerDaemonCount uint8
|
||||
tlsCertFile string
|
||||
tlsClientCAFile string
|
||||
tlsClientAuth string
|
||||
tlsEnable bool
|
||||
tlsKeyFile string
|
||||
tlsMinVersion string
|
||||
turnRelayAddress string
|
||||
skipTunnel bool
|
||||
traceDatadog bool
|
||||
secureAuthCookie bool
|
||||
sshKeygenAlgorithmRaw string
|
||||
spooky bool
|
||||
provisionerDaemonCount uint8
|
||||
oauth2GithubClientID string
|
||||
oauth2GithubClientSecret string
|
||||
oauth2GithubAllowedOrganizations []string
|
||||
oauth2GithubAllowSignups bool
|
||||
tlsCertFile string
|
||||
tlsClientCAFile string
|
||||
tlsClientAuth string
|
||||
tlsEnable bool
|
||||
tlsKeyFile string
|
||||
tlsMinVersion string
|
||||
turnRelayAddress string
|
||||
skipTunnel bool
|
||||
traceDatadog bool
|
||||
secureAuthCookie bool
|
||||
sshKeygenAlgorithmRaw string
|
||||
spooky bool
|
||||
)
|
||||
|
||||
root := &cobra.Command{
|
||||
|
@ -180,6 +187,13 @@ func server() *cobra.Command {
|
|||
TURNServer: turnServer,
|
||||
}
|
||||
|
||||
if oauth2GithubClientSecret != "" {
|
||||
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("configure github oauth2: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "access-url: %s\n", accessURL)
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "provisioner-daemons: %d\n", provisionerDaemonCount)
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr())
|
||||
|
@ -373,6 +387,14 @@ func server() *cobra.Command {
|
|||
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
|
||||
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
|
||||
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
|
||||
cliflag.StringVarP(root.Flags(), &oauth2GithubClientID, "oauth2-github-client-id", "", "CODER_OAUTH2_GITHUB_CLIENT_ID", "",
|
||||
"Specifies a client ID to use for oauth2 with GitHub.")
|
||||
cliflag.StringVarP(root.Flags(), &oauth2GithubClientSecret, "oauth2-github-client-secret", "", "CODER_OAUTH2_GITHUB_CLIENT_SECRET", "",
|
||||
"Specifies a client secret to use for oauth2 with GitHub.")
|
||||
cliflag.StringArrayVarP(root.Flags(), &oauth2GithubAllowedOrganizations, "oauth2-github-allowed-orgs", "", "CODER_OAUTH2_GITHUB_ALLOWED_ORGS", nil,
|
||||
"Specifies organizations the user must be a member of to authenticate with GitHub.")
|
||||
cliflag.BoolVarP(root.Flags(), &oauth2GithubAllowSignups, "oauth2-github-allow-signups", "", "CODER_OAUTH2_GITHUB_ALLOW_SIGNUPS", false,
|
||||
"Specifies whether new users can sign up with GitHub.")
|
||||
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
|
||||
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
|
||||
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
|
||||
|
@ -572,6 +594,42 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi
|
|||
return tls.NewListener(listener, tlsConfig), nil
|
||||
}
|
||||
|
||||
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string) (*coderd.GithubOAuth2Config, error) {
|
||||
redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse github oauth callback url: %w", err)
|
||||
}
|
||||
return &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: xgithub.Endpoint,
|
||||
RedirectURL: redirectURL.String(),
|
||||
Scopes: []string{
|
||||
"read:user",
|
||||
"read:org",
|
||||
"user:email",
|
||||
},
|
||||
},
|
||||
AllowSignups: allowSignups,
|
||||
AllowOrganizations: allowOrgs,
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
user, _, err := github.NewClient(client).Users.Get(ctx, "")
|
||||
return user, err
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
emails, _, err := github.NewClient(client).Users.ListEmails(ctx, &github.ListOptions{})
|
||||
return emails, err
|
||||
},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
memberships, _, err := github.NewClient(client).Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{
|
||||
State: "active",
|
||||
})
|
||||
return memberships, err
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type datadogLogger struct {
|
||||
logger slog.Logger
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@ type Options struct {
|
|||
AWSCertificates awsidentity.Certificates
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GoogleTokenValidator *idtoken.Validator
|
||||
GithubOAuth2Config *GithubOAuth2Config
|
||||
ICEServers []webrtc.ICEServer
|
||||
SecureAuthCookie bool
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
|
@ -62,6 +63,9 @@ func New(options *Options) (http.Handler, func()) {
|
|||
api := &api{
|
||||
Options: options,
|
||||
}
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
|
||||
Github: options.GithubOAuth2Config,
|
||||
})
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Route("/api/v2", func(r chi.Router) {
|
||||
|
@ -86,7 +90,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
// This number is arbitrary, but reading/writing
|
||||
// file content is expensive so it should be small.
|
||||
httpmw.RateLimitPerMinute(12),
|
||||
|
@ -96,7 +100,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
r.Route("/organizations/{organization}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractOrganizationParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.organization)
|
||||
|
@ -109,7 +113,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
})
|
||||
r.Route("/parameters/{scope}/{id}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Post("/", api.postParameter)
|
||||
r.Get("/", api.parameters)
|
||||
r.Route("/{name}", func(r chi.Router) {
|
||||
|
@ -118,7 +122,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
r.Route("/templates/{template}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractTemplateParam(options.Database),
|
||||
httpmw.ExtractOrganizationParam(options.Database),
|
||||
)
|
||||
|
@ -132,7 +136,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
r.Route("/templateversions/{templateversion}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractTemplateVersionParam(options.Database),
|
||||
httpmw.ExtractOrganizationParam(options.Database),
|
||||
)
|
||||
|
@ -154,8 +158,15 @@ func New(options *Options) (http.Handler, func()) {
|
|||
r.Post("/first", api.postFirstUser)
|
||||
r.Post("/login", api.postLogin)
|
||||
r.Post("/logout", api.postLogout)
|
||||
r.Get("/authmethods", api.userAuthMethods)
|
||||
r.Route("/oauth2", func(r chi.Router) {
|
||||
r.Route("/github", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config))
|
||||
r.Get("/callback", api.userOAuth2Github)
|
||||
})
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Post("/", api.postUsers)
|
||||
r.Get("/", api.users)
|
||||
r.Route("/{user}", func(r chi.Router) {
|
||||
|
@ -193,7 +204,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
r.Route("/{workspaceagent}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceAgentParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspaceAgent)
|
||||
|
@ -204,7 +215,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceResourceParam(options.Database),
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
|
@ -212,7 +223,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
r.Route("/workspaces/{workspace}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspace)
|
||||
|
@ -230,7 +241,7 @@ func New(options *Options) (http.Handler, func()) {
|
|||
})
|
||||
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceBuildParam(options.Database),
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
|
|
|
@ -53,6 +53,7 @@ import (
|
|||
type Options struct {
|
||||
AWSCertificates awsidentity.Certificates
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GithubOAuth2Config *coderd.GithubOAuth2Config
|
||||
GoogleTokenValidator *idtoken.Validator
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
APIRateLimit int
|
||||
|
@ -123,6 +124,7 @@ func New(t *testing.T, options *Options) *codersdk.Client {
|
|||
|
||||
AWSCertificates: options.AWSCertificates,
|
||||
AzureCertificates: options.AzureCertificates,
|
||||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||||
TURNServer: turnServer,
|
||||
|
|
|
@ -434,6 +434,16 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, req database.GetW
|
|||
return workspaces, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
if len(q.organizations) == 0 {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
return q.organizations, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
@ -856,21 +866,18 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
|
|||
|
||||
//nolint:gosimple
|
||||
key := database.APIKey{
|
||||
ID: arg.ID,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
UserID: arg.UserID,
|
||||
Application: arg.Application,
|
||||
Name: arg.Name,
|
||||
LastUsed: arg.LastUsed,
|
||||
ExpiresAt: arg.ExpiresAt,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
LoginType: arg.LoginType,
|
||||
OIDCAccessToken: arg.OIDCAccessToken,
|
||||
OIDCRefreshToken: arg.OIDCRefreshToken,
|
||||
OIDCIDToken: arg.OIDCIDToken,
|
||||
OIDCExpiry: arg.OIDCExpiry,
|
||||
DevurlToken: arg.DevurlToken,
|
||||
ID: arg.ID,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
UserID: arg.UserID,
|
||||
ExpiresAt: arg.ExpiresAt,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
LastUsed: arg.LastUsed,
|
||||
LoginType: arg.LoginType,
|
||||
OAuthAccessToken: arg.OAuthAccessToken,
|
||||
OAuthRefreshToken: arg.OAuthRefreshToken,
|
||||
OAuthIDToken: arg.OAuthIDToken,
|
||||
OAuthExpiry: arg.OAuthExpiry,
|
||||
}
|
||||
q.apiKeys = append(q.apiKeys, key)
|
||||
return key, nil
|
||||
|
@ -1185,9 +1192,9 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI
|
|||
}
|
||||
apiKey.LastUsed = arg.LastUsed
|
||||
apiKey.ExpiresAt = arg.ExpiresAt
|
||||
apiKey.OIDCAccessToken = arg.OIDCAccessToken
|
||||
apiKey.OIDCRefreshToken = arg.OIDCRefreshToken
|
||||
apiKey.OIDCExpiry = arg.OIDCExpiry
|
||||
apiKey.OAuthAccessToken = arg.OAuthAccessToken
|
||||
apiKey.OAuthRefreshToken = arg.OAuthRefreshToken
|
||||
apiKey.OAuthExpiry = arg.OAuthExpiry
|
||||
q.apiKeys[index] = apiKey
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -14,9 +14,8 @@ CREATE TYPE log_source AS ENUM (
|
|||
);
|
||||
|
||||
CREATE TYPE login_type AS ENUM (
|
||||
'built-in',
|
||||
'saml',
|
||||
'oidc'
|
||||
'password',
|
||||
'github'
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_destination_scheme AS ENUM (
|
||||
|
@ -67,18 +66,15 @@ CREATE TABLE api_keys (
|
|||
id text NOT NULL,
|
||||
hashed_secret bytea NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
application boolean NOT NULL,
|
||||
name text NOT NULL,
|
||||
last_used timestamp with time zone NOT NULL,
|
||||
expires_at timestamp with time zone NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
oidc_access_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_id_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
devurl_token boolean DEFAULT false NOT NULL
|
||||
oauth_access_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_id_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE files (
|
||||
|
|
|
@ -4,14 +4,9 @@
|
|||
-- All tables and types are stolen from:
|
||||
-- https://github.com/coder/m/blob/47b6fc383347b9f9fab424d829c482defd3e1fe2/product/coder/pkg/database/dump.sql
|
||||
|
||||
--
|
||||
-- Name: users; Type: TABLE; Schema: public; Owner: coder
|
||||
--
|
||||
|
||||
CREATE TYPE login_type AS ENUM (
|
||||
'built-in',
|
||||
'saml',
|
||||
'oidc'
|
||||
'password',
|
||||
'github'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
|
@ -31,10 +26,6 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users USING btree (email);
|
|||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users USING btree (username);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS users_username_lower_idx ON users USING btree (lower(username));
|
||||
|
||||
--
|
||||
-- Name: organizations; Type: TABLE; Schema: Owner: coder
|
||||
--
|
||||
|
||||
CREATE TABLE IF NOT EXISTS organizations (
|
||||
id uuid NOT NULL,
|
||||
name text NOT NULL,
|
||||
|
@ -68,18 +59,15 @@ CREATE TABLE IF NOT EXISTS api_keys (
|
|||
id text NOT NULL,
|
||||
hashed_secret bytea NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
application boolean NOT NULL,
|
||||
name text NOT NULL,
|
||||
last_used timestamp with time zone NOT NULL,
|
||||
expires_at timestamp with time zone NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
oidc_access_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_id_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
devurl_token boolean DEFAULT false NOT NULL,
|
||||
oauth_access_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_id_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
|
|
|
@ -56,9 +56,8 @@ func (e *LogSource) Scan(src interface{}) error {
|
|||
type LoginType string
|
||||
|
||||
const (
|
||||
LoginTypeBuiltIn LoginType = "built-in"
|
||||
LoginTypeSaml LoginType = "saml"
|
||||
LoginTypeOIDC LoginType = "oidc"
|
||||
LoginTypePassword LoginType = "password"
|
||||
LoginTypeGithub LoginType = "github"
|
||||
)
|
||||
|
||||
func (e *LoginType) Scan(src interface{}) error {
|
||||
|
@ -230,21 +229,18 @@ func (e *WorkspaceTransition) Scan(src interface{}) error {
|
|||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Application bool `db:"application" json:"application"`
|
||||
Name string `db:"name" json:"name"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
|
||||
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
|
||||
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
|
||||
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
|
||||
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
|
|
|
@ -18,6 +18,7 @@ type querier interface {
|
|||
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
|
||||
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
|
||||
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
|
||||
GetOrganizations(ctx context.Context) ([]Organization, error)
|
||||
GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error)
|
||||
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
|
||||
GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error)
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
const getAPIKeyByID = `-- name: GetAPIKeyByID :one
|
||||
SELECT
|
||||
id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token
|
||||
id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry
|
||||
FROM
|
||||
api_keys
|
||||
WHERE
|
||||
|
@ -31,18 +31,15 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
|
|||
&i.ID,
|
||||
&i.HashedSecret,
|
||||
&i.UserID,
|
||||
&i.Application,
|
||||
&i.Name,
|
||||
&i.LastUsed,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.LoginType,
|
||||
&i.OIDCAccessToken,
|
||||
&i.OIDCRefreshToken,
|
||||
&i.OIDCIDToken,
|
||||
&i.OIDCExpiry,
|
||||
&i.DevurlToken,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthIDToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -53,55 +50,33 @@ INSERT INTO
|
|||
id,
|
||||
hashed_secret,
|
||||
user_id,
|
||||
application,
|
||||
"name",
|
||||
last_used,
|
||||
expires_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
login_type,
|
||||
oidc_access_token,
|
||||
oidc_refresh_token,
|
||||
oidc_id_token,
|
||||
oidc_expiry,
|
||||
devurl_token
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_id_token,
|
||||
oauth_expiry
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12,
|
||||
$13,
|
||||
$14,
|
||||
$15
|
||||
) RETURNING id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry
|
||||
`
|
||||
|
||||
type InsertAPIKeyParams struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Application bool `db:"application" json:"application"`
|
||||
Name string `db:"name" json:"name"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
|
||||
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
|
||||
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
|
||||
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
|
||||
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) {
|
||||
|
@ -109,36 +84,30 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
|
|||
arg.ID,
|
||||
arg.HashedSecret,
|
||||
arg.UserID,
|
||||
arg.Application,
|
||||
arg.Name,
|
||||
arg.LastUsed,
|
||||
arg.ExpiresAt,
|
||||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.LoginType,
|
||||
arg.OIDCAccessToken,
|
||||
arg.OIDCRefreshToken,
|
||||
arg.OIDCIDToken,
|
||||
arg.OIDCExpiry,
|
||||
arg.DevurlToken,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthIDToken,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
var i APIKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HashedSecret,
|
||||
&i.UserID,
|
||||
&i.Application,
|
||||
&i.Name,
|
||||
&i.LastUsed,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.LoginType,
|
||||
&i.OIDCAccessToken,
|
||||
&i.OIDCRefreshToken,
|
||||
&i.OIDCIDToken,
|
||||
&i.OIDCExpiry,
|
||||
&i.DevurlToken,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthIDToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -149,20 +118,20 @@ UPDATE
|
|||
SET
|
||||
last_used = $2,
|
||||
expires_at = $3,
|
||||
oidc_access_token = $4,
|
||||
oidc_refresh_token = $5,
|
||||
oidc_expiry = $6
|
||||
oauth_access_token = $4,
|
||||
oauth_refresh_token = $5,
|
||||
oauth_expiry = $6
|
||||
WHERE
|
||||
id = $1
|
||||
`
|
||||
|
||||
type UpdateAPIKeyByIDParams struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
|
||||
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
|
||||
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
|
||||
ID string `db:"id" json:"id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error {
|
||||
|
@ -170,9 +139,9 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
|
|||
arg.ID,
|
||||
arg.LastUsed,
|
||||
arg.ExpiresAt,
|
||||
arg.OIDCAccessToken,
|
||||
arg.OIDCRefreshToken,
|
||||
arg.OIDCExpiry,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
@ -453,6 +422,42 @@ func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Or
|
|||
return i, err
|
||||
}
|
||||
|
||||
const getOrganizations = `-- name: GetOrganizations :many
|
||||
SELECT
|
||||
id, name, description, created_at, updated_at
|
||||
FROM
|
||||
organizations
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOrganizations(ctx context.Context) ([]Organization, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getOrganizations)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Organization
|
||||
for rows.Next() {
|
||||
var i Organization
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many
|
||||
SELECT
|
||||
id, name, description, created_at, updated_at
|
||||
|
|
|
@ -14,37 +14,18 @@ INSERT INTO
|
|||
id,
|
||||
hashed_secret,
|
||||
user_id,
|
||||
application,
|
||||
"name",
|
||||
last_used,
|
||||
expires_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
login_type,
|
||||
oidc_access_token,
|
||||
oidc_refresh_token,
|
||||
oidc_id_token,
|
||||
oidc_expiry,
|
||||
devurl_token
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_id_token,
|
||||
oauth_expiry
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12,
|
||||
$13,
|
||||
$14,
|
||||
$15
|
||||
) RETURNING *;
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING *;
|
||||
|
||||
-- name: UpdateAPIKeyByID :exec
|
||||
UPDATE
|
||||
|
@ -52,8 +33,8 @@ UPDATE
|
|||
SET
|
||||
last_used = $2,
|
||||
expires_at = $3,
|
||||
oidc_access_token = $4,
|
||||
oidc_refresh_token = $5,
|
||||
oidc_expiry = $6
|
||||
oauth_access_token = $4,
|
||||
oauth_refresh_token = $5,
|
||||
oauth_expiry = $6
|
||||
WHERE
|
||||
id = $1;
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
-- name: GetOrganizations :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organizations;
|
||||
|
||||
-- name: GetOrganizationByID :one
|
||||
SELECT
|
||||
*
|
||||
|
|
|
@ -21,10 +21,10 @@ overrides:
|
|||
rename:
|
||||
api_key: APIKey
|
||||
login_type_oidc: LoginTypeOIDC
|
||||
oidc_access_token: OIDCAccessToken
|
||||
oidc_expiry: OIDCExpiry
|
||||
oidc_id_token: OIDCIDToken
|
||||
oidc_refresh_token: OIDCRefreshToken
|
||||
oauth_access_token: OAuthAccessToken
|
||||
oauth_expiry: OAuthExpiry
|
||||
oauth_id_token: OAuthIDToken
|
||||
oauth_refresh_token: OAuthRefreshToken
|
||||
parameter_type_system_hcl: ParameterTypeSystemHCL
|
||||
userstatus: UserStatus
|
||||
gitsshkey: GitSSHKey
|
||||
|
|
|
@ -20,12 +20,6 @@ import (
|
|||
// AuthCookie represents the name of the cookie the API key is stored in.
|
||||
const AuthCookie = "session_token"
|
||||
|
||||
// OAuth2Config contains a subset of functions exposed from oauth2.Config.
|
||||
// It is abstracted for simple testing.
|
||||
type OAuth2Config interface {
|
||||
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
|
||||
}
|
||||
|
||||
type apiKeyContextKey struct{}
|
||||
|
||||
// APIKey returns the API key from the ExtractAPIKey handler.
|
||||
|
@ -37,10 +31,16 @@ func APIKey(r *http.Request) database.APIKey {
|
|||
return apiKey
|
||||
}
|
||||
|
||||
// OAuth2Configs is a collection of configurations for OAuth-based authentication.
|
||||
// This should be extended to support other authentication types in the future.
|
||||
type OAuth2Configs struct {
|
||||
Github OAuth2Config
|
||||
}
|
||||
|
||||
// ExtractAPIKey requires authentication using a valid API key.
|
||||
// It handles extending an API key if it comes close to expiry,
|
||||
// updating the last used time in the database.
|
||||
func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler {
|
||||
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(AuthCookie)
|
||||
|
@ -99,14 +99,24 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
|
|||
// Tracks if the API key has properties updated!
|
||||
changed := false
|
||||
|
||||
if key.LoginType == database.LoginTypeOIDC {
|
||||
// Check if the OIDC token is expired!
|
||||
if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() {
|
||||
if key.LoginType != database.LoginTypePassword {
|
||||
// Check if the OAuth token is expired!
|
||||
if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() {
|
||||
var oauthConfig OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = oauth.Github
|
||||
default:
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("unexpected authentication type %q", key.LoginType),
|
||||
})
|
||||
return
|
||||
}
|
||||
// If it is, let's refresh it from the provided config!
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: key.OIDCAccessToken,
|
||||
RefreshToken: key.OIDCRefreshToken,
|
||||
Expiry: key.OIDCExpiry,
|
||||
AccessToken: key.OAuthAccessToken,
|
||||
RefreshToken: key.OAuthRefreshToken,
|
||||
Expiry: key.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
|
@ -114,9 +124,9 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
|
|||
})
|
||||
return
|
||||
}
|
||||
key.OIDCAccessToken = token.AccessToken
|
||||
key.OIDCRefreshToken = token.RefreshToken
|
||||
key.OIDCExpiry = token.Expiry
|
||||
key.OAuthAccessToken = token.AccessToken
|
||||
key.OAuthRefreshToken = token.RefreshToken
|
||||
key.OAuthExpiry = token.Expiry
|
||||
key.ExpiresAt = token.Expiry
|
||||
changed = true
|
||||
}
|
||||
|
@ -136,21 +146,20 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
|
|||
changed = true
|
||||
}
|
||||
// Only update the ExpiresAt once an hour to prevent database spam.
|
||||
// We extend the ExpiresAt to reduce reauthentication.
|
||||
// We extend the ExpiresAt to reduce re-authentication.
|
||||
apiKeyLifetime := 24 * time.Hour
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
LastUsed: key.LastUsed,
|
||||
OIDCAccessToken: key.OIDCAccessToken,
|
||||
OIDCRefreshToken: key.OIDCRefreshToken,
|
||||
OIDCExpiry: key.OIDCExpiry,
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
OAuthAccessToken: key.OAuthAccessToken,
|
||||
OAuthRefreshToken: key.OAuthRefreshToken,
|
||||
OAuthExpiry: key.OAuthExpiry,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
|
|
|
@ -189,7 +189,6 @@ func TestAPIKey(t *testing.T) {
|
|||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -207,7 +206,6 @@ func TestAPIKey(t *testing.T) {
|
|||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
|
@ -277,7 +275,7 @@ func TestAPIKey(t *testing.T) {
|
|||
require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("OIDCNotExpired", func(t *testing.T) {
|
||||
t.Run("OAuthNotExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
|
@ -294,7 +292,7 @@ func TestAPIKey(t *testing.T) {
|
|||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
})
|
||||
|
@ -311,7 +309,7 @@ func TestAPIKey(t *testing.T) {
|
|||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("OIDCRefresh", func(t *testing.T) {
|
||||
t.Run("OAuthRefresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
|
@ -328,9 +326,9 @@ func TestAPIKey(t *testing.T) {
|
|||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
LastUsed: database.Now(),
|
||||
OIDCExpiry: database.Now().AddDate(0, 0, -1),
|
||||
OAuthExpiry: database.Now().AddDate(0, 0, -1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
token := &oauth2.Token{
|
||||
|
@ -338,11 +336,11 @@ func TestAPIKey(t *testing.T) {
|
|||
RefreshToken: "moo",
|
||||
Expiry: database.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
httpmw.ExtractAPIKey(db, &oauth2Config{
|
||||
tokenSource: &oauth2TokenSource{
|
||||
token: func() (*oauth2.Token, error) {
|
||||
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{
|
||||
Github: &oauth2Config{
|
||||
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
|
||||
return token, nil
|
||||
},
|
||||
}),
|
||||
},
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
|
@ -354,22 +352,28 @@ func TestAPIKey(t *testing.T) {
|
|||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
|
||||
require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken)
|
||||
require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken)
|
||||
})
|
||||
}
|
||||
|
||||
type oauth2Config struct {
|
||||
tokenSource *oauth2TokenSource
|
||||
tokenSource oauth2TokenSource
|
||||
}
|
||||
|
||||
func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
|
||||
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return o.tokenSource
|
||||
}
|
||||
|
||||
type oauth2TokenSource struct {
|
||||
token func() (*oauth2.Token, error)
|
||||
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
|
||||
return o.token()
|
||||
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{}, nil
|
||||
}
|
||||
|
||||
type oauth2TokenSource func() (*oauth2.Token, error)
|
||||
|
||||
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
|
||||
return o()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2StateCookieName = "oauth_state"
|
||||
oauth2RedirectCookieName = "oauth_redirect"
|
||||
)
|
||||
|
||||
type oauth2StateKey struct{}
|
||||
|
||||
type OAuth2State struct {
|
||||
Token *oauth2.Token
|
||||
Redirect string
|
||||
}
|
||||
|
||||
// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing.
|
||||
// *oauth2.Config should be used instead of implementing this in production.
|
||||
type OAuth2Config interface {
|
||||
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
|
||||
}
|
||||
|
||||
// OAuth2 returns the state from an oauth request.
|
||||
func OAuth2(r *http.Request) OAuth2State {
|
||||
oauth, ok := r.Context().Value(oauth2StateKey{}).(OAuth2State)
|
||||
if !ok {
|
||||
panic("developer error: oauth middleware not provided")
|
||||
}
|
||||
return oauth
|
||||
}
|
||||
|
||||
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
|
||||
// URLs, and handling the exchange inbound. Any route that does not have
|
||||
// a "code" URL parameter will be redirected.
|
||||
func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if config == nil {
|
||||
httpapi.Write(rw, http.StatusPreconditionRequired, httpapi.Response{
|
||||
Message: fmt.Sprintf("The oauth2 method requested is not configured!"),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
// If the code isn't provided, we'll redirect!
|
||||
state, err := cryptorand.String(32)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate state string: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: oauth2StateCookieName,
|
||||
Value: state,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
// Redirect must always be specified, otherwise
|
||||
// an old redirect could apply!
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: oauth2RedirectCookieName,
|
||||
Value: r.URL.Query().Get("redirect"),
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
http.Redirect(rw, r, config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "state must be provided",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
stateCookie, err := r.Cookie(oauth2StateCookieName)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("%q cookie must be provided", oauth2StateCookieName),
|
||||
})
|
||||
return
|
||||
}
|
||||
if stateCookie.Value != state {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "state mismatched",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var redirect string
|
||||
stateRedirect, err := r.Cookie(oauth2RedirectCookieName)
|
||||
if err == nil {
|
||||
redirect = stateRedirect.Value
|
||||
}
|
||||
|
||||
oauthToken, err := config.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("exchange oauth code: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), oauth2StateKey{}, OAuth2State{
|
||||
Token: oauthToken,
|
||||
Redirect: redirect,
|
||||
})
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
type testOAuth2Provider struct {
|
||||
}
|
||||
|
||||
func (*testOAuth2Provider) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (*testOAuth2Provider) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{
|
||||
AccessToken: "hello",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*testOAuth2Provider) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NotSetup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("RedirectWithoutCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
location := res.Header().Get("Location")
|
||||
if !assert.NotEmpty(t, location) {
|
||||
return
|
||||
}
|
||||
require.Len(t, res.Result().Cookies(), 2)
|
||||
cookie := res.Result().Cookies()[1]
|
||||
require.Equal(t, "/dashboard", cookie.Value)
|
||||
})
|
||||
t.Run("NoState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("NoStateCookie", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("MismatchedState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "mismatch",
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("ExchangeCodeAndState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=test&state=something", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "something",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_redirect",
|
||||
Value: "/dashboard",
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
require.Equal(t, "/dashboard", state.Redirect)
|
||||
})).ServeHTTP(res, req)
|
||||
})
|
||||
}
|
|
@ -41,7 +41,7 @@ func TestOrganizationParam(t *testing.T) {
|
|||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestTemplateParam(t *testing.T) {
|
|||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestTemplateVersionParam(t *testing.T) {
|
|||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
|
|||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
|
|||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestWorkspaceParam(t *testing.T) {
|
|||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// GithubOAuth2Provider exposes required functions for the Github authentication flow.
|
||||
type GithubOAuth2Config struct {
|
||||
httpmw.OAuth2Config
|
||||
AuthenticatedUser func(ctx context.Context, client *http.Client) (*github.User, error)
|
||||
ListEmails func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error)
|
||||
ListOrganizationMemberships func(ctx context.Context, client *http.Client) ([]*github.Membership, error)
|
||||
|
||||
AllowSignups bool
|
||||
AllowOrganizations []string
|
||||
}
|
||||
|
||||
func (api *api) userAuthMethods(rw http.ResponseWriter, _ *http.Request) {
|
||||
httpapi.Write(rw, http.StatusOK, codersdk.AuthMethods{
|
||||
Password: true,
|
||||
Github: api.GithubOAuth2Config != nil,
|
||||
})
|
||||
}
|
||||
|
||||
func (api *api) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
|
||||
oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token))
|
||||
memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get authenticated github user organizations: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
var selectedMembership *github.Membership
|
||||
for _, membership := range memberships {
|
||||
for _, allowed := range api.GithubOAuth2Config.AllowOrganizations {
|
||||
if *membership.Organization.Login != allowed {
|
||||
continue
|
||||
}
|
||||
selectedMembership = membership
|
||||
break
|
||||
}
|
||||
}
|
||||
if selectedMembership == nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("You aren't a member of the authorized Github organizations!"),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get personal github user: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var user database.User
|
||||
// Search for existing users with matching and verified emails.
|
||||
// If a verified GitHub email matches a Coder user, we will return.
|
||||
for _, email := range emails {
|
||||
if email.Verified == nil {
|
||||
continue
|
||||
}
|
||||
user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
Email: *email.Email,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get user by email: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !*email.Verified {
|
||||
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
|
||||
Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// If the user doesn't exist, create a new one!
|
||||
if user.ID == uuid.Nil {
|
||||
if !api.GithubOAuth2Config.AllowSignups {
|
||||
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
|
||||
Message: "Signups are disabled for Github authentication!",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var organizationID uuid.UUID
|
||||
organizations, _ := api.Database.GetOrganizations(r.Context())
|
||||
if len(organizations) > 0 {
|
||||
// Add the user to the first organization. Once multi-organization
|
||||
// support is added, we should enable a configuration map of user
|
||||
// email to organization.
|
||||
organizationID = organizations[0].ID
|
||||
}
|
||||
ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get authenticated github user: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
|
||||
Email: *ghUser.Email,
|
||||
Username: *ghUser.Login,
|
||||
OrganizationID: organizationID,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("create user: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
OAuthAccessToken: state.Token.AccessToken,
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
}
|
||||
|
||||
redirect := state.Redirect
|
||||
if redirect == "" {
|
||||
redirect = "/"
|
||||
}
|
||||
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
|
@ -0,0 +1,205 @@
|
|||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
type oauth2Config struct{}
|
||||
|
||||
func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUserAuthMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Password", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
methods, err := client.AuthMethods(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.True(t, methods.Password)
|
||||
require.False(t, methods.Github)
|
||||
})
|
||||
t.Run("Github", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{},
|
||||
})
|
||||
methods, err := client.AuthMethods(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.True(t, methods.Password)
|
||||
require.True(t, methods.Github)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserOAuth2Github(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NotInAllowedOrganization", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("kyle"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
})
|
||||
t.Run("UnverifiedEmail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{{
|
||||
Email: github.String("testuser@coder.com"),
|
||||
Verified: github.Bool(false),
|
||||
}}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
})
|
||||
t.Run("BlockSignups", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
})
|
||||
t.Run("Signup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
AllowSignups: true,
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{
|
||||
Login: github.String("kyle"),
|
||||
Email: github.String("kyle@coder.com"),
|
||||
}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
})
|
||||
t.Run("Login", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{{
|
||||
Email: github.String("testuser@coder.com"),
|
||||
Verified: github.Bool(true),
|
||||
}}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
state := "somestate"
|
||||
oauthURL, err := client.URL.Parse("/api/v2/users/oauth2/github/callback?code=asd&state=" + state)
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequest("GET", oauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state,
|
||||
})
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = res.Body.Close()
|
||||
})
|
||||
return res
|
||||
}
|
297
coderd/users.go
297
coderd/users.go
|
@ -1,6 +1,7 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
@ -71,66 +72,10 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := userpassword.Hash(createUser.Password)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("hash password: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create the user, organization, and membership to the user.
|
||||
var user database.User
|
||||
var organization database.Organization
|
||||
err = api.Database.InTx(func(db database.Store) error {
|
||||
user, err = api.Database.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: uuid.New(),
|
||||
Email: createUser.Email,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
Username: createUser.Username,
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate user gitsshkey: %w", err)
|
||||
}
|
||||
_, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user gitsshkey: %w", err)
|
||||
}
|
||||
|
||||
organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: uuid.New(),
|
||||
Name: createUser.OrganizationName,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization: %w", err)
|
||||
}
|
||||
_, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Roles: []string{"organization-admin"},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization member: %w", err)
|
||||
}
|
||||
return nil
|
||||
user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{
|
||||
Email: createUser.Email,
|
||||
Username: createUser.Username,
|
||||
Password: createUser.Password,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
|
@ -141,7 +86,7 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
|||
|
||||
httpapi.Write(rw, http.StatusCreated, codersdk.CreateFirstUserResponse{
|
||||
UserID: user.ID,
|
||||
OrganizationID: organization.ID,
|
||||
OrganizationID: organizationID,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -262,56 +207,7 @@ func (api *api) postUsers(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := userpassword.Hash(createUser.Password)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("hash password: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var user database.User
|
||||
err = api.Database.InTx(func(db database.Store) error {
|
||||
user, err = db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: uuid.New(),
|
||||
Email: createUser.Email,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
Username: createUser.Username,
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate user gitsshkey: %w", err)
|
||||
}
|
||||
_, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user gitsshkey: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Roles: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization member: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
user, _, err := api.createUser(r.Context(), createUser)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: err.Error(),
|
||||
|
@ -542,41 +438,13 @@ func (api *api) postLogin(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
|
||||
})
|
||||
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: user.ID,
|
||||
ExpiresAt: database.Now().Add(24 * time.Hour),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("insert api key: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: api.SecureAuthCookie,
|
||||
})
|
||||
|
||||
httpapi.Write(rw, http.StatusCreated, codersdk.LoginWithPasswordResponse{
|
||||
SessionToken: sessionToken,
|
||||
|
@ -595,35 +463,15 @@ func (api *api) postAPIKey(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: apiKey.UserID,
|
||||
ExpiresAt: database.Now().AddDate(1, 0, 0), // Expire after 1 year (same as v1)
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("insert api key: %s", err.Error()),
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
}
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
generatedAPIKey := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
|
||||
httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: generatedAPIKey})
|
||||
httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: sessionToken})
|
||||
}
|
||||
|
||||
// Clear the user's session cookie
|
||||
|
@ -984,6 +832,117 @@ func generateAPIKeyIDSecret() (id string, secret string, err error) {
|
|||
return id, secret, nil
|
||||
}
|
||||
|
||||
func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) {
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: params.UserID,
|
||||
ExpiresAt: database.Now().Add(24 * time.Hour),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: params.LoginType,
|
||||
OAuthAccessToken: params.OAuthAccessToken,
|
||||
OAuthRefreshToken: params.OAuthRefreshToken,
|
||||
OAuthIDToken: params.OAuthIDToken,
|
||||
OAuthExpiry: params.OAuthExpiry,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("insert api key: %s", err.Error()),
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: api.SecureAuthCookie,
|
||||
})
|
||||
return sessionToken, true
|
||||
}
|
||||
|
||||
func (api *api) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) {
|
||||
var user database.User
|
||||
return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error {
|
||||
// If no organization is provided, create a new one for the user.
|
||||
if req.OrganizationID == uuid.Nil {
|
||||
organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{
|
||||
ID: uuid.New(),
|
||||
Name: req.Username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization: %w", err)
|
||||
}
|
||||
req.OrganizationID = organization.ID
|
||||
}
|
||||
|
||||
params := database.InsertUserParams{
|
||||
ID: uuid.New(),
|
||||
Email: req.Email,
|
||||
Username: req.Username,
|
||||
LoginType: database.LoginTypePassword,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
}
|
||||
// If a user signs up with OAuth, they can have no password!
|
||||
if req.Password != "" {
|
||||
hashedPassword, err := userpassword.Hash(req.Password)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("hash password: %w", err)
|
||||
}
|
||||
params.HashedPassword = []byte(hashedPassword)
|
||||
}
|
||||
|
||||
var err error
|
||||
user, err = db.InsertUser(ctx, params)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate user gitsshkey: %w", err)
|
||||
}
|
||||
_, err = db.InsertGitSSHKey(ctx, database.InsertGitSSHKeyParams{
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user gitsshkey: %w", err)
|
||||
}
|
||||
_, err = db.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
|
||||
OrganizationID: req.OrganizationID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Roles: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization member: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func convertUser(user database.User) codersdk.User {
|
||||
return codersdk.User{
|
||||
ID: user.ID,
|
||||
|
|
|
@ -241,13 +241,14 @@ func TestUpdateUserProfile(t *testing.T) {
|
|||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
existentUser, _ := client.CreateUser(context.Background(), codersdk.CreateUserRequest{
|
||||
existentUser, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{
|
||||
Email: "bruno@coder.com",
|
||||
Username: "bruno",
|
||||
Password: "password",
|
||||
OrganizationID: user.OrganizationID,
|
||||
})
|
||||
_, err := client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{
|
||||
require.NoError(t, err)
|
||||
_, err = client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{
|
||||
Username: existentUser.Username,
|
||||
Email: "newemail@coder.com",
|
||||
})
|
||||
|
|
|
@ -92,6 +92,12 @@ type CreateWorkspaceRequest struct {
|
|||
ParameterValues []CreateParameterRequest `json:"parameter_values"`
|
||||
}
|
||||
|
||||
// AuthMethods contains whether authentication types are enabled or not.
|
||||
type AuthMethods struct {
|
||||
Password bool `json:"password"`
|
||||
Github bool `json:"github"`
|
||||
}
|
||||
|
||||
// HasFirstUser returns whether the first user has been created.
|
||||
func (c *Client) HasFirstUser(ctx context.Context) (bool, error) {
|
||||
res, err := c.request(ctx, http.MethodGet, "/api/v2/users/first", nil)
|
||||
|
@ -330,6 +336,22 @@ func (c *Client) WorkspaceByName(ctx context.Context, userID uuid.UUID, name str
|
|||
return workspace, json.NewDecoder(res.Body).Decode(&workspace)
|
||||
}
|
||||
|
||||
// AuthMethods returns types of authentication available to the user.
|
||||
func (c *Client) AuthMethods(ctx context.Context) (AuthMethods, error) {
|
||||
res, err := c.request(ctx, http.MethodGet, "/api/v2/users/authmethods", nil)
|
||||
if err != nil {
|
||||
return AuthMethods{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return AuthMethods{}, readBodyAsError(res)
|
||||
}
|
||||
|
||||
var userAuth AuthMethods
|
||||
return userAuth, json.NewDecoder(res.Body).Decode(&userAuth)
|
||||
}
|
||||
|
||||
// uuidOrMe returns the provided uuid as a string if it's valid, ortherwise
|
||||
// `me`.
|
||||
func uuidOrMe(id uuid.UUID) string {
|
||||
|
|
2
go.mod
2
go.mod
|
@ -61,6 +61,7 @@ require (
|
|||
github.com/gohugoio/hugo v0.97.2
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/golang-migrate/migrate/v4 v4.15.1
|
||||
github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/hashicorp/go-version v1.4.0
|
||||
github.com/hashicorp/hc-install v0.3.1
|
||||
|
@ -157,6 +158,7 @@ require (
|
|||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/google/go-cmp v0.5.7 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/gorilla/mux v1.8.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -784,7 +784,11 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
|
||||
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
|
||||
github.com/google/go-github/v35 v35.2.0/go.mod h1:s0515YVTI+IMrDoy9Y4pHt9ShGpzHvHO8rZ7L7acgvs=
|
||||
github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405 h1:DdHws/YnnPrSywrjNYu2lEHqYHWp/LnEx56w59esd54=
|
||||
github.com/google/go-github/v43 v43.0.1-0.20220414155304-00e42332e405/go.mod h1:4RgUDSnsxP19d65zJWqvqJ/poJxBCvmna50eXmIvoR8=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
|
|
|
@ -2,6 +2,7 @@ import axios, { AxiosRequestHeaders } from "axios"
|
|||
import { mutate } from "swr"
|
||||
import { MockPager, MockUser, MockUser2 } from "../testHelpers/entities"
|
||||
import * as Types from "./types"
|
||||
import * as TypesGen from "./typesGenerated"
|
||||
|
||||
const CONTENT_TYPE_JSON: AxiosRequestHeaders = {
|
||||
"Content-Type": "application/json",
|
||||
|
@ -65,6 +66,11 @@ export const getUser = async (): Promise<Types.UserResponse> => {
|
|||
return response.data
|
||||
}
|
||||
|
||||
export const getAuthMethods = async (): Promise<TypesGen.AuthMethods> => {
|
||||
const response = await axios.get<TypesGen.AuthMethods>("/api/v2/users/authmethods")
|
||||
return response.data
|
||||
}
|
||||
|
||||
export const getApiKey = async (): Promise<Types.APIKeyResponse> => {
|
||||
const response = await axios.post<Types.APIKeyResponse>("/api/v2/users/me/keys")
|
||||
return response.data
|
||||
|
|
|
@ -132,6 +132,12 @@ export interface CreateWorkspaceRequest {
|
|||
readonly name: string
|
||||
}
|
||||
|
||||
// From codersdk/users.go:96:6.
|
||||
export interface AuthMethods {
|
||||
readonly password: boolean
|
||||
readonly github: boolean
|
||||
}
|
||||
|
||||
// From codersdk/workspaceagents.go:31:6.
|
||||
export interface GoogleInstanceIdentityToken {
|
||||
readonly json_web_token: string
|
||||
|
|
|
@ -24,7 +24,26 @@ SignedOut.args = {
|
|||
}
|
||||
|
||||
export const Loading = Template.bind({})
|
||||
Loading.args = { ...SignedOut.args, isLoading: true }
|
||||
Loading.args = {
|
||||
...SignedOut.args,
|
||||
isLoading: true,
|
||||
authMethods: {
|
||||
github: true,
|
||||
password: true,
|
||||
},
|
||||
}
|
||||
|
||||
export const WithError = Template.bind({})
|
||||
WithError.args = { ...SignedOut.args, authErrorMessage: "Email or password was invalid" }
|
||||
export const WithLoginError = Template.bind({})
|
||||
WithLoginError.args = { ...SignedOut.args, authErrorMessage: "Email or password was invalid" }
|
||||
|
||||
export const WithAuthMethodsError = Template.bind({})
|
||||
WithAuthMethodsError.args = { ...SignedOut.args, methodsErrorMessage: "Failed to fetch auth methods" }
|
||||
|
||||
export const WithGithub = Template.bind({})
|
||||
WithGithub.args = {
|
||||
...SignedOut.args,
|
||||
authMethods: {
|
||||
password: true,
|
||||
github: true,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import Button from "@material-ui/core/Button"
|
||||
import FormHelperText from "@material-ui/core/FormHelperText"
|
||||
import Link from "@material-ui/core/Link"
|
||||
import { makeStyles } from "@material-ui/core/styles"
|
||||
import TextField from "@material-ui/core/TextField"
|
||||
import { FormikContextType, useFormik } from "formik"
|
||||
import React from "react"
|
||||
import * as Yup from "yup"
|
||||
import { AuthMethods } from "../../api/typesGenerated"
|
||||
import { getFormHelpers, onChangeTrimmed } from "../../util/formUtils"
|
||||
import { Welcome } from "../Welcome/Welcome"
|
||||
import { LoadingButton } from "./../LoadingButton/LoadingButton"
|
||||
|
@ -24,7 +27,9 @@ export const Language = {
|
|||
emailInvalid: "Please enter a valid email address.",
|
||||
emailRequired: "Please enter an email address.",
|
||||
authErrorMessage: "Incorrect email or password.",
|
||||
signIn: "Sign In",
|
||||
methodsErrorMessage: "Unable to fetch auth methods.",
|
||||
passwordSignIn: "Sign In",
|
||||
githubSignIn: "GitHub",
|
||||
}
|
||||
|
||||
const validationSchema = Yup.object({
|
||||
|
@ -49,10 +54,18 @@ const useStyles = makeStyles((theme) => ({
|
|||
export interface SignInFormProps {
|
||||
isLoading: boolean
|
||||
authErrorMessage?: string
|
||||
methodsErrorMessage?: string
|
||||
authMethods?: AuthMethods
|
||||
onSubmit: ({ email, password }: { email: string; password: string }) => Promise<void>
|
||||
}
|
||||
|
||||
export const SignInForm: React.FC<SignInFormProps> = ({ isLoading, authErrorMessage, onSubmit }) => {
|
||||
export const SignInForm: React.FC<SignInFormProps> = ({
|
||||
authMethods,
|
||||
isLoading,
|
||||
authErrorMessage,
|
||||
methodsErrorMessage,
|
||||
onSubmit,
|
||||
}) => {
|
||||
const styles = useStyles()
|
||||
|
||||
const form: FormikContextType<BuiltInAuthFormValues> = useFormik<BuiltInAuthFormValues>({
|
||||
|
@ -76,6 +89,7 @@ export const SignInForm: React.FC<SignInFormProps> = ({ isLoading, authErrorMess
|
|||
className={styles.loginTextField}
|
||||
fullWidth
|
||||
label={Language.emailLabel}
|
||||
type="email"
|
||||
variant="outlined"
|
||||
/>
|
||||
<TextField
|
||||
|
@ -89,12 +103,22 @@ export const SignInForm: React.FC<SignInFormProps> = ({ isLoading, authErrorMess
|
|||
variant="outlined"
|
||||
/>
|
||||
{authErrorMessage && <FormHelperText error>{Language.authErrorMessage}</FormHelperText>}
|
||||
{methodsErrorMessage && <FormHelperText error>{Language.methodsErrorMessage}</FormHelperText>}
|
||||
<div className={styles.submitBtn}>
|
||||
<LoadingButton color="primary" loading={isLoading} fullWidth type="submit" variant="contained">
|
||||
{isLoading ? "" : Language.signIn}
|
||||
{isLoading ? "" : Language.passwordSignIn}
|
||||
</LoadingButton>
|
||||
</div>
|
||||
</form>
|
||||
{authMethods?.github && (
|
||||
<div className={styles.submitBtn}>
|
||||
<Link href="/api/v2/users/oauth2/github/callback">
|
||||
<Button color="primary" disabled={isLoading} fullWidth type="submit" variant="contained">
|
||||
{Language.githubSignIn}
|
||||
</Button>
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ describe("LoginPage", () => {
|
|||
render(<LoginPage />)
|
||||
|
||||
// Then
|
||||
await screen.findByText(Language.signIn)
|
||||
await screen.findByText(Language.passwordSignIn)
|
||||
})
|
||||
|
||||
it("shows an error message if SignIn fails", async () => {
|
||||
|
@ -42,7 +42,7 @@ describe("LoginPage", () => {
|
|||
await userEvent.type(email, "test@coder.com")
|
||||
await userEvent.type(password, "password")
|
||||
// Click sign-in
|
||||
const signInButton = await screen.findByText(Language.signIn)
|
||||
const signInButton = await screen.findByText(Language.passwordSignIn)
|
||||
act(() => signInButton.click())
|
||||
|
||||
// Then
|
||||
|
@ -50,4 +50,43 @@ describe("LoginPage", () => {
|
|||
expect(errorMessage).toBeDefined()
|
||||
expect(history.location.pathname).toEqual("/login")
|
||||
})
|
||||
|
||||
it("shows an error if fetching auth methods fails", async () => {
|
||||
// Given
|
||||
server.use(
|
||||
// Make login fail
|
||||
rest.get("/api/v2/users/authmethods", async (req, res, ctx) => {
|
||||
return res(ctx.status(500), ctx.json({ message: "nope" }))
|
||||
}),
|
||||
)
|
||||
|
||||
// When
|
||||
render(<LoginPage />)
|
||||
|
||||
// Then
|
||||
const errorMessage = await screen.findByText(Language.methodsErrorMessage)
|
||||
expect(errorMessage).toBeDefined()
|
||||
})
|
||||
|
||||
it("shows github authentication when enabled", async () => {
|
||||
// Given
|
||||
server.use(
|
||||
rest.get("/api/v2/users/authmethods", async (req, res, ctx) => {
|
||||
return res(
|
||||
ctx.status(200),
|
||||
ctx.json({
|
||||
password: true,
|
||||
github: true,
|
||||
}),
|
||||
)
|
||||
}),
|
||||
)
|
||||
|
||||
// When
|
||||
render(<LoginPage />)
|
||||
|
||||
// Then
|
||||
await screen.findByText(Language.passwordSignIn)
|
||||
await screen.findByText(Language.githubSignIn)
|
||||
})
|
||||
})
|
||||
|
|
|
@ -35,6 +35,9 @@ export const LoginPage: React.FC = () => {
|
|||
const isLoading = authState.hasTag("loading")
|
||||
const redirectTo = retrieveRedirect(location.search)
|
||||
const authErrorMessage = authState.context.authError ? (authState.context.authError as Error).message : undefined
|
||||
const getMethodsError = authState.context.getMethodsError
|
||||
? (authState.context.getMethodsError as Error).message
|
||||
: undefined
|
||||
|
||||
const onSubmit = async ({ email, password }: { email: string; password: string }) => {
|
||||
authSend({ type: "SIGN_IN", email, password })
|
||||
|
@ -47,7 +50,13 @@ export const LoginPage: React.FC = () => {
|
|||
<div className={styles.root}>
|
||||
<div className={styles.layout}>
|
||||
<div className={styles.container}>
|
||||
<SignInForm isLoading={isLoading} authErrorMessage={authErrorMessage} onSubmit={onSubmit} />
|
||||
<SignInForm
|
||||
authMethods={authState.context.methods}
|
||||
isLoading={isLoading}
|
||||
authErrorMessage={authErrorMessage}
|
||||
methodsErrorMessage={getMethodsError}
|
||||
onSubmit={onSubmit}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Footer />
|
||||
|
|
|
@ -9,6 +9,7 @@ import {
|
|||
Workspace,
|
||||
WorkspaceAutostartRequest,
|
||||
} from "../api/types"
|
||||
import { AuthMethods } from "../api/typesGenerated"
|
||||
|
||||
export const MockSessionToken = { session_token: "my-session-token" }
|
||||
|
||||
|
@ -97,3 +98,8 @@ export const MockUserAgent: UserAgent = {
|
|||
ip_address: "11.22.33.44",
|
||||
os: "Windows 10",
|
||||
}
|
||||
|
||||
export const MockAuthMethods: AuthMethods = {
|
||||
password: true,
|
||||
github: false,
|
||||
}
|
||||
|
|
|
@ -42,6 +42,9 @@ export const handlers = [
|
|||
rest.get("/api/v2/users/me/keys", async (req, res, ctx) => {
|
||||
return res(ctx.status(200), ctx.json(M.MockAPIKey))
|
||||
}),
|
||||
rest.get("/api/v2/users/authmethods", async (req, res, ctx) => {
|
||||
return res(ctx.status(200), ctx.json(M.MockAuthMethods))
|
||||
}),
|
||||
|
||||
// workspaces
|
||||
rest.get("/api/v2/workspaces/:workspaceId", async (req, res, ctx) => {
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
import { assign, createMachine } from "xstate"
|
||||
import * as API from "../../api"
|
||||
import * as Types from "../../api/types"
|
||||
import * as TypesGen from "../../api/typesGenerated"
|
||||
import { displaySuccess } from "../../components/GlobalSnackbar/utils"
|
||||
|
||||
export const Language = {
|
||||
successProfileUpdate: "Updated preferences.",
|
||||
}
|
||||
|
||||
export interface AuthContext {
|
||||
getUserError?: Error | unknown
|
||||
getMethodsError?: Error | unknown
|
||||
authError?: Error | unknown
|
||||
updateProfileError?: Error | unknown
|
||||
me?: Types.UserResponse
|
||||
methods?: TypesGen.AuthMethods
|
||||
}
|
||||
|
||||
export type AuthEvent =
|
||||
|
@ -19,10 +23,17 @@ export type AuthEvent =
|
|||
| { type: "UPDATE_PROFILE"; data: Types.UpdateProfileRequest }
|
||||
|
||||
export const authMachine =
|
||||
/** @xstate-layout N4IgpgJg5mDOIC5QEMCuAXAFgZXc9YAdLAJZQB2kA8hgMTYCSA4gHID6DLioADgPal0JPuW4gAHogCMABgCcAFkLyZAVilS5GgEwAOAMwA2ADQgAnogDsa5YcOXD2y7oWH92-QF9PptFlz4RKQUJORQDOS0ECJEoQBufADWQWTkEWL8gsKiSBKI7kraqpZSuh7aUvpy6vqmFghlcoT6qgb6GsUlctrevhg4eATEqaHhkWAAThN8E4Q8ADb4AGYzALbDFOm5mSRCImKSCLKKynJqGlpSekZ1iIaltvaOzq4FvSB+A4GEMOhCYQBVWCTKIxQjxJJEX4AWTAGQEu2yB0QrikhG0Mn0+l0ulUWJkhlatXMKKqynkCn0lOscnsUnenwCQ1+-ygQJBk2mswWyzWPzA6Fh8Ky+1yh202iacip2I8hjU9m0twQ1lUzSk9hkHlUCjU2kMDP6TJSFEgETm0yWJHmsQgNtoAIACgARACCABUAKJsR0AJSoADEGAAZT3CxGi0CHGTKmSG-yDE2UCDmniW61EVA8CD4UaO9P26KUcHkBLJQiMxMbZOpguZ7O5sL5vhWm0ICEAY1zIgA2jIALrhvY5KPSKSWNWKWQeaUKSXj5UY3SEUrdKqExy08fxr5DYI18gWlsZwhZnOs5utsC0TkzOaLdArCbrSvffdmw9p48208Ni919tSz4Lthz7QdtgRYdkQaLRCF0SxtAUcdpQnRxdGVXV9EIKdFBkGRdDkSxFGKHdjWrD96GYdgqABd0hyRMVpCxdFaSxVQZEsIxx0sSxlUI9Eql1DUJQnRQ5FIqt91GGh0FBYsIXLfcZPoyM8gQdwsIQuRdEMSlSm1DxlR1Zc1AIhDdJwrwfA+I1JJGMIZJvKY7x5R8+SUjAVJHNSijVdwCIUdQriJfReJJBAihkZpLAUJDikigwJW8azyD4CA4DEV891SahPIgkVvMOdRDEINxqTUbFCW05UJywhQDFUNd2gxQxxOsrKk1GLZeEghjRyOfUoqkPEOOlQj2mVVrtDgjUEPYkpcT0CTvhZUZ2QmLzoLnZVdA1OCiS0TjcU0VRluy00U0-OtwTtIhUs9ZyNvyiNCrHHi4InPCFE4wktQwvbQtiloWu0jFLDOpMPyPK8bp-W8np6groI0d74PYmRvqMdilXC1wSvYxRYsIwbrHB9rbLfHLLuhk8SFuzbGKOYasLRr6fux5UZWi2KCclXarL6BNKYu2tv3rc88zrBn+pxTTGsqULp2xKRlRO0qCJnDdJXuMnBd3SHqa-K9pbUgBaRDlVNzjyTwjpOM+1QDXJoXzoPE3DlN+rLakVwba1dpvvYjQIeraS8sRl7oPcQgicQicihxGQfaM9jlFaWdSgJDR6Wd-X3cQZdY8DhPdCThRLY8dUOPsBwDE4wjdGSzwgA */
|
||||
/** @xstate-layout N4IgpgJg5mDOIC5QEMCuAXAFgZXc9YAdLAJZQB2kA8hgMTYCSA4gHID6DLioADgPal0JPuW4gAHogDsABgCshOQA4AzABY5ARgBMUgJxrtcgDQgAnok0zNSwkpkrNKvQDY3aqS6kBfb6bRYuPhEpBQk5FAM5LQQIkThAG58ANYhZORRYvyCwqJIEohyMlKE2mrFKo7aLq5SJuaILtq2mjZqrboy2s4+fiABOHgExOnhkdFgAE6TfJOEPAA2+ABmswC2IxSZ+dkkQiJikgiyCsrqWroGRqYWCHoq2oQyDpouXa1qGmq+-hiDwYQYOghBEAKqwKYxOKERIpIhAgCyYCyAj2uUOiF0mieTik3UqMhqSleN0QhgUOhUbzUKhk9jKch+-T+QWGQJBUHBkKmMzmixW60BYHQSJROQO+SOUmxVKUniUcjkUgVLjlpOOCqecj0LyKmmVeiUTIGrPhwo5SKwfAgsChlBh5CSqSFIuFmGt8B2qP2eVARxUeNKCo02k0cllTRc6qUalsCtU731DikvV+gSGZuBY0t7pttB5s3mS3Qq0mG0Rbo9YrREr9iHUMkIUnKHSklQVKnqtz0BkImjUbmeShcVmejL6Jozm0oECi8xmyxIC3iEGXtFBAAUACIAQQAKgBRNgbgBKVAAYgwADIH6s+jEIOV6Qgj1oxnTBkfq7qNlxyT4aC4saaPcVLGiyU6hDOc48AuS5EKgPAQPgYwbnBa6xPasLOpOAJQZAMHoQhSEoREaF8Iuy4ILCADGKEiAA2jIAC6d7opKiBPi+rRtB+-5fg0CDqM+YafG8+jWLI2jgemeHpAR5DzhR8GEIhyEcuRlFgPm0yFvyJaCrhwz4bOimwcpy6qSRGlEdRjp8HRPpMaxXrir6BSPvo3Fvu0zT8Zo6oDmoTb-p8cjaFcsjqDJ-zGfJpn0Mw7BUKCe5sbWHm6CU2jWKGnhaFSeLqv2jZKMOehOE49i0v+MWmtOYw0OgdrxPZzpQU16XuUcAC0-aPGGSgVQOTS4ko2jfjGhC0gB1WeDIBguHVkGjBETU6byRYCmW06da5NbdZiKalLl+p-k4XgTYJNhxuVNKGCB77fEy5DWnAYhGWkFDUBgXUPoqLiKOoxIqGVejNKoxXPCUMgjZ8zj6sORoThBclhBE2y8N67F1o+xIvhoejavqEX-gFgl6IGFWOC2bzWNqy0AuyYxcpMf0cQghjqn+JT6GJeJynIqrI2msWZhalY2uzuN9dojwpn+epDvqHjRkUL7KBDEljiLzKyXF32mUpWkwquRCvQeuls-t94c606s8c0A5C00UjqqDja5cUXQRXiDiMwb0FmURpuWQW1tY25D7242jsxn+bi6IFhh2K4qjKP+fsqAHX1B8bKkkGb0seR0gNx87idu4JtIwzoxTdELzbheOov1SZhEWcR6moURxdHCOgPhsBoNDRDKju84hCfHS4ZAXPqg59OCn58ufeFODQPD2DY-FUqGsASmdShgvKP67nClrwgQu2EPIPb2V4+CVNf7w5TOphs22en2LDVrb9Ns4w8j1coU8iafDKJVWQagDDqmOsOKBVhXDgweNJb+ppL59TcH2ZQw1E5jSurcYK8CHCODfE0B4vRfBAA */
|
||||
createMachine(
|
||||
{
|
||||
context: { me: undefined, getUserError: undefined, authError: undefined, updateProfileError: undefined },
|
||||
context: {
|
||||
me: undefined,
|
||||
getUserError: undefined,
|
||||
authError: undefined,
|
||||
updateProfileError: undefined,
|
||||
methods: undefined,
|
||||
getMethodsError: undefined,
|
||||
},
|
||||
tsTypes: {} as import("./authXService.typegen").Typegen0,
|
||||
schema: {
|
||||
context: {} as AuthContext,
|
||||
|
@ -31,6 +42,9 @@ export const authMachine =
|
|||
getMe: {
|
||||
data: Types.UserResponse
|
||||
}
|
||||
getMethods: {
|
||||
data: TypesGen.AuthMethods
|
||||
}
|
||||
signIn: {
|
||||
data: Types.LoginResponse
|
||||
}
|
||||
|
@ -81,6 +95,25 @@ export const authMachine =
|
|||
onError: [
|
||||
{
|
||||
actions: "assignGetUserError",
|
||||
target: "gettingMethods",
|
||||
},
|
||||
],
|
||||
},
|
||||
tags: "loading",
|
||||
},
|
||||
gettingMethods: {
|
||||
invoke: {
|
||||
src: "getMethods",
|
||||
id: "getMethods",
|
||||
onDone: [
|
||||
{
|
||||
actions: ["assignMethods", "clearGetMethodsError"],
|
||||
target: "signedOut",
|
||||
},
|
||||
],
|
||||
onError: [
|
||||
{
|
||||
actions: "assignGetMethodsError",
|
||||
target: "signedOut",
|
||||
},
|
||||
],
|
||||
|
@ -139,7 +172,7 @@ export const authMachine =
|
|||
onDone: [
|
||||
{
|
||||
actions: ["unassignMe", "clearAuthError"],
|
||||
target: "signedOut",
|
||||
target: "gettingMethods",
|
||||
},
|
||||
],
|
||||
onError: [
|
||||
|
@ -160,6 +193,7 @@ export const authMachine =
|
|||
},
|
||||
signOut: API.logout,
|
||||
getMe: API.getUser,
|
||||
getMethods: API.getAuthMethods,
|
||||
updateProfile: async (context, event) => {
|
||||
if (!context.me) {
|
||||
throw new Error("No current user found")
|
||||
|
@ -176,6 +210,16 @@ export const authMachine =
|
|||
...context,
|
||||
me: undefined,
|
||||
})),
|
||||
assignMethods: assign({
|
||||
methods: (_, event) => event.data,
|
||||
}),
|
||||
assignGetMethodsError: assign({
|
||||
getMethodsError: (_, event) => event.data,
|
||||
}),
|
||||
clearGetMethodsError: assign((context: AuthContext) => ({
|
||||
...context,
|
||||
getMethodsError: undefined,
|
||||
})),
|
||||
assignGetUserError: assign({
|
||||
getUserError: (_, event) => event.data,
|
||||
}),
|
||||
|
|
Loading…
Reference in New Issue