diff --git a/src/plugin_middleware.rs b/src/plugin_middleware.rs index 0a9e58c46..29e599eb2 100644 --- a/src/plugin_middleware.rs +++ b/src/plugin_middleware.rs @@ -1,4 +1,4 @@ -use actix_http::h1::Payload; +use actix_http::{body::BoxBody, h1::Payload}; use actix_web::{ body::MessageBody, dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, @@ -22,13 +22,12 @@ impl PluginMiddleware { PluginMiddleware {} } } -impl Transform for PluginMiddleware +impl Transform for PluginMiddleware where - S: Service, Error = Error> + 'static, + S: Service, Error = Error> + 'static, S::Future: 'static, - B: MessageBody + 'static, { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Transform = SessionService; type InitError = (); @@ -45,13 +44,12 @@ pub struct SessionService { service: Rc, } -impl Service for SessionService +impl Service for SessionService where - S: Service, Error = Error> + 'static, + S: Service, Error = Error> + 'static, S::Future: 'static, - B: 'static, { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = Error; type Future = LocalBoxFuture<'static, Result>; @@ -61,27 +59,39 @@ where let svc = self.service.clone(); Box::pin(async move { - let method = service_req.method(); + let method = service_req.method().clone(); let path = service_req.path().replace("/api/v3/", "").replace("/", "_"); // TODO: naming can be a bit silly, `POST /api/v3/post` becomes `api_before_post_post` - let plugin_hook = format!("api_before_{method}_{path}").to_lowercase(); + let before_plugin_hook = format!("api_before_{method}_{path}").to_lowercase(); - info!("Calling plugin hook {}", &plugin_hook); + info!("Calling plugin hook {}", &before_plugin_hook); if let Some(mut plugins) = load_plugins()? { - if plugins.function_exists(&plugin_hook) { + if plugins.function_exists(&before_plugin_hook) { let payload = service_req.extract::().await?; let mut json: Value = serde_json::from_slice(&payload.to_vec())?; - call_plugin(plugins, &plugin_hook, &mut json)?; + call_plugin(plugins, &before_plugin_hook, &mut json)?; let (_, mut new_payload) = Payload::create(true); - new_payload.unread_data(Bytes::from(serde_json::to_vec_pretty(&json)?)); + new_payload.unread_data(Bytes::from(serde_json::to_vec(&json)?)); service_req.set_payload(new_payload.into()); } } - let res = svc.call(service_req).await?; + let mut res = svc.call(service_req).await?; // TODO: add after hook + let after_plugin_hook = format!("api_after_{method}_{path}").to_lowercase(); + info!("Calling plugin hook {}", &after_plugin_hook); + if let Some(mut plugins) = load_plugins()? { + if plugins.function_exists(&before_plugin_hook) { + res = res.map_body(|_, body| { + let mut json: Value = + serde_json::from_slice(&body.try_into_bytes().unwrap().to_vec()).unwrap(); + call_plugin(plugins, &after_plugin_hook, &mut json).unwrap(); + BoxBody::new(Bytes::from(serde_json::to_vec(&json).unwrap())) + }); + } + } Ok(res) })