mirror of https://github.com/JSH32/Backpack.git
Put OAuth logic into seperate module.
This commit is contained in:
parent
a194b8be9e
commit
2040f18cd2
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
models::{auth::BasicAuthForm, AuthRequest, TokenResponse},
|
||||
services::{
|
||||
auth::{AuthService, OAuthProvider},
|
||||
auth::{oauth::OAuthProvider, AuthService},
|
||||
ServiceError, ToResponse,
|
||||
},
|
||||
};
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
database::entity::{files, settings},
|
||||
models::{AppInfo, OAuthProviders},
|
||||
services::{
|
||||
auth::{AuthService, OAuthProvider},
|
||||
auth::{oauth::OAuthProvider, AuthService},
|
||||
user::UserService,
|
||||
ServiceError,
|
||||
},
|
||||
|
|
|
@ -1,13 +1,7 @@
|
|||
use actix_http::Uri;
|
||||
use anyhow::anyhow;
|
||||
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
|
||||
use chrono::Utc;
|
||||
use derive_more::Display;
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use oauth2::{
|
||||
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
||||
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
|
||||
};
|
||||
use rand::rngs::OsRng;
|
||||
use sea_orm::{ColumnTrait, Condition};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -16,145 +10,17 @@ use std::sync::{Arc, RwLock};
|
|||
use crate::{
|
||||
config::OAuthConfig,
|
||||
database::entity::{applications, users},
|
||||
models::AuthRequest,
|
||||
models::{AuthRequest, TokenResponse},
|
||||
};
|
||||
|
||||
use self::oauth::{OAuthClient, OAuthProvider};
|
||||
|
||||
use super::{
|
||||
application::ApplicationService, prelude::DataService, user::UserService, ServiceError,
|
||||
ServiceResult,
|
||||
};
|
||||
|
||||
/// All OAuth providers.
|
||||
#[derive(Debug, Display)]
|
||||
pub enum OAuthProvider {
|
||||
Google,
|
||||
Github,
|
||||
Discord,
|
||||
}
|
||||
|
||||
struct EmailRequest {
|
||||
request_endpoint: RequestEndpoint,
|
||||
/// Email retriever using result data.
|
||||
email_retrieve: fn(serde_json::Value) -> Option<String>,
|
||||
}
|
||||
|
||||
/// User request endpoint configuration for getting email.
|
||||
enum RequestEndpoint {
|
||||
/// Format URL with token as argument.
|
||||
FormatUrl(fn(&str) -> String),
|
||||
/// Automatically use token in `Authorization` header.
|
||||
Bearer(String),
|
||||
}
|
||||
|
||||
struct OAuthClient {
|
||||
http_client: reqwest::Client,
|
||||
client: BasicClient,
|
||||
scopes: Vec<Scope>,
|
||||
email_request: EmailRequest,
|
||||
}
|
||||
|
||||
impl OAuthClient {
|
||||
pub fn new(
|
||||
oauth_config: OAuthConfig,
|
||||
auth_url: &str,
|
||||
token_url: &str,
|
||||
redirect_url: &str,
|
||||
scopes: &[&str],
|
||||
email_request: EmailRequest,
|
||||
) -> Self {
|
||||
let auth_url = AuthUrl::new(auth_url.to_string()).unwrap();
|
||||
let token_url = TokenUrl::new(token_url.to_string()).unwrap();
|
||||
|
||||
Self {
|
||||
http_client: reqwest::Client::builder()
|
||||
.user_agent("Backpack")
|
||||
.build()
|
||||
.unwrap(),
|
||||
client: BasicClient::new(
|
||||
ClientId::new(oauth_config.client_id),
|
||||
Some(ClientSecret::new(oauth_config.client_secret)),
|
||||
auth_url,
|
||||
Some(token_url),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url.into()).expect("Invalid redirect URL")),
|
||||
scopes: scopes
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|f| Scope::new(f.to_string()))
|
||||
.collect(),
|
||||
email_request,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initiate an oauth login.
|
||||
/// Start the login session by redirecting the user to the provider URL.
|
||||
fn login(&self) -> ServiceResult<oauth2::url::Url> {
|
||||
// TODO: PKCE verification.
|
||||
|
||||
// Generate the authorization URL to which we'll redirect the user.
|
||||
let (authorize_url, _csrf_state) = self
|
||||
.client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scopes(self.scopes.clone())
|
||||
.url();
|
||||
|
||||
Ok(authorize_url)
|
||||
}
|
||||
|
||||
/// Use auth params provided by the provider to get a JWT token.
|
||||
/// Returns user email.
|
||||
async fn auth(&self, oauth_request: &AuthRequest) -> ServiceResult<String> {
|
||||
let code = AuthorizationCode::new(oauth_request.code.clone());
|
||||
|
||||
// Exchange the code with a token.
|
||||
let token = match self
|
||||
.client
|
||||
.exchange_code(code)
|
||||
.request_async(async_http_client)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(e) => return Err(ServiceError::ServerError(e.into())),
|
||||
};
|
||||
|
||||
let response = match &self.email_request.request_endpoint {
|
||||
RequestEndpoint::FormatUrl(formatter) => self
|
||||
.http_client
|
||||
.get(formatter(token.access_token().secret())),
|
||||
RequestEndpoint::Bearer(url) => self
|
||||
.http_client
|
||||
.get(url)
|
||||
.bearer_auth(token.access_token().secret()),
|
||||
}
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ServiceError::ServerError(e.into()))?
|
||||
.json::<serde_json::Value>()
|
||||
.await
|
||||
.map_err(|e| ServiceError::ServerError(e.into()))?;
|
||||
|
||||
match (self.email_request.email_retrieve)(response) {
|
||||
Some(v) => Ok(v),
|
||||
None => Err(ServiceError::ServerError(anyhow!(
|
||||
"OAuth provider was misconfigured."
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract any string field from the root of a [`serde_json::Value`].
|
||||
/// This returns [`None`] if this fails.
|
||||
fn root_json_str_parse(object: serde_json::Value, field: &str) -> Option<String> {
|
||||
let object = serde_json::from_value::<serde_json::Value>(object);
|
||||
|
||||
if let Ok(serde_json::Value::Object(object)) = object {
|
||||
if let Some(serde_json::Value::String(str)) = object.get(field) {
|
||||
return Some(str.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
pub mod oauth;
|
||||
|
||||
/// Handles authentication and validation.
|
||||
pub struct AuthService {
|
||||
|
@ -190,68 +56,25 @@ impl AuthService {
|
|||
client_url: client_url.into(),
|
||||
google_oauth_client: match google_oauth {
|
||||
Some(config) => Some(OAuthClient::new(
|
||||
OAuthProvider::Google,
|
||||
config,
|
||||
"https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"https://www.googleapis.com/oauth2/v3/token",
|
||||
&format!("{}/api/auth/google/auth", api_url),
|
||||
&[&"https://www.googleapis.com/auth/userinfo.email"],
|
||||
EmailRequest {
|
||||
request_endpoint: RequestEndpoint::FormatUrl(|token| {
|
||||
format!(
|
||||
"https://www.googleapis.com/oauth2/v1/userinfo?access_token={}",
|
||||
token
|
||||
)
|
||||
}),
|
||||
email_retrieve: |obj| root_json_str_parse(obj, "email"),
|
||||
},
|
||||
)),
|
||||
None => None,
|
||||
},
|
||||
github_oauth_client: match github_oauth {
|
||||
Some(config) => Some(OAuthClient::new(
|
||||
OAuthProvider::Github,
|
||||
config,
|
||||
"https://github.com/login/oauth/authorize",
|
||||
"https://github.com/login/oauth/access_token",
|
||||
&format!("{}/api/auth/github/auth", api_url),
|
||||
&[&"user"],
|
||||
EmailRequest {
|
||||
request_endpoint: RequestEndpoint::Bearer(
|
||||
"https://api.github.com/user/emails".into(),
|
||||
),
|
||||
email_retrieve: |res| {
|
||||
#[derive(Deserialize)]
|
||||
struct EmailResponse {
|
||||
primary: bool,
|
||||
email: String,
|
||||
}
|
||||
|
||||
if let Ok(emails) = serde_json::from_value::<Vec<EmailResponse>>(res) {
|
||||
for email in emails {
|
||||
if email.primary {
|
||||
return Some(email.email);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
},
|
||||
)),
|
||||
None => None,
|
||||
},
|
||||
discord_oauth_client: match discord_oauth {
|
||||
Some(config) => Some(OAuthClient::new(
|
||||
OAuthProvider::Discord,
|
||||
config,
|
||||
"https://discord.com/oauth2/authorize",
|
||||
"https://discord.com/api/oauth2/token",
|
||||
&format!("{}/api/auth/discord/auth", api_url),
|
||||
&[&"identify", "email"],
|
||||
EmailRequest {
|
||||
request_endpoint: RequestEndpoint::Bearer(
|
||||
"https://discord.com/api/v10/users/@me".into(),
|
||||
),
|
||||
email_retrieve: |obj| root_json_str_parse(obj, "email"),
|
||||
},
|
||||
)),
|
||||
None => None,
|
||||
},
|
||||
|
@ -264,11 +87,7 @@ impl AuthService {
|
|||
/// * `password` - User password
|
||||
///
|
||||
/// Returns JWT token response.
|
||||
pub async fn password_auth(
|
||||
&self,
|
||||
auth: &str,
|
||||
password: &str,
|
||||
) -> ServiceResult<crate::models::TokenResponse> {
|
||||
pub async fn password_auth(&self, auth: &str, password: &str) -> ServiceResult<TokenResponse> {
|
||||
let user = self.user_service.get_by_identifier(auth).await?;
|
||||
|
||||
validate_password(&user.password, password)?;
|
||||
|
@ -350,7 +169,7 @@ impl AuthService {
|
|||
&self,
|
||||
user_id: &str,
|
||||
application_id: Option<String>,
|
||||
) -> ServiceResult<crate::models::TokenResponse> {
|
||||
) -> ServiceResult<TokenResponse> {
|
||||
let expire_time = (Utc::now() + chrono::Duration::weeks(1)).timestamp();
|
||||
|
||||
let claims = JwtClaims {
|
||||
|
@ -372,18 +191,18 @@ impl AuthService {
|
|||
)
|
||||
.map_err(|e| ServiceError::ServerError(e.into()))?;
|
||||
|
||||
Ok(crate::models::TokenResponse { token: jwt })
|
||||
Ok(TokenResponse { token: jwt })
|
||||
}
|
||||
|
||||
/// Check if the OAuth provider is enabled.
|
||||
pub fn oauth_enabled(&self, provider_type: OAuthProvider) -> bool {
|
||||
self.get_client(provider_type).is_ok()
|
||||
self.get_oauth_client(provider_type).is_ok()
|
||||
}
|
||||
|
||||
/// Initiate an oauth login.
|
||||
/// Start the login session by redirecting the user to the provider URL.
|
||||
pub fn oauth_login(&self, provider_type: OAuthProvider) -> ServiceResult<oauth2::url::Url> {
|
||||
self.get_client(provider_type)?.login()
|
||||
self.get_oauth_client(provider_type)?.login()
|
||||
}
|
||||
|
||||
/// Use auth params provided by the provider to get a JWT token.
|
||||
|
@ -392,8 +211,11 @@ impl AuthService {
|
|||
&self,
|
||||
provider_type: OAuthProvider,
|
||||
auth_request: &AuthRequest,
|
||||
) -> ServiceResult<crate::models::TokenResponse> {
|
||||
let email = self.get_client(provider_type)?.auth(auth_request).await?;
|
||||
) -> ServiceResult<TokenResponse> {
|
||||
let email = self
|
||||
.get_oauth_client(provider_type)?
|
||||
.get_email(auth_request)
|
||||
.await?;
|
||||
|
||||
let user = self
|
||||
.user_service
|
||||
|
@ -403,7 +225,7 @@ impl AuthService {
|
|||
self.new_jwt(&user.id, None)
|
||||
}
|
||||
|
||||
fn get_client(&self, provider_type: OAuthProvider) -> ServiceResult<&OAuthClient> {
|
||||
fn get_oauth_client(&self, provider_type: OAuthProvider) -> ServiceResult<&OAuthClient> {
|
||||
let provider = match provider_type {
|
||||
OAuthProvider::Google => &self.google_oauth_client,
|
||||
OAuthProvider::Github => &self.github_oauth_client,
|
|
@ -0,0 +1,212 @@
|
|||
use derive_more::Display;
|
||||
use oauth2::basic::BasicClient;
|
||||
use oauth2::reqwest::async_http_client;
|
||||
use oauth2::{
|
||||
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope,
|
||||
TokenResponse, TokenUrl,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::config::OAuthConfig;
|
||||
use crate::models::AuthRequest;
|
||||
use crate::services::{ServiceError, ServiceResult};
|
||||
|
||||
/// All OAuth providers.
|
||||
#[derive(Debug, Display)]
|
||||
pub enum OAuthProvider {
|
||||
Google,
|
||||
Github,
|
||||
Discord,
|
||||
}
|
||||
|
||||
struct EmailRequest {
|
||||
request_endpoint: RequestEndpoint,
|
||||
/// Email retriever using result data.
|
||||
email_retrieve: fn(serde_json::Value) -> Option<String>,
|
||||
}
|
||||
|
||||
/// User request endpoint configuration for getting email.
|
||||
enum RequestEndpoint {
|
||||
/// Format URL with token as argument.
|
||||
FormatUrl(fn(&str) -> String),
|
||||
/// Automatically use token in `Authorization` header.
|
||||
Bearer(String),
|
||||
}
|
||||
|
||||
pub struct OAuthClient {
|
||||
http_client: reqwest::Client,
|
||||
client: BasicClient,
|
||||
scopes: Vec<Scope>,
|
||||
email_request: EmailRequest,
|
||||
}
|
||||
|
||||
impl OAuthClient {
|
||||
/// Create new client for the provider.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - OAuth config.
|
||||
/// * `redirect_url` - Redirect URL.
|
||||
pub fn new(provider: OAuthProvider, config: OAuthConfig, redirect_url: &str) -> Self {
|
||||
match provider {
|
||||
OAuthProvider::Google => Self::new_client(
|
||||
config,
|
||||
"https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"https://www.googleapis.com/oauth2/v3/token",
|
||||
redirect_url,
|
||||
&[&"https://www.googleapis.com/auth/userinfo.email"],
|
||||
EmailRequest {
|
||||
request_endpoint: RequestEndpoint::FormatUrl(|token| {
|
||||
format!(
|
||||
"https://www.googleapis.com/oauth2/v1/userinfo?access_token={}",
|
||||
token
|
||||
)
|
||||
}),
|
||||
email_retrieve: |obj| root_json_str_parse(obj, "email"),
|
||||
},
|
||||
),
|
||||
OAuthProvider::Github => Self::new_client(
|
||||
config,
|
||||
"https://github.com/login/oauth/authorize",
|
||||
"https://github.com/login/oauth/access_token",
|
||||
redirect_url,
|
||||
&[&"user"],
|
||||
EmailRequest {
|
||||
request_endpoint: RequestEndpoint::Bearer(
|
||||
"https://api.github.com/user/emails".into(),
|
||||
),
|
||||
email_retrieve: |res| {
|
||||
#[derive(Deserialize)]
|
||||
struct EmailResponse {
|
||||
primary: bool,
|
||||
email: String,
|
||||
}
|
||||
|
||||
if let Ok(emails) = serde_json::from_value::<Vec<EmailResponse>>(res) {
|
||||
for email in emails {
|
||||
if email.primary {
|
||||
return Some(email.email);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
},
|
||||
},
|
||||
),
|
||||
OAuthProvider::Discord => Self::new_client(
|
||||
config,
|
||||
"https://discord.com/oauth2/authorize",
|
||||
"https://discord.com/api/oauth2/token",
|
||||
redirect_url,
|
||||
&[&"identify", "email"],
|
||||
EmailRequest {
|
||||
request_endpoint: RequestEndpoint::Bearer(
|
||||
"https://discord.com/api/v10/users/@me".into(),
|
||||
),
|
||||
email_retrieve: |obj| root_json_str_parse(obj, "email"),
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_client(
|
||||
oauth_config: OAuthConfig,
|
||||
auth_url: &str,
|
||||
token_url: &str,
|
||||
redirect_url: &str,
|
||||
scopes: &[&str],
|
||||
email_request: EmailRequest,
|
||||
) -> Self {
|
||||
let auth_url = AuthUrl::new(auth_url.to_string()).unwrap();
|
||||
let token_url = TokenUrl::new(token_url.to_string()).unwrap();
|
||||
|
||||
Self {
|
||||
http_client: reqwest::Client::builder()
|
||||
.user_agent("Backpack")
|
||||
.build()
|
||||
.unwrap(),
|
||||
client: BasicClient::new(
|
||||
ClientId::new(oauth_config.client_id),
|
||||
Some(ClientSecret::new(oauth_config.client_secret)),
|
||||
auth_url,
|
||||
Some(token_url),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url.into()).expect("Invalid redirect URL")),
|
||||
scopes: scopes
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|f| Scope::new(f.to_string()))
|
||||
.collect(),
|
||||
email_request,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initiate an oauth login with provided scopes.
|
||||
/// Start the login session by redirecting the user to the provider URL.
|
||||
pub fn login(&self) -> ServiceResult<oauth2::url::Url> {
|
||||
// TODO: PKCE verification.
|
||||
|
||||
// Generate the authorization URL to which we'll redirect the user.
|
||||
let (authorize_url, _csrf_state) = self
|
||||
.client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scopes(self.scopes.clone())
|
||||
.url();
|
||||
|
||||
Ok(authorize_url)
|
||||
}
|
||||
|
||||
/// Use auth params provided by the provider to get the email.
|
||||
pub async fn get_email(&self, oauth_request: &AuthRequest) -> ServiceResult<String> {
|
||||
let code = AuthorizationCode::new(oauth_request.code.clone());
|
||||
|
||||
// Exchange the code with a token.
|
||||
let token = match self
|
||||
.client
|
||||
.exchange_code(code)
|
||||
.request_async(async_http_client)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(e) => return Err(ServiceError::ServerError(e.into())),
|
||||
};
|
||||
|
||||
let response = match &self.email_request.request_endpoint {
|
||||
RequestEndpoint::FormatUrl(formatter) => self
|
||||
.http_client
|
||||
.get(formatter(token.access_token().secret())),
|
||||
RequestEndpoint::Bearer(url) => self
|
||||
.http_client
|
||||
.get(url)
|
||||
.bearer_auth(token.access_token().secret()),
|
||||
}
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ServiceError::ServerError(e.into()))?
|
||||
.json::<serde_json::Value>()
|
||||
.await
|
||||
.map_err(|e| ServiceError::ServerError(e.into()))?;
|
||||
|
||||
match (self.email_request.email_retrieve)(response) {
|
||||
Some(v) => Ok(v),
|
||||
None => Err(ServiceError::ServerError(anyhow::anyhow!(
|
||||
"OAuth provider was misconfigured."
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract any string field from the root of a [`serde_json::Value`].
|
||||
/// This returns [`None`] if this fails.
|
||||
fn root_json_str_parse(object: serde_json::Value, field: &str) -> Option<String> {
|
||||
let object = serde_json::from_value::<serde_json::Value>(object);
|
||||
|
||||
if let Ok(serde_json::Value::Object(object)) = object {
|
||||
if let Some(serde_json::Value::String(str)) = object.get(field) {
|
||||
return Some(str.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
Loading…
Reference in New Issue