OAuth functionality

This commit is contained in:
JSH32 2022-09-19 00:35:39 -05:00
parent 67b9fe57f1
commit 1e909ece07
8 changed files with 541 additions and 35 deletions

102
Cargo.lock generated
View File

@ -558,9 +558,11 @@ dependencies = [
"migration",
"nanoid",
"num_cpus",
"oauth2",
"paste",
"rand",
"regex",
"reqwest",
"rusoto_core",
"rusoto_s3",
"sea-orm",
@ -1621,6 +1623,19 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac"
dependencies = [
"http",
"hyper",
"rustls",
"tokio",
"tokio-rustls",
]
[[package]]
name = "hyper-tls"
version = "0.5.0"
@ -1714,6 +1729,12 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "ipnet"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b"
[[package]]
name = "itertools"
version = "0.10.3"
@ -2150,6 +2171,26 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "oauth2"
version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d62c436394991641b970a92e23e8eeb4eb9bca74af4f5badc53bcd568daadbd"
dependencies = [
"base64",
"chrono",
"getrandom",
"http",
"rand",
"reqwest",
"serde",
"serde_json",
"serde_path_to_error",
"sha2 0.10.2",
"thiserror",
"url",
]
[[package]]
name = "once_cell"
version = "1.13.0"
@ -2626,6 +2667,48 @@ dependencies = [
"winapi",
]
[[package]]
name = "reqwest"
version = "0.11.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b75aa69a3f06bbcc66ede33af2af253c6f7a86b1ca0033f60c580a27074fbf92"
dependencies = [
"base64",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"hyper-rustls",
"hyper-tls",
"ipnet",
"js-sys",
"lazy_static",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"rustls",
"rustls-pemfile",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"webpki-roots",
"winreg",
]
[[package]]
name = "ring"
version = "0.16.20"
@ -3058,6 +3141,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "184c643044780f7ceb59104cef98a5a6f12cb2288a7bc701ab93a362b49fd47d"
dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@ -3711,6 +3803,7 @@ dependencies = [
"idna",
"matches",
"percent-encoding",
"serde",
]
[[package]]
@ -3999,6 +4092,15 @@ version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
[[package]]
name = "winreg"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
dependencies = [
"winapi",
]
[[package]]
name = "xml-rs"
version = "0.8.4"

View File

@ -62,3 +62,5 @@ actix-multipart-extract = "0.1.4"
num_cpus = "1.0"
heck = "0.4.0"
macro_rules_attribute = "0.1.2"
oauth2 = "4.2.3"
reqwest = { version = "0.11.11", features = [ "json" ] }

View File

@ -34,37 +34,57 @@ import GithubSVG from "assets/icons/github.svg"
import styles from "styles/login.module.scss"
import { BasicAuthForm } from "@/client"
import api from "helpers/api"
import getConfig from "next/config"
const Login: NextPage = () => {
const [postLoginUnverifiedEmail, setPostLoginUnverifiedEmail] = React.useState<string | null>(null)
const router = useRouter()
const { token, fail } = router.query
const { publicRuntimeConfig } = getConfig()
const { register, handleSubmit } = useForm()
const toast = useToast()
React.useEffect(() => {
if (store.userData != null)
if (fail) {
toast({
title: "Authentication Error",
description: fail,
status: "error",
duration: 5000,
isClosable: true
})
} else if (token != null) {
tokenLogin(token as string)
} else if (store.userData != null) {
router.replace("/user/uploads")
}
}, [])
const tokenLogin = React.useCallback((token: string) => {
localStorage.setItem("token", token)
api.user.info().then(userInfo => {
store.setUserInfo(userInfo)
userInfo.verified
? router.replace("/user/uploads")
: setPostLoginUnverifiedEmail(userInfo.email)
toast({
title: "Logged in",
description: `Welcome ${userInfo.username}`,
status: "success",
duration: 5000,
isClosable: true
})
})
}, [])
const formSubmit = (data: BasicAuthForm) => {
api.authentication.basic(data)
.then(tokenRes => {
localStorage.setItem("token", tokenRes.token)
api.user.info().then(userInfo => {
store.setUserInfo(userInfo)
userInfo.verified
? router.replace("/user/uploads")
: setPostLoginUnverifiedEmail(userInfo.email)
toast({
title: "Logged in",
description: `Welcome ${userInfo.username}`,
status: "success",
duration: 5000,
isClosable: true
})
})
tokenLogin(tokenRes.token)
})
.catch(error => {
toast({
@ -77,6 +97,10 @@ const Login: NextPage = () => {
})
}
const oauthSignIn = React.useCallback((provider: string) => {
window.location.replace(`${publicRuntimeConfig.apiRoot}/api/auth/${provider}/login`)
}, [])
if (postLoginUnverifiedEmail != null)
return <VerificationMessage email={postLoginUnverifiedEmail} />
@ -99,12 +123,12 @@ const Login: NextPage = () => {
boxShadow="lg"
p={8}>
<Stack spacing={2}>
<Button w="full" variant="outline" leftIcon={<Icon as={GoogleSVG} />}>
<Button w="full" variant="outline" leftIcon={<Icon as={GoogleSVG} />} onClick={() => oauthSignIn("google")}>
<Center>
<Text>Sign in with Google</Text>
</Center>
</Button>
<Button w="full" colorScheme="blackAlpha" color="white" bg="black" variant="solid" leftIcon={<Icon color="white.500" as={GithubSVG} />}>
<Button w="full" colorScheme="blackAlpha" color="white" bg="black" variant="solid" leftIcon={<Icon color="white.500" as={GithubSVG} />} onClick={() => oauthSignIn("github")}>
<Center>
<Text>Sign in with Github</Text>
</Center>

View File

@ -21,6 +21,14 @@ pub struct Config {
pub smtp_config: Option<SMTPConfig>,
pub invite_only: bool,
pub run_migrations: bool,
pub google_oauth: Option<OAuthConfig>,
pub github_oauth: Option<OAuthConfig>,
}
#[derive(Clone)]
pub struct OAuthConfig {
pub client_id: String,
pub client_secret: String,
}
#[derive(Clone)]
@ -100,6 +108,24 @@ impl Config {
false => None,
}
},
google_oauth: {
match get_env_or("GOOGLE_OAUTH_ENABLED", false) {
true => Some(OAuthConfig {
client_id: get_env("GOOGLE_CLIENT_ID"),
client_secret: get_env("GOOGLE_CLIENT_SECRET"),
}),
false => None,
}
},
github_oauth: {
match get_env_or("GITHUB_OAUTH_ENABLED", false) {
true => Some(OAuthConfig {
client_id: get_env("GITHUB_CLIENT_ID"),
client_secret: get_env("GITHUB_CLIENT_SECRET"),
}),
false => None,
}
},
}
}
}

View File

@ -7,7 +7,6 @@ use crate::{
registration_key::RegistrationKeyService, user::UserService,
},
};
use actix_http::Uri;
use actix_multipart_extract::MultipartConfig;
use clap::Parser;
use colored::*;
@ -127,8 +126,11 @@ async fn main() -> std::io::Result<()> {
let auth_service = Data::new(AuthService::new(
user_service.clone().into_inner(),
application_service_container.clone(),
config.api_url.parse::<Uri>().unwrap(),
config.jwt_key,
&config.api_url,
&config.jwt_key,
&config.client_url,
config.google_oauth,
config.github_oauth,
));
// Application service.

View File

@ -6,3 +6,10 @@ pub struct BasicAuthForm {
pub auth: String,
pub password: String,
}
/// OAuth redirect request parameters.
#[derive(Deserialize, ToSchema)]
pub struct AuthRequest {
pub code: String,
pub state: String,
}

View File

@ -1,15 +1,24 @@
use crate::{
models::{auth::BasicAuthForm, TokenResponse},
services::{auth::AuthService, ToResponse},
models::{auth::BasicAuthForm, AuthRequest, TokenResponse},
services::{
auth::{AuthService, OAuthProvider},
ServiceError, ToResponse,
},
};
use actix_web::{http::StatusCode, post, web, Responder, Scope};
use actix_http::header;
use actix_web::{get, http::StatusCode, post, web, HttpResponse, Responder, Scope};
pub fn get_routes() -> Scope {
web::scope("/auth").service(basic)
web::scope("/auth")
.service(basic)
.service(google_login)
.service(google_auth)
.service(github_login)
.service(github_auth)
}
/// Login with email and password
/// Login with email and password.
#[utoipa::path(
context_path = "/api/auth",
tag = "authentication",
@ -26,3 +35,97 @@ async fn basic(service: web::Data<AuthService>, form: web::Json<BasicAuthForm>)
.await
.to_response::<TokenResponse>(StatusCode::OK)
}
macro_rules! define_oauth_route_auth {
($auth_service:expr, $auth_params:expr, $variant:expr) => {
HttpResponse::Found()
.append_header((
header::LOCATION,
match $auth_service
.oauth_authenticate($variant, $auth_params)
.await
{
// TODO: Allow frontend to be disabled and instead just return the raw token response.
// Frontend should get this parameter on load and put the token into headers.
Ok(v) => format!("{}/user/login?token={}", $auth_service.client_url, v.token),
Err(e) => format!(
"{}/user/login?fail={}",
$auth_service.client_url,
match e {
ServiceError::ServerError(_) | ServiceError::DbErr(_) =>
return e.to_response(),
_ => e.to_string(),
}
),
},
))
.finish()
}
}
macro_rules! define_oauth_route_login {
($auth_service:expr, $variant:expr) => {
match $auth_service.oauth_login($variant) {
Ok(v) => HttpResponse::Found()
.append_header((header::LOCATION, v.to_string()))
.finish(),
Err(e) => e.to_response(),
}
};
}
/// Initiate Google OAuth authentication.
/// This redirects to google to authenticate the user.
#[utoipa::path(
context_path = "/api/auth",
tag = "authentication",
responses((status = 200, body = TokenResponse)),
)]
#[get("/google/login")]
async fn google_login(service: web::Data<AuthService>) -> impl Responder {
define_oauth_route_login!(service, OAuthProvider::Google)
}
/// Google OAuth redirect URL.
/// This redirects to frontend with token if a valid user was found with the parameters.
#[utoipa::path(
context_path = "/api/auth",
tag = "authentication",
responses((status = 200, body = TokenResponse)),
request_body(content = AuthRequest)
)]
#[get("/google/auth")]
async fn google_auth(
service: web::Data<AuthService>,
params: web::Query<AuthRequest>,
) -> impl Responder {
define_oauth_route_auth!(service, &params, OAuthProvider::Google)
}
/// Initiate Github OAuth authentication.
/// This redirects to google to authenticate the user.
#[utoipa::path(
context_path = "/api/auth",
tag = "authentication",
responses((status = 200, body = TokenResponse)),
)]
#[get("/github/login")]
async fn github_login(service: web::Data<AuthService>) -> impl Responder {
define_oauth_route_login!(service, OAuthProvider::Github)
}
/// Github OAuth redirect URL.
/// This redirects to frontend with token if a valid user was found with the parameters.
#[utoipa::path(
context_path = "/api/auth",
tag = "authentication",
responses((status = 200, body = TokenResponse)),
request_body(content = AuthRequest)
)]
#[get("/github/auth")]
async fn github_auth(
service: web::Data<AuthService>,
params: web::Query<AuthRequest>,
) -> impl Responder {
define_oauth_route_auth!(service, &params, OAuthProvider::Github)
}

View File

@ -1,14 +1,25 @@
use actix_http::Uri;
use anyhow::anyhow;
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use chrono::Utc;
use derive_more::Display;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
};
use rand::rngs::OsRng;
use sea_orm::{ColumnTrait, Condition};
use serde::{Deserialize, Serialize};
use std::sync::{Arc, RwLock};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use crate::{
config::OAuthConfig,
database::entity::{applications, users},
models::TokenResponse,
models::AuthRequest,
};
use super::{
@ -16,6 +27,123 @@ use super::{
ServiceResult,
};
/// All OAuth providers.
#[derive(Debug, Display)]
pub enum OAuthProvider {
Google,
Github,
}
struct EmailRequest {
request_endpoint: RequestEndpoint,
/// Email retriever using result data.
email_retrieve: fn(serde_json::Value) -> Option<String>,
}
/// User request endpoint configuration for getting email.
enum RequestEndpoint {
/// Format URL with token as argument.
FormatUrl(fn(&str) -> String),
/// Automatically use token in `Authorization` header.
Bearer(String),
}
struct OAuthClient {
http_client: reqwest::Client,
client: BasicClient,
scopes: Vec<Scope>,
email_request: EmailRequest,
}
impl OAuthClient {
pub fn new(
oauth_config: OAuthConfig,
auth_url: &str,
token_url: &str,
redirect_url: &str,
scopes: &[&str],
email_request: EmailRequest,
) -> Self {
let auth_url = AuthUrl::new(auth_url.to_string()).unwrap();
let token_url = TokenUrl::new(token_url.to_string()).unwrap();
Self {
http_client: reqwest::Client::builder()
.user_agent("Backpack")
.build()
.unwrap(),
client: BasicClient::new(
ClientId::new(oauth_config.client_id),
Some(ClientSecret::new(oauth_config.client_secret)),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url.into()).expect("Invalid redirect URL")),
scopes: scopes
.to_vec()
.iter()
.map(|f| Scope::new(f.to_string()))
.collect(),
email_request,
}
}
/// Initiate an oauth login.
/// Start the login session by redirecting the user to the provider URL.
fn login(&self) -> ServiceResult<oauth2::url::Url> {
// TODO: PKCE verification.
// Generate the authorization URL to which we'll redirect the user.
let (authorize_url, _csrf_state) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(self.scopes.clone())
.url();
Ok(authorize_url)
}
/// Use auth params provided by the provider to get a JWT token.
/// Returns user email.
async fn auth(&self, oauth_request: &AuthRequest) -> ServiceResult<String> {
let code = AuthorizationCode::new(oauth_request.code.clone());
// Exchange the code with a token.
let token = match self
.client
.exchange_code(code)
.request_async(async_http_client)
.await
{
Ok(v) => v,
Err(e) => return Err(ServiceError::ServerError(e.into())),
};
let response = match &self.email_request.request_endpoint {
RequestEndpoint::FormatUrl(formatter) => self
.http_client
.get(formatter(token.access_token().secret())),
RequestEndpoint::Bearer(url) => self
.http_client
.get(url)
.bearer_auth(token.access_token().secret()),
}
.send()
.await
.map_err(|e| ServiceError::ServerError(e.into()))?
.json::<serde_json::Value>()
.await
.map_err(|e| ServiceError::ServerError(e.into()))?;
match (self.email_request.email_retrieve)(response) {
Some(v) => Ok(v),
None => Err(ServiceError::ServerError(anyhow!(
"OAuth provider was misconfigured."
))),
}
}
}
/// Handles authentication and validation.
pub struct AuthService {
user_service: Arc<UserService>,
@ -23,20 +151,90 @@ pub struct AuthService {
application_service: Arc<RwLock<Option<Arc<ApplicationService>>>>,
api_url: Uri,
jwt_key: String,
/// Root URL of client.
pub client_url: String,
google_oauth_client: Option<OAuthClient>,
github_oauth_client: Option<OAuthClient>,
}
impl AuthService {
pub fn new(
user_service: Arc<UserService>,
application_service: Arc<RwLock<Option<Arc<ApplicationService>>>>,
api_url: Uri,
jwt_key: String,
api_url: &str,
jwt_key: &str,
client_url: &str,
google_oauth: Option<OAuthConfig>,
github_oauth: Option<OAuthConfig>,
) -> Self {
Self {
user_service,
application_service,
api_url,
jwt_key,
api_url: api_url.parse::<Uri>().unwrap(),
jwt_key: jwt_key.into(),
client_url: client_url.into(),
google_oauth_client: match google_oauth {
Some(config) => Some(OAuthClient::new(
config,
"https://accounts.google.com/o/oauth2/v2/auth",
"https://www.googleapis.com/oauth2/v3/token",
&format!("{}/api/auth/google/auth", api_url),
&[&"https://www.googleapis.com/auth/userinfo.email"],
EmailRequest {
request_endpoint: RequestEndpoint::FormatUrl(|token| {
format!(
"https://www.googleapis.com/oauth2/v1/userinfo?access_token={}",
token
)
}),
email_retrieve: |res| {
if let serde_json::Value::Object(obj) = res {
if let Some(serde_json::Value::String(email)) = obj.get("email") {
Some(email.to_owned())
} else {
None
}
} else {
None
}
},
},
)),
None => None,
},
github_oauth_client: match github_oauth {
Some(config) => Some(OAuthClient::new(
config,
"https://github.com/login/oauth/authorize",
"https://github.com/login/oauth/access_token",
&format!("{}/api/auth/github/auth", api_url),
&[&"user"],
EmailRequest {
request_endpoint: RequestEndpoint::Bearer(
"https://api.github.com/user/emails".into(),
),
email_retrieve: |res| {
#[derive(Deserialize)]
struct EmailResponse {
primary: bool,
email: String,
}
if let Ok(emails) = serde_json::from_value::<Vec<EmailResponse>>(res) {
for email in emails {
if email.primary {
return Some(email.email);
}
}
}
None
},
},
)),
None => None,
},
}
}
@ -46,7 +244,11 @@ impl AuthService {
/// * `password` - User password
///
/// Returns JWT token response.
pub async fn password_auth(&self, auth: &str, password: &str) -> ServiceResult<TokenResponse> {
pub async fn password_auth(
&self,
auth: &str,
password: &str,
) -> ServiceResult<crate::models::TokenResponse> {
let user = self.user_service.get_by_identifier(auth).await?;
validate_password(&user.password, password)?;
@ -128,7 +330,7 @@ impl AuthService {
&self,
user_id: &str,
application_id: Option<String>,
) -> ServiceResult<TokenResponse> {
) -> ServiceResult<crate::models::TokenResponse> {
let expire_time = (Utc::now() + chrono::Duration::weeks(1)).timestamp();
let claims = JwtClaims {
@ -150,7 +352,45 @@ impl AuthService {
)
.map_err(|e| ServiceError::ServerError(e.into()))?;
Ok(TokenResponse { token: jwt })
Ok(crate::models::TokenResponse { token: jwt })
}
/// Initiate an oauth login.
/// Start the login session by redirecting the user to the provider URL.
pub fn oauth_login(&self, provider_type: OAuthProvider) -> ServiceResult<oauth2::url::Url> {
self.get_client(provider_type)?.login()
}
/// Use auth params provided by the provider to get a JWT token.
/// Returns new JWT key.
pub async fn oauth_authenticate(
&self,
provider_type: OAuthProvider,
auth_request: &AuthRequest,
) -> ServiceResult<crate::models::TokenResponse> {
let email = self.get_client(provider_type)?.auth(auth_request).await?;
let user = self
.user_service
.by_condition(Condition::all().add(users::Column::Email.eq(email)))
.await?;
self.new_jwt(&user.id, None)
}
fn get_client(&self, provider_type: OAuthProvider) -> ServiceResult<&OAuthClient> {
let provider = match provider_type {
OAuthProvider::Google => &self.google_oauth_client,
OAuthProvider::Github => &self.github_oauth_client,
};
match provider {
Some(v) => Ok(v),
None => Err(ServiceError::InvalidData(format!(
"{} OAuth provider was not enabled for this service.",
provider_type.to_string()
))),
}
}
}