use crate::config::{Data, FederationConfig, FederationMiddleware}; use axum::{async_trait, body::Body, extract::FromRequestParts, http::Request, response::Response}; use http::{request::Parts, StatusCode}; use std::task::{Context, Poll}; use tower::{Layer, Service}; impl Layer for FederationMiddleware { type Service = FederationService; fn layer(&self, inner: S) -> Self::Service { FederationService { inner, config: self.0.clone(), } } } /// Passes [FederationConfig] to HTTP handlers, converting it to [Data] in the process #[doc(hidden)] #[derive(Clone)] pub struct FederationService { inner: S, config: FederationConfig, } impl Service> for FederationService where S: Service, Response = Response> + Send + 'static, S::Future: Send + 'static, T: Clone + Send + Sync + 'static, { type Response = S::Response; type Error = S::Error; type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut request: Request) -> Self::Future { request.extensions_mut().insert(self.config.clone()); self.inner.call(request) } } #[async_trait] impl FromRequestParts for Data where S: Send + Sync, T: Send + Sync, { type Rejection = (StatusCode, &'static str); async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { match parts.extensions.get::>() { Some(c) => Ok(c.to_request_data()), None => Err(( StatusCode::INTERNAL_SERVER_ERROR, "Missing extension, did you register FederationMiddleware?", )), } } }