Put OAuth logic into seperate module.

This commit is contained in:
JSH32 2022-09-19 11:23:34 -05:00
parent a194b8be9e
commit 2040f18cd2
4 changed files with 232 additions and 198 deletions

View File

@ -1,7 +1,7 @@
use crate::{
models::{auth::BasicAuthForm, AuthRequest, TokenResponse},
services::{
auth::{AuthService, OAuthProvider},
auth::{oauth::OAuthProvider, AuthService},
ServiceError, ToResponse,
},
};

View File

@ -2,7 +2,7 @@ use crate::{
database::entity::{files, settings},
models::{AppInfo, OAuthProviders},
services::{
auth::{AuthService, OAuthProvider},
auth::{oauth::OAuthProvider, AuthService},
user::UserService,
ServiceError,
},

View File

@ -1,13 +1,7 @@
use actix_http::Uri;
use anyhow::anyhow;
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use chrono::Utc;
use derive_more::Display;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
};
use rand::rngs::OsRng;
use sea_orm::{ColumnTrait, Condition};
use serde::{Deserialize, Serialize};
@ -16,145 +10,17 @@ use std::sync::{Arc, RwLock};
use crate::{
config::OAuthConfig,
database::entity::{applications, users},
models::AuthRequest,
models::{AuthRequest, TokenResponse},
};
use self::oauth::{OAuthClient, OAuthProvider};
use super::{
application::ApplicationService, prelude::DataService, user::UserService, ServiceError,
ServiceResult,
};
/// All OAuth providers.
#[derive(Debug, Display)]
pub enum OAuthProvider {
Google,
Github,
Discord,
}
struct EmailRequest {
request_endpoint: RequestEndpoint,
/// Email retriever using result data.
email_retrieve: fn(serde_json::Value) -> Option<String>,
}
/// User request endpoint configuration for getting email.
enum RequestEndpoint {
/// Format URL with token as argument.
FormatUrl(fn(&str) -> String),
/// Automatically use token in `Authorization` header.
Bearer(String),
}
struct OAuthClient {
http_client: reqwest::Client,
client: BasicClient,
scopes: Vec<Scope>,
email_request: EmailRequest,
}
impl OAuthClient {
pub fn new(
oauth_config: OAuthConfig,
auth_url: &str,
token_url: &str,
redirect_url: &str,
scopes: &[&str],
email_request: EmailRequest,
) -> Self {
let auth_url = AuthUrl::new(auth_url.to_string()).unwrap();
let token_url = TokenUrl::new(token_url.to_string()).unwrap();
Self {
http_client: reqwest::Client::builder()
.user_agent("Backpack")
.build()
.unwrap(),
client: BasicClient::new(
ClientId::new(oauth_config.client_id),
Some(ClientSecret::new(oauth_config.client_secret)),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url.into()).expect("Invalid redirect URL")),
scopes: scopes
.to_vec()
.iter()
.map(|f| Scope::new(f.to_string()))
.collect(),
email_request,
}
}
/// Initiate an oauth login.
/// Start the login session by redirecting the user to the provider URL.
fn login(&self) -> ServiceResult<oauth2::url::Url> {
// TODO: PKCE verification.
// Generate the authorization URL to which we'll redirect the user.
let (authorize_url, _csrf_state) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(self.scopes.clone())
.url();
Ok(authorize_url)
}
/// Use auth params provided by the provider to get a JWT token.
/// Returns user email.
async fn auth(&self, oauth_request: &AuthRequest) -> ServiceResult<String> {
let code = AuthorizationCode::new(oauth_request.code.clone());
// Exchange the code with a token.
let token = match self
.client
.exchange_code(code)
.request_async(async_http_client)
.await
{
Ok(v) => v,
Err(e) => return Err(ServiceError::ServerError(e.into())),
};
let response = match &self.email_request.request_endpoint {
RequestEndpoint::FormatUrl(formatter) => self
.http_client
.get(formatter(token.access_token().secret())),
RequestEndpoint::Bearer(url) => self
.http_client
.get(url)
.bearer_auth(token.access_token().secret()),
}
.send()
.await
.map_err(|e| ServiceError::ServerError(e.into()))?
.json::<serde_json::Value>()
.await
.map_err(|e| ServiceError::ServerError(e.into()))?;
match (self.email_request.email_retrieve)(response) {
Some(v) => Ok(v),
None => Err(ServiceError::ServerError(anyhow!(
"OAuth provider was misconfigured."
))),
}
}
}
/// Extract any string field from the root of a [`serde_json::Value`].
/// This returns [`None`] if this fails.
fn root_json_str_parse(object: serde_json::Value, field: &str) -> Option<String> {
let object = serde_json::from_value::<serde_json::Value>(object);
if let Ok(serde_json::Value::Object(object)) = object {
if let Some(serde_json::Value::String(str)) = object.get(field) {
return Some(str.to_owned());
}
}
None
}
pub mod oauth;
/// Handles authentication and validation.
pub struct AuthService {
@ -190,68 +56,25 @@ impl AuthService {
client_url: client_url.into(),
google_oauth_client: match google_oauth {
Some(config) => Some(OAuthClient::new(
OAuthProvider::Google,
config,
"https://accounts.google.com/o/oauth2/v2/auth",
"https://www.googleapis.com/oauth2/v3/token",
&format!("{}/api/auth/google/auth", api_url),
&[&"https://www.googleapis.com/auth/userinfo.email"],
EmailRequest {
request_endpoint: RequestEndpoint::FormatUrl(|token| {
format!(
"https://www.googleapis.com/oauth2/v1/userinfo?access_token={}",
token
)
}),
email_retrieve: |obj| root_json_str_parse(obj, "email"),
},
)),
None => None,
},
github_oauth_client: match github_oauth {
Some(config) => Some(OAuthClient::new(
OAuthProvider::Github,
config,
"https://github.com/login/oauth/authorize",
"https://github.com/login/oauth/access_token",
&format!("{}/api/auth/github/auth", api_url),
&[&"user"],
EmailRequest {
request_endpoint: RequestEndpoint::Bearer(
"https://api.github.com/user/emails".into(),
),
email_retrieve: |res| {
#[derive(Deserialize)]
struct EmailResponse {
primary: bool,
email: String,
}
if let Ok(emails) = serde_json::from_value::<Vec<EmailResponse>>(res) {
for email in emails {
if email.primary {
return Some(email.email);
}
}
}
None
},
},
)),
None => None,
},
discord_oauth_client: match discord_oauth {
Some(config) => Some(OAuthClient::new(
OAuthProvider::Discord,
config,
"https://discord.com/oauth2/authorize",
"https://discord.com/api/oauth2/token",
&format!("{}/api/auth/discord/auth", api_url),
&[&"identify", "email"],
EmailRequest {
request_endpoint: RequestEndpoint::Bearer(
"https://discord.com/api/v10/users/@me".into(),
),
email_retrieve: |obj| root_json_str_parse(obj, "email"),
},
)),
None => None,
},
@ -264,11 +87,7 @@ impl AuthService {
/// * `password` - User password
///
/// Returns JWT token response.
pub async fn password_auth(
&self,
auth: &str,
password: &str,
) -> ServiceResult<crate::models::TokenResponse> {
pub async fn password_auth(&self, auth: &str, password: &str) -> ServiceResult<TokenResponse> {
let user = self.user_service.get_by_identifier(auth).await?;
validate_password(&user.password, password)?;
@ -350,7 +169,7 @@ impl AuthService {
&self,
user_id: &str,
application_id: Option<String>,
) -> ServiceResult<crate::models::TokenResponse> {
) -> ServiceResult<TokenResponse> {
let expire_time = (Utc::now() + chrono::Duration::weeks(1)).timestamp();
let claims = JwtClaims {
@ -372,18 +191,18 @@ impl AuthService {
)
.map_err(|e| ServiceError::ServerError(e.into()))?;
Ok(crate::models::TokenResponse { token: jwt })
Ok(TokenResponse { token: jwt })
}
/// Check if the OAuth provider is enabled.
pub fn oauth_enabled(&self, provider_type: OAuthProvider) -> bool {
self.get_client(provider_type).is_ok()
self.get_oauth_client(provider_type).is_ok()
}
/// Initiate an oauth login.
/// Start the login session by redirecting the user to the provider URL.
pub fn oauth_login(&self, provider_type: OAuthProvider) -> ServiceResult<oauth2::url::Url> {
self.get_client(provider_type)?.login()
self.get_oauth_client(provider_type)?.login()
}
/// Use auth params provided by the provider to get a JWT token.
@ -392,8 +211,11 @@ impl AuthService {
&self,
provider_type: OAuthProvider,
auth_request: &AuthRequest,
) -> ServiceResult<crate::models::TokenResponse> {
let email = self.get_client(provider_type)?.auth(auth_request).await?;
) -> ServiceResult<TokenResponse> {
let email = self
.get_oauth_client(provider_type)?
.get_email(auth_request)
.await?;
let user = self
.user_service
@ -403,7 +225,7 @@ impl AuthService {
self.new_jwt(&user.id, None)
}
fn get_client(&self, provider_type: OAuthProvider) -> ServiceResult<&OAuthClient> {
fn get_oauth_client(&self, provider_type: OAuthProvider) -> ServiceResult<&OAuthClient> {
let provider = match provider_type {
OAuthProvider::Google => &self.google_oauth_client,
OAuthProvider::Github => &self.github_oauth_client,

212
src/services/auth/oauth.rs Normal file
View File

@ -0,0 +1,212 @@
use derive_more::Display;
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope,
TokenResponse, TokenUrl,
};
use serde::Deserialize;
use crate::config::OAuthConfig;
use crate::models::AuthRequest;
use crate::services::{ServiceError, ServiceResult};
/// All OAuth providers.
#[derive(Debug, Display)]
pub enum OAuthProvider {
Google,
Github,
Discord,
}
struct EmailRequest {
request_endpoint: RequestEndpoint,
/// Email retriever using result data.
email_retrieve: fn(serde_json::Value) -> Option<String>,
}
/// User request endpoint configuration for getting email.
enum RequestEndpoint {
/// Format URL with token as argument.
FormatUrl(fn(&str) -> String),
/// Automatically use token in `Authorization` header.
Bearer(String),
}
pub struct OAuthClient {
http_client: reqwest::Client,
client: BasicClient,
scopes: Vec<Scope>,
email_request: EmailRequest,
}
impl OAuthClient {
/// Create new client for the provider.
///
/// # Arguments
///
/// * `config` - OAuth config.
/// * `redirect_url` - Redirect URL.
pub fn new(provider: OAuthProvider, config: OAuthConfig, redirect_url: &str) -> Self {
match provider {
OAuthProvider::Google => Self::new_client(
config,
"https://accounts.google.com/o/oauth2/v2/auth",
"https://www.googleapis.com/oauth2/v3/token",
redirect_url,
&[&"https://www.googleapis.com/auth/userinfo.email"],
EmailRequest {
request_endpoint: RequestEndpoint::FormatUrl(|token| {
format!(
"https://www.googleapis.com/oauth2/v1/userinfo?access_token={}",
token
)
}),
email_retrieve: |obj| root_json_str_parse(obj, "email"),
},
),
OAuthProvider::Github => Self::new_client(
config,
"https://github.com/login/oauth/authorize",
"https://github.com/login/oauth/access_token",
redirect_url,
&[&"user"],
EmailRequest {
request_endpoint: RequestEndpoint::Bearer(
"https://api.github.com/user/emails".into(),
),
email_retrieve: |res| {
#[derive(Deserialize)]
struct EmailResponse {
primary: bool,
email: String,
}
if let Ok(emails) = serde_json::from_value::<Vec<EmailResponse>>(res) {
for email in emails {
if email.primary {
return Some(email.email);
}
}
}
None
},
},
),
OAuthProvider::Discord => Self::new_client(
config,
"https://discord.com/oauth2/authorize",
"https://discord.com/api/oauth2/token",
redirect_url,
&[&"identify", "email"],
EmailRequest {
request_endpoint: RequestEndpoint::Bearer(
"https://discord.com/api/v10/users/@me".into(),
),
email_retrieve: |obj| root_json_str_parse(obj, "email"),
},
),
}
}
fn new_client(
oauth_config: OAuthConfig,
auth_url: &str,
token_url: &str,
redirect_url: &str,
scopes: &[&str],
email_request: EmailRequest,
) -> Self {
let auth_url = AuthUrl::new(auth_url.to_string()).unwrap();
let token_url = TokenUrl::new(token_url.to_string()).unwrap();
Self {
http_client: reqwest::Client::builder()
.user_agent("Backpack")
.build()
.unwrap(),
client: BasicClient::new(
ClientId::new(oauth_config.client_id),
Some(ClientSecret::new(oauth_config.client_secret)),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url.into()).expect("Invalid redirect URL")),
scopes: scopes
.to_vec()
.iter()
.map(|f| Scope::new(f.to_string()))
.collect(),
email_request,
}
}
/// Initiate an oauth login with provided scopes.
/// Start the login session by redirecting the user to the provider URL.
pub fn login(&self) -> ServiceResult<oauth2::url::Url> {
// TODO: PKCE verification.
// Generate the authorization URL to which we'll redirect the user.
let (authorize_url, _csrf_state) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(self.scopes.clone())
.url();
Ok(authorize_url)
}
/// Use auth params provided by the provider to get the email.
pub async fn get_email(&self, oauth_request: &AuthRequest) -> ServiceResult<String> {
let code = AuthorizationCode::new(oauth_request.code.clone());
// Exchange the code with a token.
let token = match self
.client
.exchange_code(code)
.request_async(async_http_client)
.await
{
Ok(v) => v,
Err(e) => return Err(ServiceError::ServerError(e.into())),
};
let response = match &self.email_request.request_endpoint {
RequestEndpoint::FormatUrl(formatter) => self
.http_client
.get(formatter(token.access_token().secret())),
RequestEndpoint::Bearer(url) => self
.http_client
.get(url)
.bearer_auth(token.access_token().secret()),
}
.send()
.await
.map_err(|e| ServiceError::ServerError(e.into()))?
.json::<serde_json::Value>()
.await
.map_err(|e| ServiceError::ServerError(e.into()))?;
match (self.email_request.email_retrieve)(response) {
Some(v) => Ok(v),
None => Err(ServiceError::ServerError(anyhow::anyhow!(
"OAuth provider was misconfigured."
))),
}
}
}
/// Extract any string field from the root of a [`serde_json::Value`].
/// This returns [`None`] if this fails.
fn root_json_str_parse(object: serde_json::Value, field: &str) -> Option<String> {
let object = serde_json::from_value::<serde_json::Value>(object);
if let Ok(serde_json::Value::Object(object)) = object {
if let Some(serde_json::Value::String(str)) = object.get(field) {
return Some(str.to_owned());
}
}
None
}