Add native support for the `X-Clacks-Overhead` header (#412)

* add tower middleware

* factor out logic into function

* integrate with kitsune

* add tests
This commit is contained in:
aumetra 2023-11-08 19:31:31 +01:00 committed by GitHub
parent 084b34bd4a
commit 6aaa2c25b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 245 additions and 28 deletions

View File

@ -13,4 +13,4 @@
"rust-analyzer.server.extraEnv": {
"CARGO_TARGET_DIR": "target-analyzer"
}
}
}

46
Cargo.lock generated
View File

@ -1207,7 +1207,7 @@ dependencies = [
"criterion-plot",
"futures",
"is-terminal",
"itertools",
"itertools 0.10.5",
"num-traits 0.2.17",
"once_cell",
"oorandom",
@ -1227,7 +1227,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools",
"itertools 0.10.5",
]
[[package]]
@ -2192,9 +2192,9 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.10"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
dependencies = [
"cfg-if",
"js-sys",
@ -2713,6 +2713,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.9"
@ -2808,6 +2817,7 @@ dependencies = [
"tokio-util",
"tower",
"tower-http",
"tower-x-clacks-overhead",
"tracing",
"typed-builder",
"url",
@ -3406,9 +3416,9 @@ dependencies = [
[[package]]
name = "linux-raw-sys"
version = "0.4.10"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f"
checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829"
[[package]]
name = "lock_api"
@ -3988,7 +3998,7 @@ dependencies = [
"ed25519-dalek",
"hmac",
"http",
"itertools",
"itertools 0.10.5",
"log",
"oauth2",
"p256",
@ -4670,7 +4680,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [
"anyhow",
"itertools",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn 1.0.109",
@ -6202,6 +6212,19 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tower-x-clacks-overhead"
version = "0.0.1-pre.4"
dependencies = [
"futures",
"http",
"itertools 0.11.0",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tracing"
version = "0.1.40"
@ -6281,8 +6304,9 @@ dependencies = [
[[package]]
name = "tracing-opentelemetry"
version = "0.21.0"
source = "git+https://github.com/tokio-rs/tracing-opentelemetry.git?rev=2156c236db4c488eb6fce08bdd710dde290f5f52#2156c236db4c488eb6fce08bdd710dde290f5f52"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c67ac25c5407e7b961fafc6f7e9aa5958fd297aada2d20fa2ae1737357e55596"
dependencies = [
"js-sys",
"once_cell",
@ -6973,7 +6997,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "103fa851fff70ea29af380e87c25c48ff7faac5c530c70bd0e65366d4e0c94e4"
dependencies = [
"fancy-regex",
"itertools",
"itertools 0.10.5",
"js-sys",
"lazy_static",
"quick-error",

View File

@ -1,8 +1,6 @@
[profile.dev.package.backtrace]
opt-level = 3
[profile.dev.package.num-bigint-dig]
opt-level = 3
[profile.dev.package]
backtrace = { opt-level = 3 }
num-bigint-dig = { opt-level = 3 }
[profile.release]
codegen-units = 1
@ -43,6 +41,7 @@ members = [
"lib/masto-id-convert",
"lib/post-process",
"lib/speedy-uuid",
"lib/tower-x-clacks-overhead",
]
resolver = "2"
@ -77,5 +76,3 @@ pr-run-mode = "plan"
isolang = { git = "https://github.com/humenda/isolang-rs.git", rev = "f015b8cce82b6168303c84543fdd25f57005141c" }
# Patch `redis` for up-to-date `ahash`
redis = { git = "https://github.com/aumetra/redis-rs.git", rev = "3c4ee09d432a69e1d87d66dcba14c519467c9b81" }
# Patch `tracing-opentelemetry` for up-to-date `opentelemetry`
tracing-opentelemetry = { git = "https://github.com/tokio-rs/tracing-opentelemetry.git", rev = "2156c236db4c488eb6fce08bdd710dde290f5f52" }

View File

@ -153,6 +153,9 @@ type = "in-process"
#
# This configuration changes the general behaviour that you'd mostly attribute to the underlying HTTP server
[server]
# Values for the `X-Clacks-Overhead` header
# You can use this as a sort of silent memorial
# clacks-overhead = ["Natalie Nguyen", "John \"Soap\" MacTavish"]
# Path the frontend you want to use is located at
# Note: This path is not canonicalized and does not support Unix shortcuts such as the tilde (~)
frontend-dir = "./kitsune-fe/dist"

View File

@ -162,6 +162,8 @@ pub enum SearchConfiguration {
#[derive(Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct ServerConfiguration {
#[serde(default)]
pub clacks_overhead: Vec<SmolStr>,
pub frontend_dir: SmolStr,
pub max_upload_size: usize,
pub media_proxy_enabled: bool,

View File

@ -30,5 +30,5 @@ opentelemetry_sdk = { version = "0.21.0", default-features = false, features = [
] }
tracing = "0.1.40"
tracing-error = "0.2.0"
tracing-opentelemetry = { version = "0.21.0", default-features = false }
tracing-opentelemetry = { version = "0.22.0", default-features = false }
tracing-subscriber = "0.3.17"

View File

@ -69,6 +69,7 @@ thiserror = "1.0.50"
time = "0.3.30"
tokio = { version = "1.33.0", features = ["full"] }
tokio-util = { version = "0.7.10", features = ["compat"] }
tower-x-clacks-overhead = { path = "../lib/tower-x-clacks-overhead" }
tower-http = { version = "0.4.4", features = [
"catch-panic",
"cors",

View File

@ -4,6 +4,7 @@ use self::{
};
use crate::state::Zustand;
use axum::{extract::DefaultBodyLimit, Router};
use eyre::Context;
use kitsune_config::ServerConfiguration;
use std::time::Duration;
use tower_http::{
@ -13,6 +14,7 @@ use tower_http::{
timeout::TimeoutLayer,
trace::TraceLayer,
};
use tower_x_clacks_overhead::XClacksOverheadLayer;
use utoipa_swagger_ui::SwaggerUi;
#[cfg(feature = "graphql-api")]
@ -27,7 +29,7 @@ mod util;
pub mod extractor;
pub fn create_router(state: Zustand, server_config: &ServerConfiguration) -> Router {
pub fn create_router(state: Zustand, server_config: &ServerConfiguration) -> eyre::Result<Router> {
let frontend_dir = &server_config.frontend_dir;
let frontend_index_path = {
let mut tmp = frontend_dir.to_string();
@ -71,7 +73,15 @@ pub fn create_router(state: Zustand, server_config: &ServerConfiguration) -> Rou
ServeDir::new(frontend_dir.as_str()).fallback(ServeFile::new(frontend_index_path)),
);
router
if !server_config.clacks_overhead.is_empty() {
let clacks_overhead_layer =
XClacksOverheadLayer::new(server_config.clacks_overhead.iter().map(AsRef::as_ref))
.context("Invalid clacks overhead values")?;
router = router.layer(clacks_overhead_layer);
}
Ok(router
.layer(CatchPanicLayer::new())
.layer(CorsLayer::permissive())
.layer(DefaultBodyLimit::max(server_config.max_upload_size))
@ -79,14 +89,15 @@ pub fn create_router(state: Zustand, server_config: &ServerConfiguration) -> Rou
server_config.request_timeout_secs,
)))
.layer(TraceLayer::new_for_http())
.with_state(state)
.with_state(state))
}
#[instrument(skip_all, fields(port = %server_config.port))]
pub async fn run(state: Zustand, server_config: ServerConfiguration) {
let router = create_router(state, &server_config);
pub async fn run(state: Zustand, server_config: ServerConfiguration) -> eyre::Result<()> {
let router = create_router(state, &server_config)?;
axum::Server::bind(&([0, 0, 0, 0], server_config.port).into())
.serve(router.into_make_service())
.await
.unwrap();
.await?;
Ok(())
}

View File

@ -1,6 +1,9 @@
#![forbid(rust_2018_idioms)]
#![warn(clippy::all, clippy::pedantic)]
#[macro_use]
extern crate tracing;
use clap::Parser;
use color_eyre::{config::HookBuilder, Help};
use eyre::Context;
@ -97,7 +100,15 @@ async fn boot() -> eyre::Result<()> {
.context("Failed to connect to the Redis instance for the job scheduler")?;
let state = kitsune::initialise_state(&config, conn, job_queue.clone()).await?;
tokio::spawn(kitsune::http::run(state.clone(), config.server.clone()));
tokio::spawn({
let server_fut = kitsune::http::run(state.clone(), config.server.clone());
async move {
if let Err(error) = server_fut.await {
error!(?error, "failed to run http server");
}
}
});
tokio::spawn(kitsune_job_runner::run_dispatcher(
job_queue,
state.core.clone(),

View File

@ -0,0 +1,17 @@
[package]
name = "tower-x-clacks-overhead"
edition.workspace = true
version.workspace = true
[dependencies]
http = "0.2.9"
itertools = { version = "0.11.0", default-features = false }
pin-project-lite = "0.2.13"
tower-layer = "0.3.2"
tower-service = "0.3.2"
[dev-dependencies]
futures = { version = "0.3.29", default-features = false, features = [
"executor",
] }
tower = { version = "0.4.13", default-features = false, features = ["util"] }

View File

@ -0,0 +1,151 @@
use http::{header::InvalidHeaderValue, HeaderName, HeaderValue, Response};
use itertools::Itertools;
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{self, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
static HEADER_NAME: HeaderName = HeaderName::from_static("x-clacks-overhead");
#[inline]
fn build_names_value<'a, I>(names: I) -> Result<Arc<HeaderValue>, InvalidHeaderValue>
where
I: IntoIterator<Item = &'a str>,
{
let names = format!(
"GNU {}",
Itertools::intersperse(names.into_iter(), ", ").collect::<String>()
)
.parse()?;
Ok(Arc::new(names))
}
pin_project! {
pub struct XClacksOverheadFuture<F> {
#[pin]
future: F,
names: Arc<HeaderValue>,
}
}
impl<F, B, E> Future for XClacksOverheadFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.future.poll(cx).map_ok(|mut response| {
response
.headers_mut()
.insert(HEADER_NAME.clone(), (**this.names).clone());
response
})
}
}
#[derive(Clone)]
pub struct XClacksOverheadService<S> {
inner: S,
names: Arc<HeaderValue>,
}
impl<S> XClacksOverheadService<S> {
pub fn new<'a, I>(inner: S, names: I) -> Result<Self, InvalidHeaderValue>
where
I: IntoIterator<Item = &'a str>,
{
Ok(Self {
inner,
names: build_names_value(names)?,
})
}
}
impl<S, Request, ResBody> Service<Request> for XClacksOverheadService<S>
where
S: Service<Request, Response = Response<ResBody>>,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = XClacksOverheadFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
XClacksOverheadFuture {
future: self.inner.call(req),
names: Arc::clone(&self.names),
}
}
}
#[derive(Clone)]
pub struct XClacksOverheadLayer {
names: Arc<HeaderValue>,
}
impl XClacksOverheadLayer {
pub fn new<'a, I>(names: I) -> Result<Self, InvalidHeaderValue>
where
I: IntoIterator<Item = &'a str>,
{
Ok(Self {
names: build_names_value(names)?,
})
}
}
impl<S> Layer<S> for XClacksOverheadLayer {
type Service = XClacksOverheadService<S>;
fn layer(&self, inner: S) -> Self::Service {
XClacksOverheadService {
inner,
names: Arc::clone(&self.names),
}
}
}
#[cfg(test)]
mod test {
use crate::{XClacksOverheadLayer, HEADER_NAME};
use http::{Request, Response};
use std::convert::Infallible;
use tower::{service_fn, ServiceExt};
use tower_layer::Layer;
use tower_service::Service;
#[test]
fn add_header() {
let mut service = XClacksOverheadLayer::new(["Johnny"])
.unwrap()
.layer(service_fn(|_req: Request<()>| async move {
Ok::<_, Infallible>(Response::new(()))
}));
let response = futures::executor::block_on(async move {
service
.ready()
.await
.unwrap()
.call(Request::new(()))
.await
.unwrap()
});
let clacks_overhead = response.headers().get(&HEADER_NAME).unwrap();
assert_eq!(clacks_overhead.as_bytes(), b"GNU Johnny");
}
}